diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index c3c1280..85b74fe 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -49,10 +49,10 @@ for opt, arg in opts: ``` ```python -if rcode == 0: - rcode, message = self.check_start_time(start_time, block_height) -if rcode == 0: - rcode, message = self.check_end_time(end_time, start_time, block_height) +if appointment_data is None: + raise InspectionFailed(errors.APPOINTMENT_EMPTY_FIELD, "empty appointment received") +elif not isinstance(appointment_data, dict): + raise InspectionFailed(errors.APPOINTMENT_WRONG_FIELD, "wrong appointment format") ``` ## Dev Requirements diff --git a/cli/README.md b/cli/README.md index 266e0fc..b04f9df 100644 --- a/cli/README.md +++ b/cli/README.md @@ -44,8 +44,6 @@ This command is used to send appointments to the watchtower. Appointments **must { "tx": tx, "tx_id": tx_id, - "start_time": s, - "end_time": e, "to_self_delay": d } `tx` **must** be the raw penalty transaction that will be encrypted before sent to the watchtower. `type(tx) = hex encoded str` @@ -60,12 +58,6 @@ This command is used to send appointments to the watchtower. Appointments **must The API will return a `application/json` HTTP response code `200/OK` if the appointment is accepted, with the locator encoded in the response text, or a `400/Bad Request` if the appointment is rejected, with the rejection reason encoded in the response text. -### Alpha release restrictions -The alpha release does not have authentication, payments nor rate limiting, therefore some self imposed restrictions apply: - -- `start_time` should be within the next 6 blocks `[current_time+1, current_time+6]`. -- `end_time` cannot be bigger than (roughly) a month. That is `4320` blocks on top of `start_time`. - #### Usage @@ -103,9 +95,7 @@ if `-f, --file` **is** specified, then the command expects a path to a json file "appointment": { "encrypted_blob": eb, - "end_time": e, "locator": appointment_locator, - "start_time": s, "status": "being_watched", "to_self_delay": d } @@ -118,7 +108,6 @@ if `-f, --file` **is** specified, then the command expects a path to a json file "status": "dispute_responded", "appointment": { - "appointment_end": e, "dispute_txid": dispute_txid, "locator": appointment_locator, "penalty_rawtx": penalty_rawtx, @@ -164,10 +153,10 @@ python teos_cli.py register 2. Generate a new dummy appointment. **Note:** this appointment will never be fulfilled (it will eventually expire) since it does not correspond to a valid transaction. However it can be used to interact with the Eye of Satoshi's API. ``` - echo '{"tx": "4615a58815475ab8145b6bb90b1268a0dbb02e344ddd483f45052bec1f15b1951c1ee7f070a0993da395a5ee92ea3a1c184b5ffdb2507164bf1f8c1364155d48bdbc882eee0868ca69864a807f213f538990ad16f56d7dfb28a18e69e3f31ae9adad229e3244073b7d643b4597ec88bf247b9f73f301b0f25ae8207b02b7709c271da98af19f1db276ac48ba64f099644af1ae2c90edb7def5e8589a1bb17cc72ac42ecf07dd29cff91823938fd0d772c2c92b7ab050f8837efd46197c9b2b3f", "tx_id": "0b9510d92a50c1d67c6f7fc5d47908d96b3eccdea093d89bcbaf05bcfebdd951", "start_time": 0, "end_time": 0, "to_self_delay": 20}' > dummy_appointment_data.json + echo '{"tx": "4615a58815475ab8145b6bb90b1268a0dbb02e344ddd483f45052bec1f15b1951c1ee7f070a0993da395a5ee92ea3a1c184b5ffdb2507164bf1f8c1364155d48bdbc882eee0868ca69864a807f213f538990ad16f56d7dfb28a18e69e3f31ae9adad229e3244073b7d643b4597ec88bf247b9f73f301b0f25ae8207b02b7709c271da98af19f1db276ac48ba64f099644af1ae2c90edb7def5e8589a1bb17cc72ac42ecf07dd29cff91823938fd0d772c2c92b7ab050f8837efd46197c9b2b3f", "tx_id": "0b9510d92a50c1d67c6f7fc5d47908d96b3eccdea093d89bcbaf05bcfebdd951", "to_self_delay": 20}' > dummy_appointment_data.json ``` - That will create a json file that follows the appointment data structure filled with dummy data and store it in `dummy_appointment_data.json`. **Note**: You'll need to update the `start_time` and `end_time` to match valid block heights. + That will create a json file that follows the appointment data structure filled with dummy data and store it in `dummy_appointment_data.json`. 3. Send the appointment to the tower API. Which will then start monitoring for matching transactions. diff --git a/cli/__init__.py b/cli/__init__.py index 9316550..5e3f3ac 100644 --- a/cli/__init__.py +++ b/cli/__init__.py @@ -10,7 +10,6 @@ DEFAULT_CONF = { "API_PORT": {"value": 9814, "type": int}, "LOG_FILE": {"value": "teos_cli.log", "type": str, "path": True}, "APPOINTMENTS_FOLDER_NAME": {"value": "appointment_receipts", "type": str, "path": True}, - "CLI_PUBLIC_KEY": {"value": "cli_pk.der", "type": str, "path": True}, "CLI_PRIVATE_KEY": {"value": "cli_sk.der", "type": str, "path": True}, "TEOS_PUBLIC_KEY": {"value": "teos_pk.der", "type": str, "path": True}, } diff --git a/cli/exceptions.py b/cli/exceptions.py index 7498ad9..fd0d395 100644 --- a/cli/exceptions.py +++ b/cli/exceptions.py @@ -1,22 +1,9 @@ -class InvalidParameter(ValueError): - """Raised when a command line parameter is invalid (either missing or wrong)""" - - def __init__(self, msg, **kwargs): - self.reason = msg - self.kwargs = kwargs +from common.exceptions import BasicException -class InvalidKey(Exception): - """Raised when there is an error loading the keys""" - - def __init__(self, msg, **kwargs): - self.reason = msg - self.kwargs = kwargs - - -class TowerResponseError(Exception): +class TowerConnectionError(BasicException): """Raised when the tower responds with an error""" - def __init__(self, msg, **kwargs): - self.reason = msg - self.kwargs = kwargs + +class TowerResponseError(BasicException): + """Raised when the tower responds with an error""" diff --git a/cli/teos_cli.py b/cli/teos_cli.py index fd56dc7..ad6e29e 100644 --- a/cli/teos_cli.py +++ b/cli/teos_cli.py @@ -6,52 +6,50 @@ import requests from sys import argv from uuid import uuid4 from coincurve import PublicKey -from requests import Timeout, ConnectionError from getopt import getopt, GetoptError +from requests import Timeout, ConnectionError from requests.exceptions import MissingSchema, InvalidSchema, InvalidURL +from cli.exceptions import TowerResponseError from cli import DEFAULT_CONF, DATA_DIR, CONF_FILE_NAME, LOG_PREFIX -from cli.exceptions import InvalidKey, InvalidParameter, TowerResponseError from cli.help import show_usage, help_add_appointment, help_get_appointment, help_register, help_get_all_appointments -import common.cryptographer -from common.blob import Blob from common import constants from common.logger import Logger from common.appointment import Appointment from common.config_loader import ConfigLoader from common.cryptographer import Cryptographer from common.tools import setup_logging, setup_data_folder +from common.exceptions import InvalidKey, InvalidParameter, SignatureError from common.tools import is_256b_hex_str, is_locator, compute_locator, is_compressed_pk logger = Logger(actor="Client", log_name_prefix=LOG_PREFIX) -common.cryptographer.logger = Logger(actor="Cryptographer", log_name_prefix=LOG_PREFIX) -def register(compressed_pk, teos_url): +def register(user_id, teos_url): """ Registers the user to the tower. Args: - compressed_pk (:obj:`str`): a 33-byte hex-encoded compressed public key representing the user. + user_id (:obj:`str`): a 33-byte hex-encoded compressed public key representing the user. teos_url (:obj:`str`): the teos base url. Returns: :obj:`dict`: a dictionary containing the tower response if the registration succeeded. Raises: - :obj:`InvalidParameter `: if `compressed_pk` is invalid. + :obj:`InvalidParameter `: if `user_id` is invalid. :obj:`ConnectionError`: if the client cannot connect to the tower. :obj:`TowerResponseError `: if the tower responded with an error, or the response was invalid. """ - if not is_compressed_pk(compressed_pk): + if not is_compressed_pk(user_id): raise InvalidParameter("The cli public key is not valid") # Send request to the server. register_endpoint = "{}/register".format(teos_url) - data = {"public_key": compressed_pk} + data = {"public_key": user_id} logger.info("Registering in the Eye of Satoshi") response = process_post_response(post_request(data, register_endpoint)) @@ -59,7 +57,7 @@ def register(compressed_pk, teos_url): return response -def add_appointment(appointment_data, cli_sk, teos_pk, teos_url): +def add_appointment(appointment_data, user_sk, teos_id, teos_url): """ Manages the add_appointment command. @@ -74,8 +72,8 @@ def add_appointment(appointment_data, cli_sk, teos_pk, teos_url): Args: appointment_data (:obj:`dict`): a dictionary containing the appointment data. - cli_sk (:obj:`PrivateKey`): the client's private key. - teos_pk (:obj:`PublicKey`): the tower's public key. + user_sk (:obj:`PrivateKey`): the user's private key. + teos_id (:obj:`str`): the tower's compressed public key. teos_url (:obj:`str`): the teos base url. Returns: @@ -104,13 +102,9 @@ def add_appointment(appointment_data, cli_sk, teos_pk, teos_url): raise InvalidParameter("The provided data is missing the transaction") appointment_data["locator"] = compute_locator(tx_id) - appointment_data["encrypted_blob"] = Cryptographer.encrypt(Blob(tx), tx_id) + appointment_data["encrypted_blob"] = Cryptographer.encrypt(tx, tx_id) appointment = Appointment.from_dict(appointment_data) - signature = Cryptographer.sign(appointment.serialize(), cli_sk) - - # FIXME: the cryptographer should return exception we can capture - if not signature: - raise ValueError("The provided appointment cannot be signed") + signature = Cryptographer.sign(appointment.serialize(), user_sk) data = {"appointment": appointment.to_dict(), "signature": signature} @@ -125,7 +119,7 @@ def add_appointment(appointment_data, cli_sk, teos_pk, teos_url): raise TowerResponseError("The response does not contain the signature of the appointment") rpk = Cryptographer.recover_pk(appointment.serialize(), signature) - if not Cryptographer.verify_rpk(teos_pk, rpk): + if teos_id != Cryptographer.get_compressed_pk(rpk): raise TowerResponseError("The returned appointment's signature is invalid") logger.info("Appointment accepted and signed by the Eye of Satoshi") @@ -134,14 +128,14 @@ def add_appointment(appointment_data, cli_sk, teos_pk, teos_url): return appointment, signature -def get_appointment(locator, cli_sk, teos_pk, teos_url): +def get_appointment(locator, user_sk, teos_id, teos_url): """ Gets information about an appointment from the tower. Args: locator (:obj:`str`): the appointment locator used to identify it. - cli_sk (:obj:`PrivateKey`): the client's private key. - teos_pk (:obj:`PublicKey`): the tower's public key. + user_sk (:obj:`PrivateKey`): the user's private key. + teos_id (:obj:`PublicKey`): the tower's compressed public key. teos_url (:obj:`str`): the teos base url. Returns: @@ -155,13 +149,13 @@ def get_appointment(locator, cli_sk, teos_pk, teos_url): response was invalid. """ - # FIXME: All responses from the tower should be signed. Not using teos_pk atm. + # FIXME: All responses from the tower should be signed. Not using teos_id atm. if not is_locator(locator): raise InvalidParameter("The provided locator is not valid", locator=locator) message = "get appointment {}".format(locator) - signature = Cryptographer.sign(message.encode(), cli_sk) + signature = Cryptographer.sign(message.encode(), user_sk) data = {"locator": locator, "signature": signature} # Send request to the server. @@ -205,18 +199,17 @@ def get_all_appointments(teos_url): return None -def load_keys(teos_pk_path, cli_sk_path, cli_pk_path): +def load_keys(teos_pk_path, user_sk_path): """ Loads all the keys required so sign, send, and verify the appointment. Args: - teos_pk_path (:obj:`str`): path to the tower public key file. - cli_sk_path (:obj:`str`): path to the client private key file. - cli_pk_path (:obj:`str`): path to the client public key file. + teos_pk_path (:obj:`str`): path to the tower's public key file. + user_sk_path (:obj:`str`): path to the user's private key file. Returns: - :obj:`tuple`: a three-item tuple containing a ``PrivateKey``, a ``PublicKey`` and a ``str`` - representing the tower pk, user sk and user compressed pk respectively. + :obj:`tuple`: a three-item tuple containing a ``str``, a ``PrivateKey`` and a ``str`` + representing the tower id (compressed pk), user sk and user id (compressed pk) respectively. Raises: :obj:`InvalidKey `: if any of the keys is invalid or cannot be loaded. @@ -225,33 +218,30 @@ def load_keys(teos_pk_path, cli_sk_path, cli_pk_path): if not teos_pk_path: raise InvalidKey("TEOS's public key file not found. Please check your settings") - if not cli_sk_path: + if not user_sk_path: raise InvalidKey("Client's private key file not found. Please check your settings") - if not cli_pk_path: - raise InvalidKey("Client's public key file not found. Please check your settings") - try: teos_pk_der = Cryptographer.load_key_file(teos_pk_path) - teos_pk = PublicKey(teos_pk_der) + teos_id = Cryptographer.get_compressed_pk(PublicKey(teos_pk_der)) - except ValueError: - raise InvalidKey("TEOS public key is invalid or cannot be parsed") + except (InvalidParameter, InvalidKey, ValueError): + raise InvalidKey("TEOS public key cannot be loaded") - cli_sk_der = Cryptographer.load_key_file(cli_sk_path) - cli_sk = Cryptographer.load_private_key_der(cli_sk_der) + try: + user_sk_der = Cryptographer.load_key_file(user_sk_path) + user_sk = Cryptographer.load_private_key_der(user_sk_der) - if cli_sk is None: + except (InvalidParameter, InvalidKey): raise InvalidKey("Client private key is invalid or cannot be parsed") try: - cli_pk_der = Cryptographer.load_key_file(cli_pk_path) - compressed_cli_pk = Cryptographer.get_compressed_pk(PublicKey(cli_pk_der)) + user_id = Cryptographer.get_compressed_pk(user_sk.public_key) - except ValueError: - raise InvalidKey("Client public key is invalid or cannot be parsed") + except (InvalidParameter, InvalidKey): + raise InvalidKey("Client public key cannot be loaded") - return teos_pk, cli_sk, compressed_cli_pk + return teos_id, user_sk, user_id def post_request(data, endpoint): @@ -273,10 +263,10 @@ def post_request(data, endpoint): return requests.post(url=endpoint, json=data, timeout=5) except Timeout: - message = "Can't connect to the Eye of Satoshi's API. Connection timeout" + message = "Cannot connect to the Eye of Satoshi's API. Connection timeout" except ConnectionError: - message = "Can't connect to the Eye of Satoshi's API. Server cannot be reached" + message = "Cannot connect to the Eye of Satoshi's API. Server cannot be reached" except (InvalidSchema, MissingSchema, InvalidURL): message = "Invalid URL. No schema, or invalid schema, found ({})".format(endpoint) @@ -412,17 +402,15 @@ def main(command, args, command_line_conf): teos_url = "http://" + teos_url try: - teos_pk, cli_sk, compressed_cli_pk = load_keys( - config.get("TEOS_PUBLIC_KEY"), config.get("CLI_PRIVATE_KEY"), config.get("CLI_PUBLIC_KEY") - ) + teos_id, user_sk, user_id = load_keys(config.get("TEOS_PUBLIC_KEY"), config.get("CLI_PRIVATE_KEY")) if command == "register": - register_data = register(compressed_cli_pk, teos_url) + register_data = register(user_id, teos_url) logger.info("Registration succeeded. Available slots: {}".format(register_data.get("available_slots"))) if command == "add_appointment": appointment_data = parse_add_appointment_args(args) - appointment, signature = add_appointment(appointment_data, cli_sk, teos_pk, teos_url) + appointment, signature = add_appointment(appointment_data, user_sk, teos_id, teos_url) save_appointment_receipt(appointment.to_dict(), signature, config.get("APPOINTMENTS_FOLDER_NAME")) elif command == "get_appointment": @@ -435,7 +423,7 @@ def main(command, args, command_line_conf): if arg_opt in ["-h", "--help"]: sys.exit(help_get_appointment()) - appointment_data = get_appointment(arg_opt, cli_sk, teos_pk, teos_url) + appointment_data = get_appointment(arg_opt, user_sk, teos_id, teos_url) if appointment_data: print(appointment_data) @@ -468,8 +456,8 @@ def main(command, args, command_line_conf): except (FileNotFoundError, IOError, ConnectionError, ValueError) as e: logger.error(str(e)) - except (InvalidKey, InvalidParameter, TowerResponseError) as e: - logger.error(e.reason, **e.kwargs) + except (InvalidKey, InvalidParameter, TowerResponseError, SignatureError) as e: + logger.error(e.msg, **e.kwargs) except Exception as e: logger.error("Unknown error occurred", error=str(e)) diff --git a/common/appointment.py b/common/appointment.py index 7f0f5d4..2534f16 100644 --- a/common/appointment.py +++ b/common/appointment.py @@ -1,8 +1,6 @@ import struct from binascii import unhexlify -from common.encrypted_blob import EncryptedBlob - class Appointment: """ @@ -11,51 +9,42 @@ class Appointment: Args: locator (:obj:`str`): A 16-byte hex-encoded value used by the tower to detect channel breaches. It serves as a trigger for the tower to decrypt and broadcast the penalty transaction. - start_time (:obj:`int`): The block height where the tower is hired to start watching for breaches. - end_time (:obj:`int`): The block height where the tower will stop watching for breaches. to_self_delay (:obj:`int`): The ``to_self_delay`` encoded in the ``csv`` of the ``to_remote`` output of the commitment transaction that this appointment is covering. - encrypted_blob (:obj:`EncryptedBlob `): An ``EncryptedBlob`` object - containing an encrypted penalty transaction. The tower will decrypt it and broadcast the penalty transaction - upon seeing a breach on the blockchain. + encrypted_blob (:obj:`str`): An encrypted blob of data containing a penalty transaction. The tower will decrypt + it and broadcast the penalty transaction upon seeing a breach on the blockchain. """ - def __init__(self, locator, start_time, end_time, to_self_delay, encrypted_blob): + def __init__(self, locator, to_self_delay, encrypted_blob): self.locator = locator - self.start_time = start_time # ToDo: #4-standardize-appointment-fields - self.end_time = end_time # ToDo: #4-standardize-appointment-fields self.to_self_delay = to_self_delay - self.encrypted_blob = EncryptedBlob(encrypted_blob) + self.encrypted_blob = encrypted_blob @classmethod def from_dict(cls, appointment_data): """ Builds an appointment from a dictionary. - This method is useful to load data from a database. - Args: appointment_data (:obj:`dict`): a dictionary containing the following keys: - ``{locator, start_time, end_time, to_self_delay, encrypted_blob}`` + ``{locator, to_self_delay, encrypted_blob}`` Returns: - :obj:`Appointment `: An appointment initialized using the provided data. + :obj:`Appointment `: An appointment initialized using the provided data. Raises: ValueError: If one of the mandatory keys is missing in ``appointment_data``. """ locator = appointment_data.get("locator") - start_time = appointment_data.get("start_time") # ToDo: #4-standardize-appointment-fields - end_time = appointment_data.get("end_time") # ToDo: #4-standardize-appointment-fields to_self_delay = appointment_data.get("to_self_delay") - encrypted_blob_data = appointment_data.get("encrypted_blob") + encrypted_blob = appointment_data.get("encrypted_blob") - if any(v is None for v in [locator, start_time, end_time, to_self_delay, encrypted_blob_data]): + if any(v is None for v in [locator, to_self_delay, encrypted_blob]): raise ValueError("Wrong appointment data, some fields are missing") else: - appointment = cls(locator, start_time, end_time, to_self_delay, encrypted_blob_data) + appointment = cls(locator, to_self_delay, encrypted_blob) return appointment @@ -67,33 +56,18 @@ class Appointment: :obj:`dict`: A dictionary containing the appointment attributes. """ - # ToDO: #3-improve-appointment-structure - appointment = { - "locator": self.locator, - "start_time": self.start_time, - "end_time": self.end_time, - "to_self_delay": self.to_self_delay, - "encrypted_blob": self.encrypted_blob.data, - } - - return appointment + return self.__dict__ def serialize(self): """ Serializes an appointment to be signed. The serialization follows the same ordering as the fields in the appointment: - locator:start_time:end_time:to_self_delay:encrypted_blob + locator:to_self_delay:encrypted_blob All values are big endian. Returns: :obj:`bytes`: The serialized data to be signed. """ - return ( - unhexlify(self.locator) - + struct.pack(">I", self.start_time) - + struct.pack(">I", self.end_time) - + struct.pack(">I", self.to_self_delay) - + unhexlify(self.encrypted_blob.data) - ) + return unhexlify(self.locator) + struct.pack(">I", self.to_self_delay) + unhexlify(self.encrypted_blob) diff --git a/common/blob.py b/common/blob.py deleted file mode 100644 index cd3ed23..0000000 --- a/common/blob.py +++ /dev/null @@ -1,9 +0,0 @@ -import re - - -class Blob: - def __init__(self, data): - if type(data) is not str or re.search(r"^[0-9A-Fa-f]+$", data) is None: - raise ValueError("Non-Hex character found in transaction.") - - self.data = data diff --git a/common/constants.py b/common/constants.py index 904db90..6e8a6ee 100644 --- a/common/constants.py +++ b/common/constants.py @@ -8,5 +8,8 @@ HTTP_BAD_REQUEST = 400 HTTP_NOT_FOUND = 404 HTTP_SERVICE_UNAVAILABLE = 503 +# LN general nomenclature +IRREVOCABLY_RESOLVED = 100 + # Temporary constants, may be changed ENCRYPTED_BLOB_MAX_SIZE_HEX = 2 * 2048 diff --git a/common/cryptographer.py b/common/cryptographer.py index ecef7cd..2e733cd 100644 --- a/common/cryptographer.py +++ b/common/cryptographer.py @@ -7,13 +7,14 @@ from cryptography.exceptions import InvalidTag from cryptography.hazmat.primitives.ciphers.aead import ChaCha20Poly1305 from common.tools import is_256b_hex_str +from common.exceptions import InvalidKey, InvalidParameter, SignatureError, EncryptionError LN_MESSAGE_PREFIX = b"Lightning Signed Message:" def sha256d(message): """ - Computes the double sha256 of a given by message. + Computes the double sha256 of a given message. Args: message(:obj:`bytes`): the message to be used as input to the hash function. @@ -87,10 +88,6 @@ def sigrec_decode(sigrec): return rsig + rid -# FIXME: Common has not log file, so it needs to log in the same log as the caller. This is a temporary fix. -logger = None - - class Cryptographer: """ The :class:`Cryptographer` is in charge of all the cryptography in the tower. @@ -99,63 +96,49 @@ class Cryptographer: @staticmethod def check_data_key_format(data, secret): """ - Checks that the data and secret that will be used to by ``encrypt`` / ``decrypt`` are properly - formatted. + Checks that the data and secret that will be used to by ``encrypt`` / ``decrypt`` are properly formatted. Args: data(:obj:`str`): the data to be encrypted. secret(:obj:`str`): the secret used to derive the encryption key. - Returns: - :obj:`bool`: Whether or not the ``key`` and ``data`` are properly formatted. - Raises: - :obj:`ValueError`: if either the ``key`` or ``data`` is not properly formatted. + :obj:`InvalidParameter`: if either the ``key`` and/or ``data`` are not properly formatted. """ if len(data) % 2: - error = "Incorrect (Odd-length) value" - raise ValueError(error) + raise InvalidParameter("Incorrect (Odd-length) data", data=data) if not is_256b_hex_str(secret): - error = "Secret must be a 32-byte hex value (64 hex chars)" - raise ValueError(error) - - return True + raise InvalidParameter("Secret must be a 32-byte hex value (64 hex chars)", secret=secret) @staticmethod - def encrypt(blob, secret): + def encrypt(message, secret): """ - Encrypts a given :obj:`Blob ` data using ``CHACHA20POLY1305``. + Encrypts a given message data using ``CHACHA20POLY1305``. ``SHA256(secret)`` is used as ``key``, and ``0 (12-byte)`` as ``iv``. Args: - blob (:obj:`Blob `): a ``Blob`` object containing a raw penalty transaction. + message (:obj:`str`): a message to be encrypted. Should be the hex-encoded commitment_tx. secret (:obj:`str`): a value to used to derive the encryption key. Should be the dispute txid. Returns: - :obj:`str`: The encrypted data (hex encoded). + :obj:`str`: The encrypted data (hex-encoded). Raises: - :obj:`ValueError`: if either the ``secret`` or ``blob`` is not properly formatted. + :obj:`InvalidParameter`: if either the ``key`` and/or ``data`` are not properly formatted. """ - Cryptographer.check_data_key_format(blob.data, secret) - - # Transaction to be encrypted - # FIXME: The blob data should contain more things that just the transaction. Leaving like this for now. - tx = unhexlify(blob.data) + Cryptographer.check_data_key_format(message, secret) # sk is the H(txid) (32-byte) and nonce is set to 0 (12-byte) sk = sha256(unhexlify(secret)).digest() nonce = bytearray(12) - logger.debug("Encrypting blob", sk=hexlify(sk).decode(), nonce=hexlify(nonce).decode(), blob=blob.data) - # Encrypt the data cipher = ChaCha20Poly1305(sk) - encrypted_blob = cipher.encrypt(nonce=nonce, data=tx, associated_data=None) + encrypted_blob = cipher.encrypt(nonce=nonce, data=unhexlify(message), associated_data=None) encrypted_blob = hexlify(encrypted_blob).decode("utf8") return encrypted_blob @@ -164,46 +147,38 @@ class Cryptographer: # ToDo: #20-test-tx-decrypting-edge-cases def decrypt(encrypted_blob, secret): """ - Decrypts a given :obj:`EncryptedBlob ` using ``CHACHA20POLY1305``. + Decrypts a given encrypted_blob using ``CHACHA20POLY1305``. ``SHA256(secret)`` is used as ``key``, and ``0 (12-byte)`` as ``iv``. Args: - encrypted_blob(:obj:`EncryptedBlob `): an ``EncryptedBlob`` - potentially containing a penalty transaction. - secret (:obj:`str`): a value to used to derive the decryption key. Should be the dispute txid. + encrypted_blob(:obj:`str`): an encrypted blob of data potentially containing a penalty transaction. + secret (:obj:`str`): a value used to derive the decryption key. Should be the dispute txid. Returns: - :obj:`str`: The decrypted data (hex encoded). + :obj:`str`: The decrypted data (hex-encoded). Raises: - :obj:`ValueError`: if either the ``secret`` or ``encrypted_blob`` is not properly formatted. + :obj:`InvalidParameter`: if either the ``key`` and/or ``data`` are not properly formatted. + :obj:`EncryptionError`: if the data cannot be decrypted with the given key. """ - Cryptographer.check_data_key_format(encrypted_blob.data, secret) + Cryptographer.check_data_key_format(encrypted_blob, secret) # sk is the H(txid) (32-byte) and nonce is set to 0 (12-byte) sk = sha256(unhexlify(secret)).digest() nonce = bytearray(12) - logger.info( - "Decrypting blob", - sk=hexlify(sk).decode(), - nonce=hexlify(nonce).decode(), - encrypted_blob=encrypted_blob.data, - ) - # Decrypt cipher = ChaCha20Poly1305(sk) - data = unhexlify(encrypted_blob.data) + data = unhexlify(encrypted_blob) try: blob = cipher.decrypt(nonce=nonce, data=data, associated_data=None) blob = hexlify(blob).decode("utf8") except InvalidTag: - blob = None - logger.error("Can't decrypt blob with the provided key") + raise EncryptionError("Cannot decrypt blob with the provided key", blob=encrypted_blob, key=secret) return blob @@ -216,12 +191,15 @@ class Cryptographer: file_path (:obj:`str`): the path to the key file to be loaded. Returns: - :obj:`bytes` or :obj:`None`: the key file data if the file can be found and read. ``None`` otherwise. + :obj:`bytes`: the key file data if the file can be found and read. + + Raises: + :obj:`InvalidParameter`: if the file_path has wrong format or cannot be found. + :obj:`InvalidKey`: if the key cannot be loaded from the file. It covers temporary I/O errors. """ if not isinstance(file_path, str): - logger.error("Key file path was expected, {} received".format(type(file_path))) - return None + raise InvalidParameter("Key file path was expected, {} received".format(type(file_path))) try: with open(file_path, "rb") as key_file: @@ -229,12 +207,10 @@ class Cryptographer: return key except FileNotFoundError: - logger.error("Key file not found at {}. Please check your settings".format(file_path)) - return None + raise InvalidParameter("Key file not found at {}. Please check your settings".format(file_path)) except IOError as e: - logger.error("I/O error({}): {}".format(e.errno, e.strerror)) - return None + raise InvalidKey("Key file cannot be loaded", exception=e) @staticmethod def load_private_key_der(sk_der): @@ -245,50 +221,52 @@ class Cryptographer: sk_der(:obj:`str`): a private key encoded in ``DER`` format. Returns: - :obj:`PrivateKey` or :obj:`None`: A ``PrivateKey`` object. if the private key can be loaded. `None` - otherwise. + :obj:`PrivateKey`: A ``PrivateKey`` object if the private key can be loaded. + + Raises: + :obj:`InvalidKey`: if a ``PrivateKey`` cannot be loaded from the given data. """ + try: sk = PrivateKey.from_der(sk_der) return sk except ValueError: - logger.error("The provided data cannot be deserialized (wrong size or format)") + raise InvalidKey("The provided key data cannot be deserialized (wrong size or format)") except TypeError: - logger.error("The provided data cannot be deserialized (wrong type)") - - return None + raise InvalidKey("The provided key data cannot be deserialized (wrong type)") @staticmethod def sign(message, sk): """ - Signs a given data using a given secret key using ECDSA over secp256k1. + Signs a given message with a given secret key using ECDSA over secp256k1. Args: message(:obj:`bytes`): the data to be signed. - sk(:obj:`PrivateKey`): the ECDSA secret key used to signed the data. + sk(:obj:`PrivateKey`): the ECDSA secret key to be used to sign the data. Returns: - :obj:`str` or :obj:`None`: The zbase32 signature of the given message is it can be signed. `None` otherwise. + :obj:`str`: The zbase32 signature of the given message is it can be signed. + + Raises: + :obj:`InvalidParameter`: if the message and/or secret key have a wrong value. + :obj:`SignatureError`: if there is an error during the signing process. """ if not isinstance(message, bytes): - logger.error("The message must be bytes. {} received".format(type(message))) - return None + raise InvalidParameter("Wrong value passed as message. Received {}, expected (bytes)".format(type(message))) if not isinstance(sk, PrivateKey): - logger.error("The value passed as sk is not a private key (EllipticCurvePrivateKey)") - return None + raise InvalidParameter("Wrong value passed as sk. Received {}, expected (PrivateKey)".format(type(message))) try: rsig_rid = sk.sign_recoverable(LN_MESSAGE_PREFIX + message, hasher=sha256d) sigrec = sigrec_encode(rsig_rid) zb32_sig = pyzbase32.encode_bytes(sigrec).decode() - except ValueError: - logger.error("Couldn't sign the message") - return None + except ValueError as e: + raise SignatureError("Couldn't sign the message. " + str(e)) return zb32_sig @@ -302,16 +280,20 @@ class Cryptographer: zb32_sig(:obj:`str`): the zbase32 signature of the message. Returns: - :obj:`PublicKey` or :obj:`None`: The recovered public key if it can be recovered. `None` otherwise. + :obj:`PublicKey`: The public key if it can be recovered. + + Raises: + :obj:`InvalidParameter`: if the message and/or signature have a wrong value. + :obj:`SignatureError`: if a public key cannot be recovered from the given signature. """ if not isinstance(message, bytes): - logger.error("The message must be bytes. {} received".format(type(message))) - return None + raise InvalidParameter("Wrong value passed as message. Received {}, expected (bytes)".format(type(message))) if not isinstance(zb32_sig, str): - logger.error("The zbase32_sig must be str. {} received".format(type(zb32_sig))) - return None + raise InvalidParameter( + "Wrong value passed as zbase32_sig. Received {}, expected (str)".format(type(zb32_sig)) + ) sigrec = pyzbase32.decode_bytes(zb32_sig) @@ -323,31 +305,13 @@ class Cryptographer: except ValueError as e: # Several errors fit here: Signature length != 65, wrong recover id and failed to parse signature. # All of them return raise ValueError. - logger.error(str(e)) - return None + raise SignatureError("Cannot recover a public key from the given signature. " + str(e)) except Exception as e: if "failed to recover ECDSA public key" in str(e): - logger.error("Cannot recover public key from signature") + raise SignatureError("Cannot recover a public key from the given signature") else: - logger.error("Unknown exception", error=str(e)) - - return None - - @staticmethod - def verify_rpk(pk, rpk): - """ - Verifies that that a recovered public key matches a given one. - - Args: - pk(:obj:`PublicKey`): a given public key (provided by the user). - rpk(:obj:`PublicKey`): a public key recovered via ``recover_pk``. - - Returns: - :obj:`bool`: True if the public keys match, False otherwise. - """ - - return pk.point() == rpk.point() + raise SignatureError("Unknown exception. " + str(e)) @staticmethod def get_compressed_pk(pk): @@ -358,18 +322,19 @@ class Cryptographer: pk(:obj:`PublicKey`): a given public key. Returns: - :obj:`str` or :obj:`None`: A compressed, hex-encoded, public key (33-byte long) if it can be compressed. - `None` oterwise. + :obj:`str`: A compressed, hex-encoded, public key (33-byte long) if it can be compressed. + + Raises: + :obj:`InvalidParameter`: if the value passed as public key is not a PublicKey object. + :obj:`InvalidKey`: if the public key has not been properly created. """ if not isinstance(pk, PublicKey): - logger.error("The received data is not a PublicKey object") - return None + raise InvalidParameter("Wrong value passed as pk. Received {}, expected (PublicKey)".format(type(pk))) try: compressed_pk = pk.format(compressed=True) return hexlify(compressed_pk).decode("utf-8") except TypeError as e: - logger.error("PublicKey has invalid initializer", error=str(e)) - return None + raise InvalidKey("PublicKey has invalid initializer", error=str(e)) diff --git a/teos/db_manager.py b/common/db_manager.py similarity index 100% rename from teos/db_manager.py rename to common/db_manager.py diff --git a/common/encrypted_blob.py b/common/encrypted_blob.py deleted file mode 100644 index 945b49c..0000000 --- a/common/encrypted_blob.py +++ /dev/null @@ -1,6 +0,0 @@ -class EncryptedBlob: - def __init__(self, data): - self.data = data - - def __eq__(self, other): - return isinstance(other, EncryptedBlob) and self.data == other.data diff --git a/teos/errors.py b/common/errors.py similarity index 50% rename from teos/errors.py rename to common/errors.py index fb7ef1d..786a9d7 100644 --- a/teos/errors.py +++ b/common/errors.py @@ -1,4 +1,4 @@ -# Appointment errors [-1, -64] +# Appointment errors [-1, -32] APPOINTMENT_EMPTY_FIELD = -1 APPOINTMENT_WRONG_FIELD_TYPE = -2 APPOINTMENT_WRONG_FIELD_SIZE = -3 @@ -8,12 +8,14 @@ APPOINTMENT_FIELD_TOO_BIG = -6 APPOINTMENT_WRONG_FIELD = -7 APPOINTMENT_INVALID_SIGNATURE_OR_INSUFFICIENT_SLOTS = -8 -# Registration errors [-65, -128] -REGISTRATION_MISSING_FIELD = -65 -REGISTRATION_WRONG_FIELD_FORMAT = -66 +# Registration errors [-33, -64] +REGISTRATION_MISSING_FIELD = -33 +REGISTRATION_WRONG_FIELD_FORMAT = -34 -# Custom RPC errors -RPC_TX_REORGED_AFTER_BROADCAST = -98 +# General errors [-65, -96] +INVALID_REQUEST_FORMAT = -65 +# Custom RPC errors [255+] +RPC_TX_REORGED_AFTER_BROADCAST = -256 # UNHANDLED -UNKNOWN_JSON_RPC_EXCEPTION = -99 +UNKNOWN_JSON_RPC_EXCEPTION = -257 diff --git a/common/exceptions.py b/common/exceptions.py new file mode 100644 index 0000000..2dbf076 --- /dev/null +++ b/common/exceptions.py @@ -0,0 +1,37 @@ +class BasicException(Exception): + def __init__(self, msg, **kwargs): + self.msg = msg + self.kwargs = kwargs + + def __str__(self): + if len(self.kwargs) > 2: + params = "".join("{}={}, ".format(k, v) for k, v in self.kwargs.items()) + + # Remove the extra 2 characters (space and comma) and add all data to the final message. + message = self.msg + " ({})".format(params[:-2]) + + else: + message = self.msg + + return message + + def to_json(self): + response = {"error": self.msg} + response.update(self.kwargs) + return response + + +class InvalidParameter(BasicException): + """Raised when a command line parameter is invalid (either missing or wrong)""" + + +class InvalidKey(BasicException): + """Raised when there is an error loading the keys""" + + +class EncryptionError(BasicException): + """Raised when there is an error with encryption related functions, covers decryption""" + + +class SignatureError(BasicException): + """Raised when there is an with the signature related functions, covers EC recover""" diff --git a/teos/__init__.py b/teos/__init__.py index b519ebc..1a29ddb 100644 --- a/teos/__init__.py +++ b/teos/__init__.py @@ -18,6 +18,7 @@ DEFAULT_CONF = { "FEED_PORT": {"value": 28332, "type": int}, "MAX_APPOINTMENTS": {"value": 1000000, "type": int}, "DEFAULT_SLOTS": {"value": 100, "type": int}, + "DEFAULT_SUBSCRIPTION_DURATION": {"value": 4320, "type": int}, "EXPIRY_DELTA": {"value": 6, "type": int}, "MIN_TO_SELF_DELAY": {"value": 20, "type": int}, "LOG_FILE": {"value": "teos.log", "type": str, "path": True}, diff --git a/teos/api.py b/teos/api.py index 02bfb79..0bcbddb 100644 --- a/teos/api.py +++ b/teos/api.py @@ -1,22 +1,17 @@ import os import logging -from math import ceil from flask import Flask, request, abort, jsonify from teos import LOG_PREFIX -import teos.errors as errors +import common.errors as errors from teos.inspector import InspectionFailed -from teos.gatekeeper import NotEnoughSlots, IdentificationFailure +from teos.watcher import AppointmentLimitReached +from teos.gatekeeper import NotEnoughSlots, AuthenticationFailure from common.logger import Logger from common.cryptographer import hash_160 -from common.constants import ( - HTTP_OK, - HTTP_BAD_REQUEST, - HTTP_SERVICE_UNAVAILABLE, - HTTP_NOT_FOUND, - ENCRYPTED_BLOB_MAX_SIZE_HEX, -) +from common.exceptions import InvalidParameter +from common.constants import HTTP_OK, HTTP_BAD_REQUEST, HTTP_SERVICE_UNAVAILABLE, HTTP_NOT_FOUND # ToDo: #5-add-async-to-api @@ -54,7 +49,7 @@ def get_request_data_json(request): :obj:`dict`: the dictionary parsed from the json request. Raises: - :obj:`TypeError`: if the request is not json encoded or it does not decodes to a dictionary. + :obj:`InvalidParameter`: if the request is not json encoded or it does not decodes to a dictionary. """ if request.is_json: @@ -62,9 +57,9 @@ def get_request_data_json(request): if isinstance(request_data, dict): return request_data else: - raise TypeError("Invalid request content") + raise InvalidParameter("Invalid request content") else: - raise TypeError("Request is not json encoded") + raise InvalidParameter("Request is not json encoded") class API: @@ -72,19 +67,18 @@ class API: The :class:`API` is in charge of the interface between the user and the tower. It handles and serves user requests. Args: + host (:obj:`str`): the hostname to listen on. + port (:obj:`int`): the port of the webserver. inspector (:obj:`Inspector `): an ``Inspector`` instance to check the correctness of the received appointment data. watcher (:obj:`Watcher `): a ``Watcher`` instance to pass the requests to. - gatekeeper (:obj:`Watcher `): a `Gatekeeper` instance in charge to control the user - access. """ - def __init__(self, host, port, inspector, watcher, gatekeeper): + def __init__(self, host, port, inspector, watcher): self.host = host self.port = port self.inspector = inspector self.watcher = watcher - self.gatekeeper = gatekeeper self.app = app # Adds all the routes to the functions listed above. @@ -103,7 +97,8 @@ class API: Registers a user by creating a subscription. Registration is pretty straightforward for now, since it does not require payments. - The amount of slots cannot be requested by the user yet either. This is linked to the previous point. + The amount of slots and expiry of the subscription cannot be requested by the user yet either. This is linked to + the previous point. Users register by sending a public key to the proper endpoint. This is exploitable atm, but will be solved when payments are introduced. @@ -121,27 +116,32 @@ class API: try: request_data = get_request_data_json(request) - except TypeError as e: + except InvalidParameter as e: logger.info("Received invalid register request", from_addr="{}".format(remote_addr)) - return abort(HTTP_BAD_REQUEST, e) + return jsonify({"error": str(e), "error_code": errors.INVALID_REQUEST_FORMAT}), HTTP_BAD_REQUEST - client_pk = request_data.get("public_key") + user_id = request_data.get("public_key") - if client_pk: + if user_id: try: rcode = HTTP_OK - available_slots = self.gatekeeper.add_update_user(client_pk) - response = {"public_key": client_pk, "available_slots": available_slots} + available_slots, subscription_expiry = self.watcher.gatekeeper.add_update_user(user_id) + response = { + "public_key": user_id, + "available_slots": available_slots, + "subscription_expiry": subscription_expiry, + } - except ValueError as e: + except InvalidParameter as e: rcode = HTTP_BAD_REQUEST - error = "Error {}: {}".format(errors.REGISTRATION_MISSING_FIELD, str(e)) - response = {"error": error} + response = {"error": str(e), "error_code": errors.REGISTRATION_MISSING_FIELD} else: rcode = HTTP_BAD_REQUEST - error = "Error {}: public_key not found in register message".format(errors.REGISTRATION_WRONG_FIELD_FORMAT) - response = {"error": error} + response = { + "error": "public_key not found in register message", + "error_code": errors.REGISTRATION_WRONG_FIELD_FORMAT, + } logger.info("Sending response and disconnecting", from_addr="{}".format(remote_addr), response=response) @@ -157,8 +157,8 @@ class API: Returns: :obj:`tuple`: A tuple containing the response (:obj:`str`) and response code (:obj:`int`). For accepted appointments, the ``rcode`` is always 200 and the response contains the receipt signature (json). For - rejected appointments, the ``rcode`` is a 404 and the value contains an application error, and an error - message. Error messages can be found at :mod:`Errors `. + rejected appointments, the ``rcode`` contains an application error, and an error message. Error messages can + be found at :mod:`Errors `. """ # Getting the real IP if the server is behind a reverse proxy @@ -169,73 +169,28 @@ class API: try: request_data = get_request_data_json(request) - except TypeError as e: - return abort(HTTP_BAD_REQUEST, e) - - # We kind of have the chicken an the egg problem here. Data must be verified and the signature must be checked: - # - If we verify the data first, we may encounter that the signature is wrong and wasted some time. - # - If we check the signature first, we may need to verify some of the information or expose to build - # appointments with potentially wrong data, which may be exploitable. - # - # The first approach seems safer since it only implies a bunch of pretty quick checks. + except InvalidParameter as e: + return jsonify({"error": str(e), "error_code": errors.INVALID_REQUEST_FORMAT}), HTTP_BAD_REQUEST try: appointment = self.inspector.inspect(request_data.get("appointment")) - user_pk = self.gatekeeper.identify_user(appointment.serialize(), request_data.get("signature")) - - # Check if the appointment is an update. Updates will return a summary. - appointment_uuid = hash_160("{}{}".format(appointment.locator, user_pk)) - appointment_summary = self.watcher.get_appointment_summary(appointment_uuid) - - if appointment_summary: - used_slots = ceil(appointment_summary.get("size") / ENCRYPTED_BLOB_MAX_SIZE_HEX) - required_slots = ceil(len(appointment.encrypted_blob.data) / ENCRYPTED_BLOB_MAX_SIZE_HEX) - slot_diff = required_slots - used_slots - - # For updates we only reserve the slot difference provided the new one is bigger. - required_slots = slot_diff if slot_diff > 0 else 0 - - else: - # For regular appointments 1 slot is reserved per ENCRYPTED_BLOB_MAX_SIZE_HEX block. - slot_diff = 0 - required_slots = ceil(len(appointment.encrypted_blob.data) / ENCRYPTED_BLOB_MAX_SIZE_HEX) - - # Slots are reserved before adding the appointments to prevent race conditions. - # DISCUSS: It may be worth using signals here to avoid race conditions anyway. - self.gatekeeper.fill_slots(user_pk, required_slots) - - appointment_added, signature = self.watcher.add_appointment(appointment, user_pk) - - if appointment_added: - # If the appointment is added and the update is smaller than the original, the difference is given back. - if slot_diff < 0: - self.gatekeeper.free_slots(user_pk, abs(slot_diff)) - - rcode = HTTP_OK - response = { - "locator": appointment.locator, - "signature": signature, - "available_slots": self.gatekeeper.registered_users[user_pk].get("available_slots"), - } - - else: - # If the appointment is not added the reserved slots are given back - self.gatekeeper.free_slots(user_pk, required_slots) - rcode = HTTP_SERVICE_UNAVAILABLE - response = {"error": "appointment rejected"} + response = self.watcher.add_appointment(appointment, request_data.get("signature")) + rcode = HTTP_OK except InspectionFailed as e: rcode = HTTP_BAD_REQUEST - error = "appointment rejected. Error {}: {}".format(e.erno, e.reason) - response = {"error": error} + response = {"error": "appointment rejected. {}".format(e.reason), "error_code": e.erno} - except (IdentificationFailure, NotEnoughSlots): + except (AuthenticationFailure, NotEnoughSlots): rcode = HTTP_BAD_REQUEST - error = "appointment rejected. Error {}: {}".format( - errors.APPOINTMENT_INVALID_SIGNATURE_OR_INSUFFICIENT_SLOTS, - "Invalid signature or user does not have enough slots available", - ) - response = {"error": error} + response = { + "error": "appointment rejected. Invalid signature or user does not have enough slots available", + "error_code": errors.APPOINTMENT_INVALID_SIGNATURE_OR_INSUFFICIENT_SLOTS, + } + + except AppointmentLimitReached: + rcode = HTTP_SERVICE_UNAVAILABLE + response = {"error": "appointment rejected"} logger.info("Sending response and disconnecting", from_addr="{}".format(remote_addr), response=response) return jsonify(response), rcode @@ -266,9 +221,9 @@ class API: try: request_data = get_request_data_json(request) - except TypeError as e: + except InvalidParameter as e: logger.info("Received invalid get_appointment request", from_addr="{}".format(remote_addr)) - return abort(HTTP_BAD_REQUEST, e) + return jsonify({"error": str(e), "error_code": errors.INVALID_REQUEST_FORMAT}), HTTP_BAD_REQUEST locator = request_data.get("locator") @@ -278,16 +233,18 @@ class API: message = "get appointment {}".format(locator).encode() signature = request_data.get("signature") - user_pk = self.gatekeeper.identify_user(message, signature) + user_id = self.watcher.gatekeeper.authenticate_user(message, signature) triggered_appointments = self.watcher.db_manager.load_all_triggered_flags() - uuid = hash_160("{}{}".format(locator, user_pk)) + uuid = hash_160("{}{}".format(locator, user_id)) # If the appointment has been triggered, it should be in the locator (default else just in case). if uuid in triggered_appointments: appointment_data = self.watcher.db_manager.load_responder_tracker(uuid) if appointment_data: rcode = HTTP_OK + # Remove user_id field from appointment data since it is an internal field + appointment_data.pop("user_id") response = {"locator": locator, "status": "dispute_responded", "appointment": appointment_data} else: rcode = HTTP_NOT_FOUND @@ -298,12 +255,14 @@ class API: appointment_data = self.watcher.db_manager.load_watcher_appointment(uuid) if appointment_data: rcode = HTTP_OK + # Remove user_id field from appointment data since it is an internal field + appointment_data.pop("user_id") response = {"locator": locator, "status": "being_watched", "appointment": appointment_data} else: rcode = HTTP_NOT_FOUND response = {"locator": locator, "status": "not_found"} - except (InspectionFailed, IdentificationFailure): + except (InspectionFailed, AuthenticationFailure): rcode = HTTP_NOT_FOUND response = {"locator": locator, "status": "not_found"} @@ -335,9 +294,7 @@ class API: return response def start(self): - """ - This function starts the Flask server used to run the API. - """ + """ This function starts the Flask server used to run the API """ # Setting Flask log to ERROR only so it does not mess with our logging. Also disabling flask initial messages logging.getLogger("werkzeug").setLevel(logging.ERROR) diff --git a/teos/appointments_dbm.py b/teos/appointments_dbm.py index cc1fcd0..2ad51f3 100644 --- a/teos/appointments_dbm.py +++ b/teos/appointments_dbm.py @@ -1,11 +1,10 @@ import json import plyvel -from teos.db_manager import DBManager - from teos import LOG_PREFIX from common.logger import Logger +from common.db_manager import DBManager logger = Logger(actor="AppointmentsDBM", log_name_prefix=LOG_PREFIX) @@ -85,7 +84,7 @@ class AppointmentsDBM(DBManager): ``RESPONDER_LAST_BLOCK_KEY``). Returns: - :obj:`str` or :obj:`None`: A 16-byte hex-encoded str representing the last known block hash. + :obj:`str` or :obj:`None`: A 32-byte hex-encoded str representing the last known block hash. Returns ``None`` if the entry is not found. """ @@ -178,7 +177,7 @@ class AppointmentsDBM(DBManager): Args: uuid (:obj:`str`): the identifier of the appointment to be stored. - appointment (:obj:`dict`): an appointment encoded as dictionary. + appointment (:obj:`dict`): an appointment encoded as a dictionary. Returns: :obj:`bool`: True if the appointment was stored in the db. False otherwise. @@ -203,7 +202,7 @@ class AppointmentsDBM(DBManager): Args: uuid (:obj:`str`): the identifier of the appointment to be stored. - tracker (:obj:`dict`): a tracker encoded as dictionary. + tracker (:obj:`dict`): a tracker encoded as a dictionary. Returns: :obj:`bool`: True if the tracker was stored in the db. False otherwise. @@ -248,7 +247,7 @@ class AppointmentsDBM(DBManager): def create_append_locator_map(self, locator, uuid): """ - Creates (or appends to if already exists) a ``locator:uuid`` map. + Creates a ``locator:uuid`` map. If the map already exists, the new ``uuid`` is appended to the existing ones (if it is not already there). @@ -335,7 +334,7 @@ class AppointmentsDBM(DBManager): def batch_delete_watcher_appointments(self, uuids): """ - Deletes an appointment from the database. + Deletes multiple appointments from the database. Args: uuids (:obj:`list`): a list of 16-byte hex-encoded strings identifying the appointments to be deleted. @@ -368,7 +367,7 @@ class AppointmentsDBM(DBManager): def batch_delete_responder_trackers(self, uuids): """ - Deletes an appointment from the database. + Deletes multiple trackers from the database. Args: uuids (:obj:`list`): a list of 16-byte hex-encoded strings identifying the trackers to be deleted. diff --git a/teos/block_processor.py b/teos/block_processor.py index 7b9ea24..136e9f7 100644 --- a/teos/block_processor.py +++ b/teos/block_processor.py @@ -14,7 +14,7 @@ class BlockProcessor: Args: btc_connect_params (:obj:`dict`): a dictionary with the parameters to connect to bitcoind - (rpc user, rpc passwd, host and port) + (rpc user, rpc password, host and port) """ def __init__(self, btc_connect_params): @@ -22,10 +22,10 @@ class BlockProcessor: def get_block(self, block_hash): """ - Gives a block given a block hash by querying ``bitcoind``. + Gets a block given a block hash by querying ``bitcoind``. Args: - block_hash (:obj:`str`): The block hash to be queried. + block_hash (:obj:`str`): the block hash to be queried. Returns: :obj:`dict` or :obj:`None`: A dictionary containing the requested block data if the block is found. @@ -44,7 +44,7 @@ class BlockProcessor: def get_best_block_hash(self): """ - Returns the hash of the current best chain tip. + Gets the hash of the current best chain tip. Returns: :obj:`str` or :obj:`None`: The hash of the block if it can be found. @@ -63,10 +63,10 @@ class BlockProcessor: def get_block_count(self): """ - Returns the block height of the best chain. + Gets the block count of the best chain. Returns: - :obj:`int` or :obj:`None`: The block height if it can be computed. + :obj:`int` or :obj:`None`: The count of the best chain if it can be computed. Returns ``None`` otherwise (not even sure this can actually happen). """ @@ -86,7 +86,7 @@ class BlockProcessor: associated metadata given by ``bitcoind`` (e.g. confirmation count). Args: - raw_tx (:obj:`str`): The hex representation of the transaction. + raw_tx (:obj:`str`): the hex representation of the transaction. Returns: :obj:`dict` or :obj:`None`: The decoding of the given ``raw_tx`` if the transaction is well formatted. @@ -99,7 +99,7 @@ class BlockProcessor: except JSONRPCException as e: tx = None - logger.error("Can't build transaction from decoded data", error=e.error) + logger.error("Cannot build transaction from decoded data", error=e.error) return tx @@ -133,7 +133,7 @@ class BlockProcessor: def get_missed_blocks(self, last_know_block_hash): """ - Compute the blocks between the current best chain tip and a given block hash (``last_know_block_hash``). + Gets the blocks between the current best chain tip and a given block hash (``last_know_block_hash``). This method is used to fetch all the missed information when recovering from a crash. @@ -158,7 +158,7 @@ class BlockProcessor: def is_block_in_best_chain(self, block_hash): """ - Checks whether or not a given block is on the best chain. Blocks are identified by block_hash. + Checks whether a given block is on the best chain or not. Blocks are identified by block_hash. A block that is not in the best chain will either not exists (block = None) or have a confirmation count of -1 (implying that the block was forked out or the chain never grew from that one). diff --git a/teos/builder.py b/teos/builder.py index 831236f..56f489d 100644 --- a/teos/builder.py +++ b/teos/builder.py @@ -1,49 +1,50 @@ +from teos.responder import TransactionTracker +from teos.extended_appointment import ExtendedAppointment + + class Builder: """ - The :class:`Builder` class is in charge of reconstructing data loaded from the database and build the data - structures of the :obj:`Watcher ` and the :obj:`Responder `. + The :class:`Builder` class is in charge of reconstructing data loaded from the appointments database and build the + data structures of the :obj:`Watcher ` and the :obj:`Responder `. """ @staticmethod def build_appointments(appointments_data): """ - Builds an appointments dictionary (``uuid: Appointment``) and a locator_uuid_map (``locator: uuid``) given a - dictionary of appointments from the database. + Builds an appointments dictionary (``uuid:ExtendedAppointment``) and a locator_uuid_map (``locator:uuid``) + given a dictionary of appointments from the database. Args: appointments_data (:obj:`dict`): a dictionary of dictionaries representing all the :obj:`Watcher ` appointments stored in the database. The structure is as follows: - ``{uuid: {locator: str, start_time: int, ...}, uuid: {locator:...}}`` + ``{uuid: {locator: str, ...}, uuid: {locator:...}}`` Returns: :obj:`tuple`: A tuple with two dictionaries. ``appointments`` containing the appointment information in - :obj:`Appointment ` objects and ``locator_uuid_map`` containing a map of - appointment (``uuid:locator``). + :obj:`ExtendedAppointment ` objects and ``locator_uuid_map`` + containing a map of appointment (``uuid:locator``). """ appointments = {} locator_uuid_map = {} for uuid, data in appointments_data.items(): - appointments[uuid] = { - "locator": data.get("locator"), - "end_time": data.get("end_time"), - "size": len(data.get("encrypted_blob")), - } + appointment = ExtendedAppointment.from_dict(data) + appointments[uuid] = appointment.get_summary() - if data.get("locator") in locator_uuid_map: - locator_uuid_map[data.get("locator")].append(uuid) + if appointment.locator in locator_uuid_map: + locator_uuid_map[appointment.locator].append(uuid) else: - locator_uuid_map[data.get("locator")] = [uuid] + locator_uuid_map[appointment.locator] = [uuid] return appointments, locator_uuid_map @staticmethod def build_trackers(tracker_data): """ - Builds a tracker dictionary (``uuid: TransactionTracker``) and a tx_tracker_map (``penalty_txid: uuid``) given + Builds a tracker dictionary (``uuid:TransactionTracker``) and a tx_tracker_map (``penalty_txid:uuid``) given a dictionary of trackers from the database. Args: @@ -64,17 +65,14 @@ class Builder: tx_tracker_map = {} for uuid, data in tracker_data.items(): - trackers[uuid] = { - "penalty_txid": data.get("penalty_txid"), - "locator": data.get("locator"), - "appointment_end": data.get("appointment_end"), - } + tracker = TransactionTracker.from_dict(data) + trackers[uuid] = tracker.get_summary() - if data.get("penalty_txid") in tx_tracker_map: - tx_tracker_map[data.get("penalty_txid")].append(uuid) + if tracker.penalty_txid in tx_tracker_map: + tx_tracker_map[tracker.penalty_txid].append(uuid) else: - tx_tracker_map[data.get("penalty_txid")] = [uuid] + tx_tracker_map[tracker.penalty_txid] = [uuid] return trackers, tx_tracker_map @@ -85,8 +83,8 @@ class Builder: :mod:`Responder ` using backed up data. Args: - block_queue (:obj:`Queue`): a ``Queue`` - missed_blocks (:obj:`list`): list of block hashes missed by the Watchtower (do to a crash or shutdown). + block_queue (:obj:`Queue`): a ``Queue``. + missed_blocks (:obj:`list`): list of block hashes missed by the Watchtower (due to a crash or shutdown). Returns: :obj:`Queue`: A ``Queue`` containing all the missed blocks hashes. diff --git a/teos/carrier.py b/teos/carrier.py index c8bba1c..11b578f 100644 --- a/teos/carrier.py +++ b/teos/carrier.py @@ -3,7 +3,7 @@ from common.logger import Logger from teos.tools import bitcoin_cli import teos.rpc_errors as rpc_errors from teos.utils.auth_proxy import JSONRPCException -from teos.errors import UNKNOWN_JSON_RPC_EXCEPTION, RPC_TX_REORGED_AFTER_BROADCAST +from common.errors import UNKNOWN_JSON_RPC_EXCEPTION, RPC_TX_REORGED_AFTER_BROADCAST logger = Logger(actor="Carrier", log_name_prefix=LOG_PREFIX) diff --git a/teos/chain_monitor.py b/teos/chain_monitor.py index 654fffa..6f1522d 100644 --- a/teos/chain_monitor.py +++ b/teos/chain_monitor.py @@ -20,14 +20,14 @@ class ChainMonitor: Args: watcher_queue (:obj:`Queue`): the queue to be used to send blocks hashes to the ``Watcher``. responder_queue (:obj:`Queue`): the queue to be used to send blocks hashes to the ``Responder``. - block_processor (:obj:`BlockProcessor `): a blockProcessor instance. + block_processor (:obj:`BlockProcessor `): a ``BlockProcessor`` instance. bitcoind_feed_params (:obj:`dict`): a dict with the feed (ZMQ) connection parameters. Attributes: best_tip (:obj:`str`): a block hash representing the current best tip. last_tips (:obj:`list`): a list of last chain tips. Used as a sliding window to avoid notifying about old tips. terminate (:obj:`bool`): a flag to signal the termination of the :class:`ChainMonitor` (shutdown the tower). - check_tip (:obj:`Event`): an event that's triggered at fixed time intervals and controls the polling thread. + check_tip (:obj:`Event`): an event that is triggered at fixed time intervals and controls the polling thread. lock (:obj:`Condition`): a lock used to protect concurrent access to the queues and ``best_tip`` by the zmq and polling threads. zmqSubSocket (:obj:`socket`): a socket to connect to ``bitcoind`` via ``zmq``. diff --git a/teos/cleaner.py b/teos/cleaner.py index a9cba0a..938fcfd 100644 --- a/teos/cleaner.py +++ b/teos/cleaner.py @@ -87,7 +87,7 @@ class Cleaner: @staticmethod def delete_expired_appointments(expired_appointments, appointments, locator_uuid_map, db_manager): """ - Deletes appointments which ``end_time`` has been reached (with no trigger) both from memory + Deletes appointments which ``expiry`` has been reached (with no trigger) both from memory (:obj:`Watcher `) and disk. Args: @@ -181,31 +181,37 @@ class Cleaner: db_manager.create_triggered_appointment_flag(uuid) @staticmethod - def delete_completed_trackers(completed_trackers, height, trackers, tx_tracker_map, db_manager): + def delete_trackers(completed_trackers, height, trackers, tx_tracker_map, db_manager, expired=False): """ - Deletes a completed tracker both from memory (:obj:`Responder `) and disk (from the - Responder's and Watcher's databases). + Deletes completed/expired trackers both from memory (:obj:`Responder `) and disk + (from the Responder's and Watcher's databases). Args: trackers (:obj:`dict`): a dictionary containing all the :obj:`Responder ` trackers. + height (:obj:`int`): the block height at which the trackers were completed. tx_tracker_map (:obj:`dict`): a ``penalty_txid:uuid`` map for the :obj:`Responder ` trackers. - completed_trackers (:obj:`dict`): a dict of completed trackers to be deleted (uuid:confirmations). - height (:obj:`int`): the block height at which the trackers were completed. + completed_trackers (:obj:`dict`): a dict of completed/expired trackers to be deleted (uuid:confirmations). db_manager (:obj:`AppointmentsDBM `): a ``AppointmentsDBM`` instance to interact with the database. + expired (:obj:`bool`): whether the trackers have expired or not. Defaults to False. """ locator_maps_to_update = {} - for uuid, confirmations in completed_trackers.items(): - logger.info( - "Appointment completed. Appointment ended after reaching enough confirmations", - uuid=uuid, - height=height, - confirmations=confirmations, - ) + for uuid in completed_trackers: + + if expired: + logger.info( + "Appointment couldn't be completed. Expiry reached but penalty didn't make it to the chain", + uuid=uuid, + height=height, + ) + else: + logger.info( + "Appointment completed. Penalty transaction was irrevocably confirmed", uuid=uuid, height=height + ) penalty_txid = trackers[uuid].get("penalty_txid") locator = trackers[uuid].get("locator") @@ -229,6 +235,35 @@ class Cleaner: Cleaner.update_delete_db_locator_map(uuids, locator, db_manager) # Delete appointment from the db (from watchers's and responder's db) and remove flag - db_manager.batch_delete_responder_trackers(list(completed_trackers.keys())) - db_manager.batch_delete_watcher_appointments(list(completed_trackers.keys())) - db_manager.batch_delete_triggered_appointment_flag(list(completed_trackers.keys())) + db_manager.batch_delete_responder_trackers(completed_trackers) + db_manager.batch_delete_watcher_appointments(completed_trackers) + db_manager.batch_delete_triggered_appointment_flag(completed_trackers) + + @staticmethod + def delete_gatekeeper_appointments(gatekeeper, appointment_to_delete): + """ + Deletes a list of expired / completed appointments of a given user both from memory and the UserDB. + + Args: + gatekeeper (:obj:`Gatekeeper `): a `Gatekeeper` instance in charge to control + the user access and subscription expiry. + appointment_to_delete (:obj:`dict`): uuid:user_id dict containing the appointments to delete + (expired + completed) + """ + + user_ids = [] + # Remove appointments from memory + for uuid, user_id in appointment_to_delete.items(): + if user_id in gatekeeper.registered_users and uuid in gatekeeper.registered_users[user_id].appointments: + # Remove the appointment from the appointment list and update the available slots + gatekeeper.lock.acquire() + freed_slots = gatekeeper.registered_users[user_id].appointments.pop(uuid) + gatekeeper.registered_users[user_id].available_slots += freed_slots + gatekeeper.lock.release() + + if user_id not in user_ids: + user_ids.append(user_id) + + # Store the updated users in the DB + for user_id in user_ids: + gatekeeper.user_db.store_user(user_id, gatekeeper.registered_users[user_id].to_dict()) diff --git a/teos/extended_appointment.py b/teos/extended_appointment.py new file mode 100644 index 0000000..ca6a88b --- /dev/null +++ b/teos/extended_appointment.py @@ -0,0 +1,46 @@ +from common.appointment import Appointment + + +class ExtendedAppointment(Appointment): + def __init__(self, locator, to_self_delay, encrypted_blob, user_id): + super().__init__(locator, to_self_delay, encrypted_blob) + self.user_id = user_id + + def get_summary(self): + """ + Returns the summary of an appointment, consisting on the locator, the user_id and the appointment size. + + Returns: + :obj:`dict`: the appointment summary. + """ + return {"locator": self.locator, "user_id": self.user_id} + + @classmethod + def from_dict(cls, appointment_data): + """ + Builds an appointment from a dictionary. + + This method is useful to load data from a database. + + Args: + appointment_data (:obj:`dict`): a dictionary containing the following keys: + ``{locator, to_self_delay, encrypted_blob, user_id}`` + + Returns: + :obj:`ExtendedAppointment `: An appointment initialized + using the provided data. + + Raises: + ValueError: If one of the mandatory keys is missing in ``appointment_data``. + """ + + appointment = Appointment.from_dict(appointment_data) + user_id = appointment_data.get("user_id") + + if not user_id: + raise ValueError("Wrong appointment data, user_id is missing") + + else: + appointment = cls(appointment.locator, appointment.to_self_delay, appointment.encrypted_blob, user_id) + + return appointment diff --git a/teos/gatekeeper.py b/teos/gatekeeper.py index 79b5efc..4386fa6 100644 --- a/teos/gatekeeper.py +++ b/teos/gatekeeper.py @@ -1,62 +1,115 @@ +from math import ceil +from threading import Lock + from common.tools import is_compressed_pk from common.cryptographer import Cryptographer +from common.constants import ENCRYPTED_BLOB_MAX_SIZE_HEX +from common.exceptions import InvalidParameter, InvalidKey, SignatureError class NotEnoughSlots(ValueError): """Raised when trying to subtract more slots than a user has available""" - def __init__(self, user_pk, requested_slots): - self.user_pk = user_pk - self.requested_slots = requested_slots + pass -class IdentificationFailure(Exception): +class AuthenticationFailure(Exception): """ - Raised when a user can not be identified. Either the user public key cannot be recovered or the user is + Raised when a user can not be authenticated. Either the user public key cannot be recovered or the user is not found within the registered ones. """ pass +class UserInfo: + def __init__(self, available_slots, subscription_expiry, appointments=None): + self.available_slots = available_slots + self.subscription_expiry = subscription_expiry + + if not appointments: + # A dictionary of the form uuid:required_slots for each user appointment + self.appointments = {} + else: + self.appointments = appointments + + @classmethod + def from_dict(cls, user_data): + available_slots = user_data.get("available_slots") + appointments = user_data.get("appointments") + subscription_expiry = user_data.get("subscription_expiry") + + if any(v is None for v in [available_slots, appointments, subscription_expiry]): + raise ValueError("Wrong appointment data, some fields are missing") + + return cls(available_slots, subscription_expiry, appointments) + + def to_dict(self): + return self.__dict__ + + class Gatekeeper: """ The :class:`Gatekeeper` is in charge of managing the access to the tower. Only registered users are allowed to perform actions. Attributes: - registered_users (:obj:`dict`): a map of user_pk:appointment_slots. + default_slots (:obj:`int`): the number of slots assigned to a user subscription. + default_subscription_duration (:obj:`int`): the expiry assigned to a user subscription. + expiry_delta (:obj:`int`): the grace period given to the user to renew their subscription. + block_processor (:obj:`BlockProcessor `): a ``BlockProcessor`` instance to + get block from bitcoind. + user_db (:obj:`UserDBM `): a ``UserDBM`` instance to interact with the database. + registered_users (:obj:`dict`): a map of user_pk:UserInfo. + lock (:obj:`Lock`): a Threading.Lock object to lock access to the Gatekeeper on updates. + """ - def __init__(self, user_db, default_slots): + def __init__(self, user_db, block_processor, default_slots, default_subscription_duration, expiry_delta): self.default_slots = default_slots + self.default_subscription_duration = default_subscription_duration + self.expiry_delta = expiry_delta + self.block_processor = block_processor self.user_db = user_db - self.registered_users = user_db.load_all_users() + self.registered_users = { + user_id: UserInfo.from_dict(user_data) for user_id, user_data in user_db.load_all_users().items() + } + self.lock = Lock() - def add_update_user(self, user_pk): + def add_update_user(self, user_id): """ Adds a new user or updates the subscription of an existing one, by adding additional slots. Args: - user_pk(:obj:`str`): the public key that identifies the user (33-bytes hex str). + user_id(:obj:`str`): the public key that identifies the user (33-bytes hex str). Returns: - :obj:`int`: the number of available slots in the user subscription. + :obj:`tuple`: a tuple with the number of available slots in the user subscription and the subscription + expiry (in absolute block height). + + Raises: + :obj:`InvalidParameter`: if the user_pk does not match the expected format. """ - if not is_compressed_pk(user_pk): - raise ValueError("Provided public key does not match expected format (33-byte hex string)") + if not is_compressed_pk(user_id): + raise InvalidParameter("Provided public key does not match expected format (33-byte hex string)") - if user_pk not in self.registered_users: - self.registered_users[user_pk] = {"available_slots": self.default_slots} + if user_id not in self.registered_users: + self.registered_users[user_id] = UserInfo( + self.default_slots, self.block_processor.get_block_count() + self.default_subscription_duration + ) else: - self.registered_users[user_pk]["available_slots"] += self.default_slots + # FIXME: For now new calls to register add default_slots to the current count and reset the expiry time + self.registered_users[user_id].available_slots += self.default_slots + self.registered_users[user_id].subscription_expiry = ( + self.block_processor.get_block_count() + self.default_subscription_duration + ) - self.user_db.store_user(user_pk, self.registered_users[user_pk]) + self.user_db.store_user(user_id, self.registered_users[user_id].to_dict()) - return self.registered_users[user_pk]["available_slots"] + return self.registered_users[user_id].available_slots, self.registered_users[user_id].subscription_expiry - def identify_user(self, message, signature): + def authenticate_user(self, message, signature): """ Checks if a request comes from a registered user by ec-recovering their public key from a signed message. @@ -68,50 +121,83 @@ class Gatekeeper: :obj:`str`: a compressed key recovered from the signature and matching a registered user. Raises: - :obj:`IdentificationFailure`: if the user cannot be identified. + :obj:`AuthenticationFailure`: if the user cannot be authenticated. """ - if isinstance(message, bytes) and isinstance(signature, str): + try: rpk = Cryptographer.recover_pk(message, signature) - compressed_pk = Cryptographer.get_compressed_pk(rpk) + user_id = Cryptographer.get_compressed_pk(rpk) - if compressed_pk in self.registered_users: - return compressed_pk + if user_id in self.registered_users: + return user_id else: - raise IdentificationFailure("User not found.") + raise AuthenticationFailure("User not found.") - else: - raise IdentificationFailure("Wrong message or signature.") + except (InvalidParameter, InvalidKey, SignatureError): + raise AuthenticationFailure("Wrong message or signature.") - def fill_slots(self, user_pk, n): + def add_update_appointment(self, user_id, uuid, appointment): """ - Fills a given number os slots of the user subscription. + Adds (or updates) an appointment to a user subscription. The user slots are updated accordingly. + + Slots are taken if a new appointment is given, or an update is given with an appointment bigger than the + existing one. + + Slots are given back if an update is given but the new appointment is smaller than the existing one. Args: - user_pk(:obj:`str`): the public key that identifies the user (33-bytes hex str). - n (:obj:`int`): the number of slots to fill. + user_id (:obj:`str`): the public key that identifies the user (33-bytes hex str). + uuid (:obj:`str`): the appointment uuid. + appointment (:obj:`ExtendedAppointment `: An appointment initialized with the provided data. + :obj:`Extended `: An appointment initialized with + the provided data. Raises: :obj:`InspectionFailed`: if any of the fields is wrong. @@ -67,12 +66,16 @@ class Inspector: raise InspectionFailed(errors.UNKNOWN_JSON_RPC_EXCEPTION, "unexpected error occurred") self.check_locator(appointment_data.get("locator")) - self.check_start_time(appointment_data.get("start_time"), block_height) - self.check_end_time(appointment_data.get("end_time"), appointment_data.get("start_time"), block_height) self.check_to_self_delay(appointment_data.get("to_self_delay")) self.check_blob(appointment_data.get("encrypted_blob")) - return Appointment.from_dict(appointment_data) + # Set user_id to None since we still don't know it, it'll be set by the API after querying the gatekeeper + return ExtendedAppointment( + appointment_data.get("locator"), + appointment_data.get("to_self_delay"), + appointment_data.get("encrypted_blob"), + user_id=None, + ) @staticmethod def check_locator(locator): @@ -102,87 +105,6 @@ class Inspector: elif not is_locator(locator): raise InspectionFailed(errors.APPOINTMENT_WRONG_FIELD_FORMAT, "wrong locator format ({})".format(locator)) - @staticmethod - def check_start_time(start_time, block_height): - """ - Checks if the provided ``start_time`` is correct. - - Start times must be ahead the current best chain tip. - - Args: - start_time (:obj:`int`): the block height at which the tower is requested to start watching for breaches. - block_height (:obj:`int`): the chain height. - - Raises: - :obj:`InspectionFailed`: if any of the fields is wrong. - """ - - if start_time is None: - raise InspectionFailed(errors.APPOINTMENT_EMPTY_FIELD, "empty start_time received") - - elif type(start_time) != int: - raise InspectionFailed( - errors.APPOINTMENT_WRONG_FIELD_TYPE, "wrong start_time data type ({})".format(type(start_time)) - ) - - elif start_time < block_height: - raise InspectionFailed(errors.APPOINTMENT_FIELD_TOO_SMALL, "start_time is in the past") - - elif start_time == block_height: - raise InspectionFailed( - errors.APPOINTMENT_FIELD_TOO_SMALL, - "start_time is too close to current height. Accepted times are: [current_height+1, current_height+6]", - ) - - elif start_time > block_height + 6: - raise InspectionFailed( - errors.APPOINTMENT_FIELD_TOO_BIG, - "start_time is too far in the future. Accepted start times are up to 6 blocks in the future", - ) - - @staticmethod - def check_end_time(end_time, start_time, block_height): - """ - Checks if the provided ``end_time`` is correct. - - End times must be ahead both the ``start_time`` and the current best chain tip. - - Args: - end_time (:obj:`int`): the block height at which the tower is requested to stop watching for breaches. - start_time (:obj:`int`): the block height at which the tower is requested to start watching for breaches. - block_height (:obj:`int`): the chain height. - - Raises: - :obj:`InspectionFailed`: if any of the fields is wrong. - """ - - # TODO: What's too close to the current height is not properly defined. Right now any appointment that ends in - # the future will be accepted (even if it's only one block away). - - if end_time is None: - raise InspectionFailed(errors.APPOINTMENT_EMPTY_FIELD, "empty end_time received") - - elif type(end_time) != int: - raise InspectionFailed( - errors.APPOINTMENT_WRONG_FIELD_TYPE, "wrong end_time data type ({})".format(type(end_time)) - ) - - elif end_time > block_height + BLOCKS_IN_A_MONTH: # 4320 = roughly a month in blocks - raise InspectionFailed( - errors.APPOINTMENT_FIELD_TOO_BIG, "end_time should be within the next month (<= current_height + 4320)" - ) - elif start_time > end_time: - raise InspectionFailed(errors.APPOINTMENT_FIELD_TOO_SMALL, "end_time is smaller than start_time") - - elif start_time == end_time: - raise InspectionFailed(errors.APPOINTMENT_FIELD_TOO_SMALL, "end_time is equal to start_time") - - elif block_height > end_time: - raise InspectionFailed(errors.APPOINTMENT_FIELD_TOO_SMALL, "end_time is in the past") - - elif block_height == end_time: - raise InspectionFailed(errors.APPOINTMENT_FIELD_TOO_SMALL, "end_time is too close to current height") - def check_to_self_delay(self, to_self_delay): """ Checks if the provided ``to_self_delay`` is correct. @@ -219,7 +141,6 @@ class Inspector: ), ) - # ToDo: #6-define-checks-encrypted-blob @staticmethod def check_blob(encrypted_blob): """ diff --git a/teos/responder.py b/teos/responder.py index 862733f..3953352 100644 --- a/teos/responder.py +++ b/teos/responder.py @@ -2,9 +2,11 @@ from queue import Queue from threading import Thread from teos import LOG_PREFIX -from common.logger import Logger from teos.cleaner import Cleaner +from common.logger import Logger +from common.constants import IRREVOCABLY_RESOLVED + CONFIRMATIONS_BEFORE_RETRY = 6 MIN_CONFIRMATIONS = 6 @@ -26,16 +28,15 @@ class TransactionTracker: dispute_txid (:obj:`str`): the id of the transaction that created the channel breach and triggered the penalty. penalty_txid (:obj:`str`): the id of the transaction that was encrypted under ``dispute_txid``. penalty_rawtx (:obj:`str`): the raw transaction that was broadcast as a consequence of the channel breach. - appointment_end (:obj:`int`): the block at which the tower will stop monitoring the blockchain for this - appointment. + user_id(:obj:`str`): the public key that identifies the user (33-bytes hex str). """ - def __init__(self, locator, dispute_txid, penalty_txid, penalty_rawtx, appointment_end): + def __init__(self, locator, dispute_txid, penalty_txid, penalty_rawtx, user_id): self.locator = locator self.dispute_txid = dispute_txid self.penalty_txid = penalty_txid self.penalty_rawtx = penalty_rawtx - self.appointment_end = appointment_end + self.user_id = user_id @classmethod def from_dict(cls, tx_tracker_data): @@ -60,13 +61,13 @@ class TransactionTracker: dispute_txid = tx_tracker_data.get("dispute_txid") penalty_txid = tx_tracker_data.get("penalty_txid") penalty_rawtx = tx_tracker_data.get("penalty_rawtx") - appointment_end = tx_tracker_data.get("appointment_end") + user_id = tx_tracker_data.get("user_id") - if any(v is None for v in [locator, dispute_txid, penalty_txid, penalty_rawtx, appointment_end]): + if any(v is None for v in [locator, dispute_txid, penalty_txid, penalty_rawtx, user_id]): raise ValueError("Wrong transaction tracker data, some fields are missing") else: - tx_tracker = cls(locator, dispute_txid, penalty_txid, penalty_rawtx, appointment_end) + tx_tracker = cls(locator, dispute_txid, penalty_txid, penalty_rawtx, user_id) return tx_tracker @@ -83,11 +84,21 @@ class TransactionTracker: "dispute_txid": self.dispute_txid, "penalty_txid": self.penalty_txid, "penalty_rawtx": self.penalty_rawtx, - "appointment_end": self.appointment_end, + "user_id": self.user_id, } return tx_tracker + def get_summary(self): + """ + Returns the summary of a tracker, consisting on the locator, the user_id and the penalty_txid. + + Returns: + :obj:`dict`: the appointment summary. + """ + + return {"locator": self.locator, "user_id": self.user_id, "penalty_txid": self.penalty_txid} + class Responder: """ @@ -104,7 +115,7 @@ class Responder: Attributes: trackers (:obj:`dict`): A dictionary containing the minimum information about the :obj:`TransactionTracker` - required by the :obj:`Responder` (``penalty_txid``, ``locator`` and ``end_time``). + required by the :obj:`Responder` (``penalty_txid``, ``locator`` and ``user_id``). Each entry is identified by a ``uuid``. tx_tracker_map (:obj:`dict`): A ``penalty_txid:uuid`` map used to allow the :obj:`Responder` to deal with several trackers triggered by the same ``penalty_txid``. @@ -115,19 +126,22 @@ class Responder: is populated by the :obj:`ChainMonitor `. db_manager (:obj:`AppointmentsDBM `): a ``AppointmentsDBM`` instance to interact with the database. + gatekeeper (:obj:`Gatekeeper `): a `Gatekeeper` instance in charge to control the + user access and subscription expiry. carrier (:obj:`Carrier `): a ``Carrier`` instance to send transactions to bitcoind. block_processor (:obj:`BlockProcessor `): a ``BlockProcessor`` instance to get data from bitcoind. last_known_block (:obj:`str`): the last block known by the ``Responder``. """ - def __init__(self, db_manager, carrier, block_processor): + def __init__(self, db_manager, gatekeeper, carrier, block_processor): self.trackers = dict() self.tx_tracker_map = dict() self.unconfirmed_txs = [] self.missed_confirmations = dict() self.block_queue = Queue() self.db_manager = db_manager + self.gatekeeper = gatekeeper self.carrier = carrier self.block_processor = block_processor self.last_known_block = db_manager.load_last_block_hash_responder() @@ -169,7 +183,7 @@ class Responder: return synchronized - def handle_breach(self, uuid, locator, dispute_txid, penalty_txid, penalty_rawtx, appointment_end, block_hash): + def handle_breach(self, uuid, locator, dispute_txid, penalty_txid, penalty_rawtx, user_id, block_hash): """ Requests the :obj:`Responder` to handle a channel breach. This is the entry point of the :obj:`Responder`. @@ -179,8 +193,7 @@ class Responder: dispute_txid (:obj:`str`): the id of the transaction that created the channel breach. penalty_txid (:obj:`str`): the id of the decrypted transaction included in the appointment. penalty_rawtx (:obj:`str`): the raw transaction to be broadcast in response of the breach. - appointment_end (:obj:`int`): the block height at which the :obj:`Responder` will stop monitoring for this - penalty transaction. + user_id(:obj:`str`): the public key that identifies the user (33-bytes hex str). block_hash (:obj:`str`): the block hash at which the breach was seen (used to see if we are on sync). Returns: @@ -191,9 +204,7 @@ class Responder: receipt = self.carrier.send_transaction(penalty_rawtx, penalty_txid) if receipt.delivered: - self.add_tracker( - uuid, locator, dispute_txid, penalty_txid, penalty_rawtx, appointment_end, receipt.confirmations - ) + self.add_tracker(uuid, locator, dispute_txid, penalty_txid, penalty_rawtx, user_id, receipt.confirmations) else: # TODO: Add the missing reasons (e.g. RPC_VERIFY_REJECTED) @@ -204,7 +215,7 @@ class Responder: return receipt - def add_tracker(self, uuid, locator, dispute_txid, penalty_txid, penalty_rawtx, appointment_end, confirmations=0): + def add_tracker(self, uuid, locator, dispute_txid, penalty_txid, penalty_rawtx, user_id, confirmations=0): """ Creates a :obj:`TransactionTracker` after successfully broadcasting a ``penalty_tx``. @@ -217,20 +228,15 @@ class Responder: dispute_txid (:obj:`str`): the id of the transaction that created the channel breach. penalty_txid (:obj:`str`): the id of the decrypted transaction included in the appointment. penalty_rawtx (:obj:`str`): the raw transaction to be broadcast. - appointment_end (:obj:`int`): the block height at which the :obj:`Responder` will stop monitoring for the - tracker. + user_id(:obj:`str`): the public key that identifies the user (33-bytes hex str). confirmations (:obj:`int`): the confirmation count of the ``penalty_tx``. In normal conditions it will be zero, but if the transaction is already on the blockchain this won't be the case. """ - tracker = TransactionTracker(locator, dispute_txid, penalty_txid, penalty_rawtx, appointment_end) + tracker = TransactionTracker(locator, dispute_txid, penalty_txid, penalty_rawtx, user_id) - # We only store the penalty_txid, locator and appointment_end in memory. The rest is dumped into the db. - self.trackers[uuid] = { - "penalty_txid": tracker.penalty_txid, - "locator": locator, - "appointment_end": appointment_end, - } + # We only store the penalty_txid, locator and user_id in memory. The rest is dumped into the db. + self.trackers[uuid] = tracker.get_summary() if penalty_txid in self.tx_tracker_map: self.tx_tracker_map[penalty_txid].append(uuid) @@ -244,9 +250,7 @@ class Responder: self.db_manager.store_responder_tracker(uuid, tracker.to_dict()) - logger.info( - "New tracker added", dispute_txid=dispute_txid, penalty_txid=penalty_txid, appointment_end=appointment_end - ) + logger.info("New tracker added", dispute_txid=dispute_txid, penalty_txid=penalty_txid, user_id=user_id) def do_watch(self): """ @@ -269,17 +273,28 @@ class Responder: if len(self.trackers) > 0 and block is not None: txids = block.get("tx") + completed_trackers = self.get_completed_trackers() + expired_trackers = self.get_expired_trackers(block.get("height")) + trackers_to_delete_gatekeeper = { + uuid: self.trackers[uuid].get("user_id") for uuid in completed_trackers + expired_trackers + } + if self.last_known_block == block.get("previousblockhash"): self.check_confirmations(txids) - - height = block.get("height") - completed_trackers = self.get_completed_trackers(height) - Cleaner.delete_completed_trackers( - completed_trackers, height, self.trackers, self.tx_tracker_map, self.db_manager + Cleaner.delete_trackers( + completed_trackers, block.get("height"), self.trackers, self.tx_tracker_map, self.db_manager ) + Cleaner.delete_trackers( + expired_trackers, + block.get("height"), + self.trackers, + self.tx_tracker_map, + self.db_manager, + expired=True, + ) + Cleaner.delete_gatekeeper_appointments(self.gatekeeper, trackers_to_delete_gatekeeper) - txs_to_rebroadcast = self.get_txs_to_rebroadcast() - self.rebroadcast(txs_to_rebroadcast) + self.rebroadcast(self.get_txs_to_rebroadcast()) # NOTCOVERED else: @@ -295,7 +310,7 @@ class Responder: # Clear the receipts issued in this block self.carrier.issued_receipts = {} - if len(self.trackers) != 0: + if len(self.trackers) == 0: logger.info("No more pending trackers") # Register the last processed block for the responder @@ -326,7 +341,6 @@ class Responder: for tx in self.unconfirmed_txs: if tx in self.missed_confirmations: self.missed_confirmations[tx] += 1 - else: self.missed_confirmations[tx] = 1 @@ -349,26 +363,24 @@ class Responder: return txs_to_rebroadcast - def get_completed_trackers(self, height): + def get_completed_trackers(self): """ - Gets the trackers that has already been fulfilled based on a given height (``end_time`` was reached with a - minimum confirmation count). - - Args: - height (:obj:`int`): the height of the last received block. + Gets the trackers that has already been fulfilled based on a given height (the justice transaction is + irrevocably resolved). Returns: - :obj:`dict`: a dict (``uuid:confirmations``) of the completed trackers. + :obj:`list`: a list of completed trackers uuids. """ - completed_trackers = {} + completed_trackers = [] + # FIXME: This is here for duplicated penalties, we should be able to get rid of it once we prevent duplicates in + # the responder. checked_txs = {} - for uuid, tracker_data in self.trackers.items(): - appointment_end = tracker_data.get("appointment_end") - penalty_txid = tracker_data.get("penalty_txid") - if appointment_end <= height and penalty_txid not in self.unconfirmed_txs: - + # Avoiding dictionary changed size during iteration + for uuid in list(self.trackers.keys()): + penalty_txid = self.trackers[uuid].get("penalty_txid") + if penalty_txid not in self.unconfirmed_txs: if penalty_txid not in checked_txs: tx = self.carrier.get_transaction(penalty_txid) else: @@ -378,16 +390,37 @@ class Responder: confirmations = tx.get("confirmations") checked_txs[penalty_txid] = tx - if confirmations is not None and confirmations >= MIN_CONFIRMATIONS: - # The end of the appointment has been reached - completed_trackers[uuid] = confirmations + if confirmations is not None and confirmations >= IRREVOCABLY_RESOLVED: + completed_trackers.append(uuid) return completed_trackers + def get_expired_trackers(self, height): + """ + Gets trackers than are expired due to the user subscription expiring. + + Only gets those trackers which penalty transaction is not going trough (probably because of low fees), the rest + will be eventually completed once they are irrevocably resolved. + + Args: + height (:obj:`int`): the height of the last received block. + + Returns: + :obj:`list`: a list of the expired trackers uuids. + """ + + expired_trackers = [ + uuid + for uuid in self.gatekeeper.get_expired_appointments(height) + if self.trackers[uuid].get("penalty_txid") in self.unconfirmed_txs + ] + + return expired_trackers + def rebroadcast(self, txs_to_rebroadcast): """ - Rebroadcasts a ``penalty_tx`` that has missed too many confirmations. In the current approach this would loop - forever if the transaction keeps not getting it. + Rebroadcasts a ``penalty_tx`` that has missed too many confirmations. In the current approach this will loop + until the tracker expires if the penalty transactions keeps getting rejected due to fees. Potentially, the fees could be bumped here if the transaction has some tower dedicated outputs (or allows it trough ``ANYONECANPAY`` or something similar). @@ -436,7 +469,8 @@ class Responder: """ - for uuid in self.trackers.keys(): + # Avoiding dictionary changed size during iteration + for uuid in list(self.trackers.keys()): tracker = TransactionTracker.from_dict(self.db_manager.load_responder_tracker(uuid)) # First we check if the dispute transaction is known (exists either in mempool or blockchain) @@ -465,7 +499,7 @@ class Responder: tracker.dispute_txid, tracker.penalty_txid, tracker.penalty_rawtx, - tracker.appointment_end, + tracker.user_id, block_hash, ) diff --git a/teos/teosd.py b/teos/teosd.py index d68d38f..f34a52b 100644 --- a/teos/teosd.py +++ b/teos/teosd.py @@ -3,7 +3,6 @@ from sys import argv, exit from getopt import getopt, GetoptError from signal import signal, SIGINT, SIGQUIT, SIGTERM -import common.cryptographer from common.logger import Logger from common.config_loader import ConfigLoader from common.cryptographer import Cryptographer @@ -25,7 +24,6 @@ from teos.tools import can_connect_to_bitcoind, in_correct_network from teos import LOG_PREFIX, DATA_DIR, DEFAULT_CONF, CONF_FILE_NAME logger = Logger(actor="Daemon", log_name_prefix=LOG_PREFIX) -common.cryptographer.logger = Logger(actor="Cryptographer", log_name_prefix=LOG_PREFIX) def handle_signals(signal_received, frame): @@ -52,13 +50,12 @@ def main(command_line_conf): setup_logging(config.get("LOG_FILE"), LOG_PREFIX) logger.info("Starting TEOS") - db_manager = AppointmentsDBM(config.get("APPOINTMENTS_DB_PATH")) bitcoind_connect_params = {k: v for k, v in config.items() if k.startswith("BTC")} bitcoind_feed_params = {k: v for k, v in config.items() if k.startswith("FEED")} if not can_connect_to_bitcoind(bitcoind_connect_params): - logger.error("Can't connect to bitcoind. Shutting down") + logger.error("Cannot connect to bitcoind. Shutting down") elif not in_correct_network(bitcoind_connect_params, config.get("BTC_NETWORK")): logger.error("bitcoind is running on a different network, check conf.py and bitcoin.conf. Shutting down") @@ -67,20 +64,28 @@ def main(command_line_conf): try: secret_key_der = Cryptographer.load_key_file(config.get("TEOS_SECRET_KEY")) if not secret_key_der: - raise IOError("TEOS private key can't be loaded") + raise IOError("TEOS private key cannot be loaded") + logger.info( + "tower_id = {}".format( + Cryptographer.get_compressed_pk(Cryptographer.load_private_key_der(secret_key_der).public_key) + ) + ) block_processor = BlockProcessor(bitcoind_connect_params) carrier = Carrier(bitcoind_connect_params) - responder = Responder(db_manager, carrier, block_processor) - watcher = Watcher( - db_manager, + gatekeeper = Gatekeeper( + UsersDBM(config.get("USERS_DB_PATH")), block_processor, - responder, - secret_key_der, - config.get("MAX_APPOINTMENTS"), + config.get("DEFAULT_SLOTS"), + config.get("DEFAULT_SUBSCRIPTION_DURATION"), config.get("EXPIRY_DELTA"), ) + db_manager = AppointmentsDBM(config.get("APPOINTMENTS_DB_PATH")) + responder = Responder(db_manager, gatekeeper, carrier, block_processor) + watcher = Watcher( + db_manager, gatekeeper, block_processor, responder, secret_key_der, config.get("MAX_APPOINTMENTS") + ) # Create the chain monitor and start monitoring the chain chain_monitor = ChainMonitor( @@ -153,9 +158,8 @@ def main(command_line_conf): # Fire the API and the ChainMonitor # FIXME: 92-block-data-during-bootstrap-db chain_monitor.monitor_chain() - gatekeeper = Gatekeeper(UsersDBM(config.get("USERS_DB_PATH")), config.get("DEFAULT_SLOTS")) inspector = Inspector(block_processor, config.get("MIN_TO_SELF_DELAY")) - API(config.get("API_BIND"), config.get("API_PORT"), inspector, watcher, gatekeeper).start() + API(config.get("API_BIND"), config.get("API_PORT"), inspector, watcher).start() except Exception as e: logger.error("An error occurred: {}. Shutting down".format(e)) exit(1) diff --git a/teos/tools.py b/teos/tools.py index 269a41d..36e9afd 100644 --- a/teos/tools.py +++ b/teos/tools.py @@ -8,7 +8,6 @@ Tools is a module with general methods that can used by different entities in th """ -# NOTCOVERED def bitcoin_cli(btc_connect_params): """ An ``http`` connection with ``bitcoind`` using the ``json-rpc`` interface. diff --git a/teos/users_dbm.py b/teos/users_dbm.py index d24b421..b3b85e3 100644 --- a/teos/users_dbm.py +++ b/teos/users_dbm.py @@ -2,9 +2,9 @@ import json import plyvel from teos import LOG_PREFIX -from teos.db_manager import DBManager from common.logger import Logger +from common.db_manager import DBManager from common.tools import is_compressed_pk logger = Logger(actor="UsersDBM", log_name_prefix=LOG_PREFIX) @@ -37,42 +37,42 @@ class UsersDBM(DBManager): raise e - def store_user(self, user_pk, user_data): + def store_user(self, user_id, user_data): """ Stores a user record to the database. ``user_pk`` is used as identifier. Args: - user_pk (:obj:`str`): a 33-byte hex-encoded string identifying the user. + user_id (:obj:`str`): a 33-byte hex-encoded string identifying the user. user_data (:obj:`dict`): the user associated data, as a dictionary. Returns: :obj:`bool`: True if the user was stored in the database, False otherwise. """ - if is_compressed_pk(user_pk): + if is_compressed_pk(user_id): try: - self.create_entry(user_pk, json.dumps(user_data)) - logger.info("Adding user to Gatekeeper's db", user_pk=user_pk) + self.create_entry(user_id, json.dumps(user_data)) + logger.info("Adding user to Gatekeeper's db", user_id=user_id) return True except json.JSONDecodeError: - logger.info("Could't add user to db. Wrong user data format", user_pk=user_pk, user_data=user_data) + logger.info("Could't add user to db. Wrong user data format", user_id=user_id, user_data=user_data) return False except TypeError: - logger.info("Could't add user to db", user_pk=user_pk, user_data=user_data) + logger.info("Could't add user to db", user_id=user_id, user_data=user_data) return False else: - logger.info("Could't add user to db. Wrong pk format", user_pk=user_pk, user_data=user_data) + logger.info("Could't add user to db. Wrong pk format", user_id=user_id, user_data=user_data) return False - def load_user(self, user_pk): + def load_user(self, user_id): """ Loads a user record from the database using the ``user_pk`` as identifier. Args: - user_pk (:obj:`str`): a 33-byte hex-encoded string identifying the user. + user_id (:obj:`str`): a 33-byte hex-encoded string identifying the user. Returns: :obj:`dict`: A dictionary containing the user data if the ``key`` is found. @@ -81,31 +81,31 @@ class UsersDBM(DBManager): """ try: - data = self.load_entry(user_pk) + data = self.load_entry(user_id) data = json.loads(data) except (TypeError, json.decoder.JSONDecodeError): data = None return data - def delete_user(self, user_pk): + def delete_user(self, user_id): """ Deletes a user record from the database. Args: - user_pk (:obj:`str`): a 33-byte hex-encoded string identifying the user. + user_id (:obj:`str`): a 33-byte hex-encoded string identifying the user. Returns: :obj:`bool`: True if the user was deleted from the database or it was non-existent, False otherwise. """ try: - self.delete_entry(user_pk) - logger.info("Deleting user from Gatekeeper's db", uuid=user_pk) + self.delete_entry(user_id) + logger.info("Deleting user from Gatekeeper's db", uuid=user_id) return True except TypeError: - logger.info("Cannot delete user from db, user key has wrong type", uuid=user_pk) + logger.info("Cannot delete user from db, user key has wrong type", uuid=user_id) return False def load_all_users(self): @@ -122,7 +122,7 @@ class UsersDBM(DBManager): for k, v in self.db.iterator(): # Get uuid and appointment_data from the db - user_pk = k.decode("utf-8") - data[user_pk] = json.loads(v) + user_id = k.decode("utf-8") + data[user_id] = json.loads(v) return data diff --git a/teos/watcher.py b/teos/watcher.py index f5e7cf8..ff5ff40 100644 --- a/teos/watcher.py +++ b/teos/watcher.py @@ -1,27 +1,32 @@ from queue import Queue from threading import Thread -import common.cryptographer from common.logger import Logger from common.tools import compute_locator -from common.appointment import Appointment +from common.exceptions import BasicException +from common.exceptions import EncryptionError from common.cryptographer import Cryptographer, hash_160 +from common.exceptions import InvalidParameter, SignatureError from teos import LOG_PREFIX from teos.cleaner import Cleaner +from teos.extended_appointment import ExtendedAppointment logger = Logger(actor="Watcher", log_name_prefix=LOG_PREFIX) -common.cryptographer.logger = Logger(actor="Cryptographer", log_name_prefix=LOG_PREFIX) + + +class AppointmentLimitReached(BasicException): + """Raised when the tower maximum appointment count has been reached""" class Watcher: """ The :class:`Watcher` is in charge of watching for channel breaches for the appointments accepted by the tower. - The :class:`Watcher` keeps track of the accepted appointments in ``appointments`` and, for new received block, + The :class:`Watcher` keeps track of the accepted appointments in ``appointments`` and, for new received blocks, checks if any breach has happened by comparing the txids with the appointment locators. If a breach is seen, the - :obj:`EncryptedBlob ` of the corresponding appointment is decrypted and the - data is passed to the :obj:`Responder `. + ``encrypted_blob`` of the corresponding appointment is decrypted and the data is passed to the + :obj:`Responder `. If an appointment reaches its end with no breach, the data is simply deleted. @@ -36,40 +41,40 @@ class Watcher: responder (:obj:`Responder `): a ``Responder`` instance. sk_der (:obj:`bytes`): a DER encoded private key used to sign appointment receipts (signaling acceptance). max_appointments (:obj:`int`): the maximum amount of appointments accepted by the ``Watcher`` at the same time. - expiry_delta (:obj:`int`): the additional time the ``Watcher`` will keep an expired appointment around. Attributes: - appointments (:obj:`dict`): a dictionary containing a summary of the appointments (:obj:`Appointment - ` instances) accepted by the tower (``locator``, ``end_time``, and ``size``). - It's populated trough ``add_appointment``. + appointments (:obj:`dict`): a dictionary containing a summary of the appointments (:obj:`ExtendedAppointment + ` instances) accepted by the tower (``locator`` and + ``user_id``). It's populated trough ``add_appointment``. locator_uuid_map (:obj:`dict`): a ``locator:uuid`` map used to allow the :obj:`Watcher` to deal with several appointments with the same ``locator``. block_queue (:obj:`Queue`): A queue used by the :obj:`Watcher` to receive block hashes from ``bitcoind``. It is populated by the :obj:`ChainMonitor `. db_manager (:obj:`AppointmentsDBM `): a ``AppointmentsDBM`` instance to interact with the database. + gatekeeper (:obj:`Gatekeeper `): a `Gatekeeper` instance in charge to control the + user access and subscription expiry. block_processor (:obj:`BlockProcessor `): a ``BlockProcessor`` instance to get block from bitcoind. responder (:obj:`Responder `): a ``Responder`` instance. signing_key (:mod:`PrivateKey`): a private key used to sign accepted appointments. max_appointments (:obj:`int`): the maximum amount of appointments accepted by the ``Watcher`` at the same time. - expiry_delta (:obj:`int`): the additional time the ``Watcher`` will keep an expired appointment around. last_known_block (:obj:`str`): the last block known by the ``Watcher``. Raises: - ValueError: if `teos_sk_file` is not found. + :obj:`InvalidKey `: if teos sk cannot be loaded. """ - def __init__(self, db_manager, block_processor, responder, sk_der, max_appointments, expiry_delta): + def __init__(self, db_manager, gatekeeper, block_processor, responder, sk_der, max_appointments): self.appointments = dict() self.locator_uuid_map = dict() self.block_queue = Queue() self.db_manager = db_manager + self.gatekeeper = gatekeeper self.block_processor = block_processor self.responder = responder self.max_appointments = max_appointments - self.expiry_delta = expiry_delta self.signing_key = Cryptographer.load_private_key_der(sk_der) self.last_known_block = db_manager.load_last_block_hash_watcher() @@ -81,30 +86,15 @@ class Watcher: return watcher_thread - def get_appointment_summary(self, uuid): - """ - Returns the summary of an appointment. The summary consists of the data kept in memory: - {locator, end_time, and size} - - Args: - uuid (:obj:`str`): a 16-byte hex string identifying the appointment. - - Returns: - :obj:`dict` or :obj:`None`: a dictionary with the appointment summary, or ``None`` if the appointment is not - found. - """ - return self.appointments.get(uuid) - - def add_appointment(self, appointment, user_pk): + def add_appointment(self, appointment, signature): """ Adds a new appointment to the ``appointments`` dictionary if ``max_appointments`` has not been reached. ``add_appointment`` is the entry point of the ``Watcher``. Upon receiving a new appointment it will start monitoring the blockchain (``do_watch``) until ``appointments`` is empty. - Once a breach is seen on the blockchain, the :obj:`Watcher` will decrypt the corresponding - :obj:`EncryptedBlob ` and pass the information to the - :obj:`Responder `. + Once a breach is seen on the blockchain, the :obj:`Watcher` will decrypt the corresponding ``encrypted_blob`` + and pass the information to the :obj:`Responder `. The tower may store multiple appointments with the same ``locator`` to avoid DoS attacks based on data rewriting. `locators`` should be derived from the ``dispute_txid``, but that task is performed by the user, and @@ -112,53 +102,65 @@ class Watcher: identified by ``uuid`` and stored in ``appointments`` and ``locator_uuid_map``. Args: - appointment (:obj:`Appointment `): the appointment to be added to the - :obj:`Watcher`. - user_pk(:obj:`str`): the public key that identifies the user who sent the appointment (33-bytes hex str). + appointment (:obj:`ExtendedAppointment `): the appointment to + be added to the :obj:`Watcher`. + signature (:obj:`str`): the user's appointment signature (hex-encoded). Returns: - :obj:`tuple`: A tuple signaling if the appointment has been added or not (based on ``max_appointments``). - The structure looks as follows: + :obj:`dict`: The tower response as a dict, containing: locator, signature, available_slots and + subscription_expiry. - - ``(True, signature)`` if the appointment has been accepted. - - ``(False, None)`` otherwise. + Raises: + :obj:`AppointmentLimitReached`: If the tower cannot hold more appointments (cap reached). + :obj:`AuthenticationFailure `: If the user cannot be authenticated. + :obj:`NotEnoughSlots `: If the user does not have enough available slots, + so the appointment is rejected. """ - if len(self.appointments) < self.max_appointments: + if len(self.appointments) >= self.max_appointments: + message = "Maximum appointments reached, appointment rejected" + logger.info(message, locator=appointment.locator) + raise AppointmentLimitReached(message) - # The uuids are generated as the RIPMED160(locator||user_pubkey), that way the tower does not need to know - # anything about the user from this point on (no need to store user_pk in the database). - # If an appointment is requested by the user the uuid can be recomputed and queried straightaway (no maps). - uuid = hash_160("{}{}".format(appointment.locator, user_pk)) - self.appointments[uuid] = { - "locator": appointment.locator, - "end_time": appointment.end_time, - "size": len(appointment.encrypted_blob.data), - } + user_id = self.gatekeeper.authenticate_user(appointment.serialize(), signature) + # The user_id needs to be added to the ExtendedAppointment once the former has been authenticated + appointment.user_id = user_id - if appointment.locator in self.locator_uuid_map: - # If the uuid is already in the map it means this is an update. - if uuid not in self.locator_uuid_map[appointment.locator]: - self.locator_uuid_map[appointment.locator].append(uuid) + # The uuids are generated as the RIPMED160(locator||user_pubkey). + # If an appointment is requested by the user the uuid can be recomputed and queried straightaway (no maps). + uuid = hash_160("{}{}".format(appointment.locator, user_id)) - else: - self.locator_uuid_map[appointment.locator] = [uuid] + # Add the appointment to the Gatekeeper + available_slots = self.gatekeeper.add_update_appointment(user_id, uuid, appointment) + self.appointments[uuid] = appointment.get_summary() - self.db_manager.store_watcher_appointment(uuid, appointment.to_dict()) - self.db_manager.create_append_locator_map(appointment.locator, uuid) + if appointment.locator in self.locator_uuid_map: + # If the uuid is already in the map it means this is an update. + if uuid not in self.locator_uuid_map[appointment.locator]: + self.locator_uuid_map[appointment.locator].append(uuid) + else: + # Otherwise two users have sent an appointment with the same locator, so we need to store both. + self.locator_uuid_map[appointment.locator] = [uuid] - appointment_added = True + self.db_manager.store_watcher_appointment(uuid, appointment.to_dict()) + self.db_manager.create_append_locator_map(appointment.locator, uuid) + + try: signature = Cryptographer.sign(appointment.serialize(), self.signing_key) - logger.info("New appointment accepted", locator=appointment.locator) - - else: - appointment_added = False + except (InvalidParameter, SignatureError): + # This should never happen since data is sanitized, just in case to avoid a crash + logger.error("Data couldn't be signed", appointment=appointment.to_dict()) signature = None - logger.info("Maximum appointments reached, appointment rejected", locator=appointment.locator) + logger.info("New appointment accepted", locator=appointment.locator) - return appointment_added, signature + return { + "locator": appointment.locator, + "signature": signature, + "available_slots": available_slots, + "subscription_expiry": self.gatekeeper.registered_users[user_id].subscription_expiry, + } def do_watch(self): """ @@ -181,17 +183,20 @@ class Watcher: if len(self.appointments) > 0 and block is not None: txids = block.get("tx") - expired_appointments = [ - uuid - for uuid, appointment_data in self.appointments.items() - if block["height"] > appointment_data.get("end_time") + self.expiry_delta - ] + expired_appointments = self.gatekeeper.get_expired_appointments(block["height"]) + # Make sure we only try to delete what is on the Watcher (some appointments may have been triggered) + expired_appointments = list(set(expired_appointments).intersection(self.appointments.keys())) + + # Keep track of the expired appointments before deleting them from memory + appointments_to_delete_gatekeeper = { + uuid: self.appointments[uuid].get("user_id") for uuid in expired_appointments + } Cleaner.delete_expired_appointments( expired_appointments, self.appointments, self.locator_uuid_map, self.db_manager ) - valid_breaches, invalid_breaches = self.filter_valid_breaches(self.get_breaches(txids)) + valid_breaches, invalid_breaches = self.filter_breaches(self.get_breaches(txids)) triggered_flags = [] appointments_to_delete = [] @@ -210,7 +215,7 @@ class Watcher: breach["dispute_txid"], breach["penalty_txid"], breach["penalty_rawtx"], - self.appointments[uuid].get("end_time"), + self.appointments[uuid].get("user_id"), block_hash, ) @@ -226,10 +231,18 @@ class Watcher: appointments_to_delete.extend(invalid_breaches) self.db_manager.batch_create_triggered_appointment_flag(triggered_flags) + # Update the dictionary with the completed appointments + appointments_to_delete_gatekeeper.update( + {uuid: self.appointments[uuid].get("user_id") for uuid in appointments_to_delete} + ) + Cleaner.delete_completed_appointments( appointments_to_delete, self.appointments, self.locator_uuid_map, self.db_manager ) + # Remove expired and completed appointments from the Gatekeeper + Cleaner.delete_gatekeeper_appointments(self.gatekeeper, appointments_to_delete_gatekeeper) + if len(self.appointments) != 0: logger.info("No more pending appointments") @@ -264,13 +277,12 @@ class Watcher: return breaches - def filter_valid_breaches(self, breaches): + def filter_breaches(self, breaches): """ - Filters what of the found breaches contain valid transaction data. + Filters the valid from the invalid channel breaches. - The :obj:`Watcher` cannot if a given :obj:`EncryptedBlob ` contains a valid - transaction until a breach if seen. Blobs that contain arbitrary data are dropped and not sent to the - :obj:`Responder `. + The :obj:`Watcher` cannot if a given ``encrypted_blob`` contains a valid transaction until a breach if seen. + Blobs that contain arbitrary data are dropped and not sent to the :obj:`Responder `. Args: breaches (:obj:`dict`): a dictionary containing channel breaches (``locator:txid``). @@ -290,20 +302,20 @@ class Watcher: for locator, dispute_txid in breaches.items(): for uuid in self.locator_uuid_map[locator]: - appointment = Appointment.from_dict(self.db_manager.load_watcher_appointment(uuid)) + appointment = ExtendedAppointment.from_dict(self.db_manager.load_watcher_appointment(uuid)) - if appointment.encrypted_blob.data in decrypted_blobs: - penalty_tx, penalty_rawtx = decrypted_blobs[appointment.encrypted_blob.data] + if appointment.encrypted_blob in decrypted_blobs: + penalty_tx, penalty_rawtx = decrypted_blobs[appointment.encrypted_blob] else: try: penalty_rawtx = Cryptographer.decrypt(appointment.encrypted_blob, dispute_txid) - except ValueError: + except EncryptionError: penalty_rawtx = None penalty_tx = self.block_processor.decode_raw_transaction(penalty_rawtx) - decrypted_blobs[appointment.encrypted_blob.data] = (penalty_tx, penalty_rawtx) + decrypted_blobs[appointment.encrypted_blob] = (penalty_tx, penalty_rawtx) if penalty_tx is not None: valid_breaches[uuid] = { diff --git a/test/cli/unit/test_teos_cli.py b/test/cli/unit/test_teos_cli.py index c0f9b42..b3cb0f2 100644 --- a/test/cli/unit/test_teos_cli.py +++ b/test/cli/unit/test_teos_cli.py @@ -3,31 +3,26 @@ import json import shutil import pytest import responses -from binascii import hexlify from coincurve import PrivateKey from requests.exceptions import ConnectionError, Timeout -from common.blob import Blob -import common.cryptographer -from common.logger import Logger from common.tools import compute_locator from common.appointment import Appointment from common.cryptographer import Cryptographer +from common.exceptions import InvalidParameter, InvalidKey import cli.teos_cli as teos_cli -from cli.exceptions import InvalidParameter, InvalidKey, TowerResponseError +from cli.exceptions import TowerResponseError from test.cli.unit.conftest import get_random_value_hex, get_config -common.cryptographer.logger = Logger(actor="Cryptographer", log_name_prefix=teos_cli.LOG_PREFIX) - config = get_config() # dummy keys for the tests -dummy_cli_sk = PrivateKey.from_int(1) -dummy_cli_compressed_pk = dummy_cli_sk.public_key.format(compressed=True) +dummy_user_sk = PrivateKey.from_int(1) +dummy_user_id = Cryptographer.get_compressed_pk(dummy_user_sk.public_key) dummy_teos_sk = PrivateKey.from_int(2) -dummy_teos_pk = dummy_teos_sk.public_key +dummy_teos_id = Cryptographer.get_compressed_pk(dummy_teos_sk.public_key) another_sk = PrivateKey.from_int(3) teos_url = "http://{}:{}".format(config.get("API_CONNECT"), config.get("API_PORT")) @@ -50,9 +45,7 @@ dummy_appointment_dict = { "start_time": dummy_appointment_data.get("start_time"), "end_time": dummy_appointment_data.get("end_time"), "to_self_delay": dummy_appointment_data.get("to_self_delay"), - "encrypted_blob": Cryptographer.encrypt( - Blob(dummy_appointment_data.get("tx")), dummy_appointment_data.get("tx_id") - ), + "encrypted_blob": Cryptographer.encrypt(dummy_appointment_data.get("tx"), dummy_appointment_data.get("tx_id")), } dummy_appointment = Appointment.from_dict(dummy_appointment_dict) @@ -66,14 +59,13 @@ def get_signature(message, sk): @responses.activate def test_register(): # Simulate a register response - compressed_pk_hex = hexlify(dummy_cli_compressed_pk).decode("utf-8") - response = {"public_key": compressed_pk_hex, "available_slots": 100} + response = {"public_key": dummy_user_id, "available_slots": 100} responses.add(responses.POST, register_endpoint, json=response, status=200) - result = teos_cli.register(compressed_pk_hex, teos_url) + result = teos_cli.register(dummy_user_id, teos_url) assert len(responses.calls) == 1 assert responses.calls[0].request.url == register_endpoint - assert result.get("public_key") == compressed_pk_hex and result.get("available_slots") == response.get( + assert result.get("public_key") == dummy_user_id and result.get("available_slots") == response.get( "available_slots" ) @@ -88,7 +80,7 @@ def test_add_appointment(): "available_slots": 100, } responses.add(responses.POST, add_appointment_endpoint, json=response, status=200) - result = teos_cli.add_appointment(dummy_appointment_data, dummy_cli_sk, dummy_teos_pk, teos_url) + result = teos_cli.add_appointment(dummy_appointment_data, dummy_user_sk, dummy_teos_id, teos_url) assert len(responses.calls) == 1 assert responses.calls[0].request.url == add_appointment_endpoint @@ -109,7 +101,7 @@ def test_add_appointment_with_invalid_signature(monkeypatch): responses.add(responses.POST, add_appointment_endpoint, json=response, status=200) with pytest.raises(TowerResponseError): - teos_cli.add_appointment(dummy_appointment_data, dummy_cli_sk, dummy_teos_pk, teos_url) + teos_cli.add_appointment(dummy_appointment_data, dummy_user_sk, dummy_teos_id, teos_url) @responses.activate @@ -122,7 +114,7 @@ def test_get_appointment(): } responses.add(responses.POST, get_appointment_endpoint, json=response, status=200) - result = teos_cli.get_appointment(dummy_appointment_dict.get("locator"), dummy_cli_sk, dummy_teos_pk, teos_url) + result = teos_cli.get_appointment(dummy_appointment_dict.get("locator"), dummy_user_sk, dummy_teos_id, teos_url) assert len(responses.calls) == 1 assert responses.calls[0].request.url == get_appointment_endpoint @@ -137,7 +129,7 @@ def test_get_appointment_err(): responses.add(responses.POST, get_appointment_endpoint, body=ConnectionError()) with pytest.raises(ConnectionError): - teos_cli.get_appointment(locator, dummy_cli_sk, dummy_teos_pk, teos_url) + teos_cli.get_appointment(locator, dummy_user_sk, dummy_teos_id, teos_url) def test_load_keys(): @@ -146,36 +138,32 @@ def test_load_keys(): public_key_file_path = "pk_test_file" empty_file_path = "empty_file" with open(private_key_file_path, "wb") as f: - f.write(dummy_cli_sk.to_der()) + f.write(dummy_user_sk.to_der()) with open(public_key_file_path, "wb") as f: - f.write(dummy_cli_compressed_pk) + f.write(dummy_user_sk.public_key.format(compressed=True)) with open(empty_file_path, "wb"): pass - # Now we can test the function passing the using this files (we'll use the same pk for both) - r = teos_cli.load_keys(public_key_file_path, private_key_file_path, public_key_file_path) + # Now we can test the function passing the using this files + r = teos_cli.load_keys(public_key_file_path, private_key_file_path) assert isinstance(r, tuple) assert len(r) == 3 # If any param does not match the expected, we should get an InvalidKey exception with pytest.raises(InvalidKey): - teos_cli.load_keys(None, private_key_file_path, public_key_file_path) + teos_cli.load_keys(None, private_key_file_path) with pytest.raises(InvalidKey): - teos_cli.load_keys(public_key_file_path, None, public_key_file_path) - with pytest.raises(InvalidKey): - teos_cli.load_keys(public_key_file_path, private_key_file_path, None) + teos_cli.load_keys(public_key_file_path, None) # The same should happen if we pass a public key where a private should be, for instance with pytest.raises(InvalidKey): - teos_cli.load_keys(private_key_file_path, public_key_file_path, private_key_file_path) + teos_cli.load_keys(private_key_file_path, public_key_file_path) # Same if any of the files is empty with pytest.raises(InvalidKey): - teos_cli.load_keys(empty_file_path, private_key_file_path, public_key_file_path) + teos_cli.load_keys(empty_file_path, private_key_file_path) with pytest.raises(InvalidKey): - teos_cli.load_keys(public_key_file_path, empty_file_path, public_key_file_path) - with pytest.raises(InvalidKey): - teos_cli.load_keys(public_key_file_path, private_key_file_path, empty_file_path) + teos_cli.load_keys(public_key_file_path, empty_file_path) # Remove the tmp files os.remove(private_key_file_path) diff --git a/test/common/unit/conftest.py b/test/common/unit/conftest.py index 3752ac0..aa9b0ed 100644 --- a/test/common/unit/conftest.py +++ b/test/common/unit/conftest.py @@ -1,5 +1,8 @@ import pytest import random +from shutil import rmtree + +from common.db_manager import DBManager @pytest.fixture(scope="session", autouse=True) @@ -7,6 +10,17 @@ def prng_seed(): random.seed(0) +@pytest.fixture(scope="module") +def db_manager(): + manager = DBManager("test_db") + # Add last know block for the Responder in the db + + yield manager + + manager.db.close() + rmtree("test_db") + + def get_random_value_hex(nbytes): pseudo_random_value = random.getrandbits(8 * nbytes) prv_hex = "{:x}".format(pseudo_random_value) diff --git a/test/common/unit/test_appointment.py b/test/common/unit/test_appointment.py index d5738a4..4a5162d 100644 --- a/test/common/unit/test_appointment.py +++ b/test/common/unit/test_appointment.py @@ -1,14 +1,13 @@ import struct import binascii +import pytest from pytest import fixture from common.appointment import Appointment -from common.encrypted_blob import EncryptedBlob +from common.constants import LOCATOR_LEN_BYTES from test.common.unit.conftest import get_random_value_hex -from common.constants import LOCATOR_LEN_BYTES - # Not much to test here, adding it for completeness @fixture @@ -31,42 +30,28 @@ def appointment_data(): def test_init_appointment(appointment_data): # The appointment has no checks whatsoever, since the inspector is the one taking care or that, and the only one # creating appointments. - # DISCUSS: whether this makes sense by design or checks should be ported from the inspector to the appointment - # 35-appointment-checks appointment = Appointment( - appointment_data["locator"], - appointment_data["start_time"], - appointment_data["end_time"], - appointment_data["to_self_delay"], - appointment_data["encrypted_blob"], + appointment_data["locator"], appointment_data["to_self_delay"], appointment_data["encrypted_blob"] ) assert ( appointment_data["locator"] == appointment.locator - and appointment_data["start_time"] == appointment.start_time - and appointment_data["end_time"] == appointment.end_time and appointment_data["to_self_delay"] == appointment.to_self_delay - and EncryptedBlob(appointment_data["encrypted_blob"]) == appointment.encrypted_blob + and appointment_data["encrypted_blob"] == appointment.encrypted_blob ) def test_to_dict(appointment_data): appointment = Appointment( - appointment_data["locator"], - appointment_data["start_time"], - appointment_data["end_time"], - appointment_data["to_self_delay"], - appointment_data["encrypted_blob"], + appointment_data["locator"], appointment_data["to_self_delay"], appointment_data["encrypted_blob"] ) dict_appointment = appointment.to_dict() assert ( appointment_data["locator"] == dict_appointment["locator"] - and appointment_data["start_time"] == dict_appointment["start_time"] - and appointment_data["end_time"] == dict_appointment["end_time"] and appointment_data["to_self_delay"] == dict_appointment["to_self_delay"] - and EncryptedBlob(appointment_data["encrypted_blob"]) == EncryptedBlob(dict_appointment["encrypted_blob"]) + and appointment_data["encrypted_blob"] == dict_appointment["encrypted_blob"] ) @@ -80,13 +65,9 @@ def test_from_dict(appointment_data): prev_val = appointment_data[key] appointment_data[key] = None - try: + with pytest.raises(ValueError, match="Wrong appointment data"): Appointment.from_dict(appointment_data) - assert False - - except ValueError: appointment_data[key] = prev_val - assert True def test_serialize(appointment_data): @@ -101,13 +82,9 @@ def test_serialize(appointment_data): assert isinstance(serialized_appointment, bytes) locator = serialized_appointment[:16] - start_time = serialized_appointment[16:20] - end_time = serialized_appointment[20:24] - to_self_delay = serialized_appointment[24:28] - encrypted_blob = serialized_appointment[28:] + to_self_delay = serialized_appointment[16:20] + encrypted_blob = serialized_appointment[20:] assert binascii.hexlify(locator).decode() == appointment.locator - assert struct.unpack(">I", start_time)[0] == appointment.start_time - assert struct.unpack(">I", end_time)[0] == appointment.end_time assert struct.unpack(">I", to_self_delay)[0] == appointment.to_self_delay - assert binascii.hexlify(encrypted_blob).decode() == appointment.encrypted_blob.data + assert binascii.hexlify(encrypted_blob).decode() == appointment.encrypted_blob diff --git a/test/common/unit/test_blob.py b/test/common/unit/test_blob.py deleted file mode 100644 index cf3e07f..0000000 --- a/test/common/unit/test_blob.py +++ /dev/null @@ -1,18 +0,0 @@ -from binascii import unhexlify - -from common.blob import Blob -from test.common.unit.conftest import get_random_value_hex - - -def test_init_blob(): - data = get_random_value_hex(64) - blob = Blob(data) - assert isinstance(blob, Blob) - - # Wrong data - try: - Blob(unhexlify(get_random_value_hex(64))) - assert False, "Able to create blob with wrong data" - - except ValueError: - assert True diff --git a/test/common/unit/test_cryptographer.py b/test/common/unit/test_cryptographer.py index bb60125..55c7454 100644 --- a/test/common/unit/test_cryptographer.py +++ b/test/common/unit/test_cryptographer.py @@ -1,19 +1,14 @@ import os -from binascii import unhexlify +import pytest +from coincurve import PrivateKey, PublicKey from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.asymmetric import ec from cryptography.hazmat.primitives import serialization -from coincurve import PrivateKey, PublicKey -import common.cryptographer -from common.blob import Blob -from common.logger import Logger +from common.exceptions import InvalidKey, InvalidParameter, EncryptionError, SignatureError from common.cryptographer import Cryptographer -from common.encrypted_blob import EncryptedBlob from test.common.unit.conftest import get_random_value_hex -common.cryptographer.logger = Logger(actor="Cryptographer", log_name_prefix="") - data = "6097cdf52309b1b2124efeed36bd34f46dc1c25ad23ac86f28380f746254f777" key = "b2e984a570f6f49bc38ace178e09147b0aa296cbb7c92eb01412f7e2d07b5659" encrypted_data = "8f31028097a8bf12a92e088caab5cf3fcddf0d35ed2b72c24b12269373efcdea04f9d2a820adafe830c20ff132d89810" @@ -33,103 +28,77 @@ def test_check_data_key_format_wrong_data(): data = get_random_value_hex(64)[:-1] key = get_random_value_hex(32) - try: + with pytest.raises(InvalidParameter, match="Odd-length"): Cryptographer.check_data_key_format(data, key) - assert False - - except ValueError as e: - assert "Odd-length" in str(e) def test_check_data_key_format_wrong_key(): data = get_random_value_hex(64) key = get_random_value_hex(33) - try: + with pytest.raises(InvalidParameter, match="32-byte hex value"): Cryptographer.check_data_key_format(data, key) - assert False - - except ValueError as e: - assert "32-byte hex" in str(e) def test_check_data_key_format(): data = get_random_value_hex(64) key = get_random_value_hex(32) - assert Cryptographer.check_data_key_format(data, key) is True + # Correct format does not raise anything + Cryptographer.check_data_key_format(data, key) def test_encrypt_odd_length_data(): - blob = Blob(get_random_value_hex(64)[-1]) + blob = get_random_value_hex(64)[-1] key = get_random_value_hex(32) - try: + with pytest.raises(InvalidParameter, match="Odd-length"): Cryptographer.encrypt(blob, key) - assert False - - except ValueError: - assert True def test_encrypt_wrong_key_size(): - blob = Blob(get_random_value_hex(64)) + blob = get_random_value_hex(64) key = get_random_value_hex(31) - try: + with pytest.raises(InvalidParameter, match="32-byte hex value"): Cryptographer.encrypt(blob, key) - assert False - - except ValueError: - assert True def test_encrypt(): - blob = Blob(data) - - assert Cryptographer.encrypt(blob, key) == encrypted_data + assert Cryptographer.encrypt(data, key) == encrypted_data def test_decrypt_invalid_tag(): random_key = get_random_value_hex(32) random_encrypted_data = get_random_value_hex(64) - random_encrypted_blob = EncryptedBlob(random_encrypted_data) + random_encrypted_blob = random_encrypted_data - # Trying to decrypt random data should result in an InvalidTag exception. Our decrypt function - # returns None - hex_tx = Cryptographer.decrypt(random_encrypted_blob, random_key) - assert hex_tx is None + # Trying to decrypt random data should result in an EncryptionError + with pytest.raises(EncryptionError, match="Cannot decrypt blob with the provided key"): + Cryptographer.decrypt(random_encrypted_blob, random_key) def test_decrypt_odd_length_data(): random_key = get_random_value_hex(32) random_encrypted_data_odd = get_random_value_hex(64)[:-1] - random_encrypted_blob_odd = EncryptedBlob(random_encrypted_data_odd) + random_encrypted_blob_odd = random_encrypted_data_odd - try: + with pytest.raises(InvalidParameter, match="Odd-length"): Cryptographer.decrypt(random_encrypted_blob_odd, random_key) - assert False - - except ValueError: - assert True def test_decrypt_wrong_key_size(): random_key = get_random_value_hex(31) random_encrypted_data_odd = get_random_value_hex(64) - random_encrypted_blob_odd = EncryptedBlob(random_encrypted_data_odd) + random_encrypted_blob_odd = random_encrypted_data_odd - try: + with pytest.raises(InvalidParameter, match="32-byte hex value"): Cryptographer.decrypt(random_encrypted_blob_odd, random_key) - assert False - - except ValueError: - assert True def test_decrypt(): # Valid data should run with no InvalidTag and verify - assert Cryptographer.decrypt(EncryptedBlob(encrypted_data), key) == data + assert Cryptographer.decrypt(encrypted_data, key) == data def test_load_key_file(): @@ -144,29 +113,33 @@ def test_load_key_file(): with open("key_test_file", "wb") as f: f.write(dummy_sk_der) - appt_data = Cryptographer.load_key_file("key_test_file") - assert appt_data - + Cryptographer.load_key_file("key_test_file") os.remove("key_test_file") # If file doesn't exist, function should return None - assert Cryptographer.load_key_file("nonexistent_file") is None + with pytest.raises(InvalidParameter, match="file not found"): + Cryptographer.load_key_file("nonexistent_file") - # If something that's not a file_path is passed as parameter the method should also return None - assert Cryptographer.load_key_file(0) is None and Cryptographer.load_key_file(None) is None + with pytest.raises(InvalidParameter, match="file path was expected"): + Cryptographer.load_key_file(0) + + with pytest.raises(InvalidParameter, match="file path was expected"): + Cryptographer.load_key_file(None) def test_load_private_key_der(): # load_private_key_der expects a byte encoded data. Any other should fail and return None for wtype in WRONG_TYPES: - assert Cryptographer.load_private_key_der(wtype) is None + with pytest.raises(InvalidKey, match="(wrong type)"): + Cryptographer.load_private_key_der(wtype) # On the other hand, any random formatter byte array would also fail (zeros for example) - assert Cryptographer.load_private_key_der(bytes(32)) is None + with pytest.raises(InvalidKey, match="(wrong size or format)"): + Cryptographer.load_private_key_der(bytes(32)) # A proper formatted key should load sk_der = generate_keypair()[0].to_der() - assert Cryptographer.load_private_key_der(sk_der) is not None + Cryptographer.load_private_key_der(sk_der) def test_sign(): @@ -189,13 +162,14 @@ def test_sign_ground_truth(): sig = Cryptographer.sign(message, sk) rpk = Cryptographer.recover_pk(message, sig) - assert Cryptographer.verify_rpk(PublicKey(unhexlify(c_lightning_rpk)), rpk) + assert c_lightning_rpk == Cryptographer.get_compressed_pk(rpk) def test_sign_wrong_sk(): # If a sk is not passed, sign will return None for wtype in WRONG_TYPES: - assert Cryptographer.sign(b"", wtype) is None + with pytest.raises(InvalidParameter, match="Wrong value passed as sk"): + Cryptographer.sign(b"", wtype) def test_recover_pk(): @@ -209,12 +183,13 @@ def test_recover_pk(): def test_recover_pk_invalid_sigrec(): - message = "Hey, it's me" + message = "Hey, it's me".encode("utf-8") signature = "ddbfb019e4d56155b4175066c2b615ab765d317ae7996d188b4a5fae4cc394adf98fef46034d0553149392219ca6d37dca9abdfa6366a8e54b28f19d3e5efa8a14b556205dc7f33a" # The given signature, when zbase32 decoded, has a fist byte with value lower than 31. # The first byte of the signature should be 31 + SigRec, so this should fail - assert Cryptographer.recover_pk(message, signature) is None + with pytest.raises(SignatureError, match="Wrong SigRec"): + Cryptographer.recover_pk(message, signature) def test_recover_pk_ground_truth(): @@ -225,9 +200,10 @@ def test_recover_pk_ground_truth(): rpk = Cryptographer.recover_pk(message, zsig) - assert Cryptographer.verify_rpk(PublicKey(unhexlify(org_pk)), rpk) + assert org_pk == Cryptographer.get_compressed_pk(rpk) +# FIXME: needs further testing def test_recover_pk_wrong_inputs(): str_message = "Test message" message = bytes(20) @@ -235,35 +211,18 @@ def test_recover_pk_wrong_inputs(): sig = bytes(20) # Wrong input type - assert Cryptographer.recover_pk(message, str_sig) is None - assert Cryptographer.recover_pk(str_message, str_sig) is None - assert Cryptographer.recover_pk(str_message, sig) is None - assert Cryptographer.recover_pk(message, str_sig) is None + with pytest.raises(InvalidParameter, match="Wrong value passed as zbase32_sig"): + Cryptographer.recover_pk(message, sig) - # Wrong input size or format - assert Cryptographer.recover_pk(message, sig) is None - assert Cryptographer.recover_pk(message, bytes(104)) is None + with pytest.raises(InvalidParameter, match="Wrong value passed as message"): + Cryptographer.recover_pk(str_message, str_sig) + with pytest.raises(InvalidParameter, match="Wrong value passed as message"): + Cryptographer.recover_pk(str_message, sig) -def test_verify_pk(): - sk, _ = generate_keypair() - message = b"Test message" - - zbase32_sig = Cryptographer.sign(message, sk) - rpk = Cryptographer.recover_pk(message, zbase32_sig) - - assert Cryptographer.verify_rpk(sk.public_key, rpk) - - -def test_verify_pk_wrong(): - sk, _ = generate_keypair() - sk2, _ = generate_keypair() - message = b"Test message" - - zbase32_sig = Cryptographer.sign(message, sk) - rpk = Cryptographer.recover_pk(message, zbase32_sig) - - assert not Cryptographer.verify_rpk(sk2.public_key, rpk) + # Wrong input size + with pytest.raises(SignatureError, match="Serialized signature must be 65 bytes long"): + Cryptographer.recover_pk(message, str_sig) def test_get_compressed_pk(): @@ -275,16 +234,16 @@ def test_get_compressed_pk(): def test_get_compressed_pk_wrong_key(): - # pk should be properly initialized. Initializing from int will case it to not be recoverable + # pk should be properly initialized. Initializing from int will cause it to not be recoverable pk = PublicKey(0) - compressed_pk = Cryptographer.get_compressed_pk(pk) - assert compressed_pk is None + with pytest.raises(InvalidKey, match="PublicKey has invalid initializer"): + Cryptographer.get_compressed_pk(pk) def test_get_compressed_pk_wrong_type(): # Passing a value that is not a PublicKey will make it to fail too pk = get_random_value_hex(33) - compressed_pk = Cryptographer.get_compressed_pk(pk) - assert compressed_pk is None + with pytest.raises(InvalidParameter, match="Wrong value passed as pk"): + Cryptographer.get_compressed_pk(pk) diff --git a/test/teos/unit/test_db_manager.py b/test/common/unit/test_db_manager.py similarity index 97% rename from test/teos/unit/test_db_manager.py rename to test/common/unit/test_db_manager.py index 2ee337d..bbcb030 100644 --- a/test/teos/unit/test_db_manager.py +++ b/test/common/unit/test_db_manager.py @@ -2,8 +2,8 @@ import os import shutil import pytest -from teos.db_manager import DBManager -from test.teos.unit.conftest import get_random_value_hex +from common.db_manager import DBManager +from test.common.unit.conftest import get_random_value_hex def open_create_db(db_path): diff --git a/test/common/unit/test_encrypted_blob.py b/test/common/unit/test_encrypted_blob.py deleted file mode 100644 index d67cfe1..0000000 --- a/test/common/unit/test_encrypted_blob.py +++ /dev/null @@ -1,16 +0,0 @@ -from common.encrypted_blob import EncryptedBlob -from test.common.unit.conftest import get_random_value_hex - - -def test_init_encrypted_blob(): - # No much to test here, basically that the object is properly created - data = get_random_value_hex(64) - assert EncryptedBlob(data).data == data - - -def test_equal(): - data = get_random_value_hex(64) - e_blob1 = EncryptedBlob(data) - e_blob2 = EncryptedBlob(data) - - assert e_blob1 == e_blob2 and id(e_blob1) != id(e_blob2) diff --git a/test/teos/e2e/conftest.py b/test/teos/e2e/conftest.py index 6d786cc..d4bb314 100644 --- a/test/teos/e2e/conftest.py +++ b/test/teos/e2e/conftest.py @@ -11,7 +11,6 @@ from common.config_loader import ConfigLoader getcontext().prec = 10 -END_TIME_DELTA = 10 @pytest.fixture(scope="session") @@ -123,16 +122,8 @@ def create_penalty_tx(bitcoin_cli, decoded_commitment_tx, destination=None): return signed_penalty_tx.get("hex") -def build_appointment_data(bitcoin_cli, commitment_tx_id, penalty_tx): - current_height = bitcoin_cli.getblockcount() - - appointment_data = { - "tx": penalty_tx, - "tx_id": commitment_tx_id, - "start_time": current_height + 1, - "end_time": current_height + 1 + END_TIME_DELTA, - "to_self_delay": 20, - } +def build_appointment_data(commitment_tx_id, penalty_tx): + appointment_data = {"tx": penalty_tx, "tx_id": commitment_tx_id, "to_self_delay": 20} return appointment_data diff --git a/test/teos/e2e/test_basic_e2e.py b/test/teos/e2e/test_basic_e2e.py index 59c4ee7..dbcfa0f 100644 --- a/test/teos/e2e/test_basic_e2e.py +++ b/test/teos/e2e/test_basic_e2e.py @@ -8,15 +8,16 @@ from coincurve import PrivateKey from cli.exceptions import TowerResponseError from cli import teos_cli, DATA_DIR, DEFAULT_CONF, CONF_FILE_NAME -import common.cryptographer -from common.blob import Blob -from common.logger import Logger from common.tools import compute_locator from common.appointment import Appointment from common.cryptographer import Cryptographer + +from teos import DEFAULT_CONF as TEOS_CONF +from teos import DATA_DIR as TEOS_DATA_DIR +from teos import CONF_FILE_NAME as TEOS_CONF_FILE_NAME from teos.utils.auth_proxy import JSONRPCException + from test.teos.e2e.conftest import ( - END_TIME_DELTA, build_appointment_data, get_random_value_hex, create_penalty_tx, @@ -26,7 +27,8 @@ from test.teos.e2e.conftest import ( ) cli_config = get_config(DATA_DIR, CONF_FILE_NAME, DEFAULT_CONF) -common.cryptographer.logger = Logger(actor="Cryptographer", log_name_prefix="") +teos_config = get_config(TEOS_DATA_DIR, TEOS_CONF_FILE_NAME, TEOS_CONF) + teos_base_endpoint = "http://{}:{}".format(cli_config.get("API_CONNECT"), cli_config.get("API_PORT")) teos_add_appointment_endpoint = "{}/add_appointment".format(teos_base_endpoint) @@ -36,9 +38,7 @@ teos_get_all_appointments_endpoint = "{}/get_all_appointments".format(teos_base_ # Run teosd teosd_process = run_teosd() -teos_pk, cli_sk, compressed_cli_pk = teos_cli.load_keys( - cli_config.get("TEOS_PUBLIC_KEY"), cli_config.get("CLI_PRIVATE_KEY"), cli_config.get("CLI_PUBLIC_KEY") -) +teos_id, user_sk, user_id = teos_cli.load_keys(cli_config.get("TEOS_PUBLIC_KEY"), cli_config.get("CLI_PRIVATE_KEY")) def broadcast_transaction_and_mine_block(bitcoin_cli, commitment_tx, addr): @@ -47,13 +47,13 @@ def broadcast_transaction_and_mine_block(bitcoin_cli, commitment_tx, addr): bitcoin_cli.generatetoaddress(1, addr) -def get_appointment_info(locator, sk=cli_sk): +def get_appointment_info(locator, sk=user_sk): sleep(1) # Let's add a bit of delay so the state can be updated - return teos_cli.get_appointment(locator, sk, teos_pk, teos_base_endpoint) + return teos_cli.get_appointment(locator, sk, teos_id, teos_base_endpoint) -def add_appointment(appointment_data, sk=cli_sk): - return teos_cli.add_appointment(appointment_data, sk, teos_pk, teos_base_endpoint) +def add_appointment(appointment_data, sk=user_sk): + return teos_cli.add_appointment(appointment_data, sk, teos_id, teos_base_endpoint) def get_all_appointments(): @@ -67,7 +67,7 @@ def test_commands_non_registered(bitcoin_cli): # Add appointment commitment_tx, penalty_tx = create_txs(bitcoin_cli) commitment_tx_id = bitcoin_cli.decoderawtransaction(commitment_tx).get("txid") - appointment_data = build_appointment_data(bitcoin_cli, commitment_tx_id, penalty_tx) + appointment_data = build_appointment_data(commitment_tx_id, penalty_tx) with pytest.raises(TowerResponseError): assert add_appointment(appointment_data) @@ -79,12 +79,12 @@ def test_commands_non_registered(bitcoin_cli): def test_commands_registered(bitcoin_cli): # Test registering and trying again - teos_cli.register(compressed_cli_pk, teos_base_endpoint) + teos_cli.register(user_id, teos_base_endpoint) # Add appointment commitment_tx, penalty_tx = create_txs(bitcoin_cli) commitment_tx_id = bitcoin_cli.decoderawtransaction(commitment_tx).get("txid") - appointment_data = build_appointment_data(bitcoin_cli, commitment_tx_id, penalty_tx) + appointment_data = build_appointment_data(commitment_tx_id, penalty_tx) appointment, available_slots = add_appointment(appointment_data) assert isinstance(appointment, Appointment) and isinstance(available_slots, str) @@ -97,14 +97,15 @@ def test_commands_registered(bitcoin_cli): def test_appointment_life_cycle(bitcoin_cli): # First of all we need to register - teos_cli.register(compressed_cli_pk, teos_base_endpoint) + response = teos_cli.register(user_id, teos_base_endpoint) + available_slots = response.get("available_slots") # After that we can build an appointment and send it to the tower commitment_tx, penalty_tx = create_txs(bitcoin_cli) commitment_tx_id = bitcoin_cli.decoderawtransaction(commitment_tx).get("txid") - appointment_data = build_appointment_data(bitcoin_cli, commitment_tx_id, penalty_tx) + appointment_data = build_appointment_data(commitment_tx_id, penalty_tx) locator = compute_locator(commitment_tx_id) - appointment, available_slots = add_appointment(appointment_data) + appointment, signature = add_appointment(appointment_data) # Get the information from the tower to check that it matches appointment_info = get_appointment_info(locator) @@ -141,14 +142,18 @@ def test_appointment_life_cycle(bitcoin_cli): # If the transaction is not found. assert False - # Now let's mine some blocks so the appointment reaches its end. - for _ in range(END_TIME_DELTA): - bitcoin_cli.generatetoaddress(1, new_addr) + # Now let's mine some blocks so the appointment reaches its end. We need 100 + EXPIRY_DELTA -1 + bitcoin_cli.generatetoaddress(100 + teos_config.get("EXPIRY_DELTA") - 1, new_addr) # The appointment is no longer in the tower with pytest.raises(TowerResponseError): get_appointment_info(locator) + # Check that the appointment is not in the Gatekeeper by checking the available slots (should have increase by 1) + # We can do so by topping up the subscription (FIXME: find a better way to check this). + response = teos_cli.register(user_id, teos_base_endpoint) + assert response.get("available_slots") == available_slots + teos_config.get("DEFAULT_SLOTS") + 1 + def test_multiple_appointments_life_cycle(bitcoin_cli): # Tests that get_all_appointments returns all the appointments the tower is storing at various stages in the @@ -160,7 +165,7 @@ def test_multiple_appointments_life_cycle(bitcoin_cli): # Create five appointments. for commitment_tx, penalty_tx in zip(commitment_txs, penalty_txs): commitment_tx_id = bitcoin_cli.decoderawtransaction(commitment_tx).get("txid") - appointment_data = build_appointment_data(bitcoin_cli, commitment_tx_id, penalty_tx) + appointment_data = build_appointment_data(commitment_tx_id, penalty_tx) locator = compute_locator(commitment_tx_id) appointment = { @@ -194,9 +199,8 @@ def test_multiple_appointments_life_cycle(bitcoin_cli): assert set(responder_locators) == set(breached_appointments) new_addr = bitcoin_cli.getnewaddress() - # Now let's mine some blocks so the appointment reaches its end. - for _ in range(END_TIME_DELTA): - bitcoin_cli.generatetoaddress(1, new_addr) + # Now let's mine some blocks so the appointment reaches its end. We need 100 + EXPIRY_DELTA -1 + bitcoin_cli.generatetoaddress(100 + teos_config.get("EXPIRY_DELTA") - 1, new_addr) # The appointment is no longer in the tower with pytest.raises(TowerResponseError): @@ -214,7 +218,7 @@ def test_appointment_malformed_penalty(bitcoin_cli): mod_penalty_tx = mod_penalty_tx.copy(tx_ins=[tx_in]) commitment_tx_id = bitcoin_cli.decoderawtransaction(commitment_tx).get("txid") - appointment_data = build_appointment_data(bitcoin_cli, commitment_tx_id, mod_penalty_tx.hex()) + appointment_data = build_appointment_data(commitment_tx_id, mod_penalty_tx.hex()) locator = compute_locator(commitment_tx_id) appointment, _ = add_appointment(appointment_data) @@ -240,15 +244,15 @@ def test_appointment_wrong_decryption_key(bitcoin_cli): commitment_tx, penalty_tx = create_txs(bitcoin_cli) # The appointment data is built using a random 32-byte value. - appointment_data = build_appointment_data(bitcoin_cli, get_random_value_hex(32), penalty_tx) + appointment_data = build_appointment_data(get_random_value_hex(32), penalty_tx) - # We can't use teos_cli.add_appointment here since it computes the locator internally, so let's do it manually. + # We cannot use teos_cli.add_appointment here since it computes the locator internally, so let's do it manually. # We will encrypt the blob using the random value and derive the locator from the commitment tx. appointment_data["locator"] = compute_locator(bitcoin_cli.decoderawtransaction(commitment_tx).get("txid")) - appointment_data["encrypted_blob"] = Cryptographer.encrypt(Blob(penalty_tx), get_random_value_hex(32)) + appointment_data["encrypted_blob"] = Cryptographer.encrypt(penalty_tx, get_random_value_hex(32)) appointment = Appointment.from_dict(appointment_data) - signature = Cryptographer.sign(appointment.serialize(), cli_sk) + signature = Cryptographer.sign(appointment.serialize(), user_sk) data = {"appointment": appointment.to_dict(), "signature": signature} # Send appointment to the server. @@ -258,7 +262,7 @@ def test_appointment_wrong_decryption_key(bitcoin_cli): # Check that the server has accepted the appointment signature = response_json.get("signature") rpk = Cryptographer.recover_pk(appointment.serialize(), signature) - assert Cryptographer.verify_rpk(teos_pk, rpk) is True + assert teos_id == Cryptographer.get_compressed_pk(rpk) assert response_json.get("locator") == appointment.locator # Trigger the appointment @@ -277,7 +281,7 @@ def test_two_identical_appointments(bitcoin_cli): commitment_tx, penalty_tx = create_txs(bitcoin_cli) commitment_tx_id = bitcoin_cli.decoderawtransaction(commitment_tx).get("txid") - appointment_data = build_appointment_data(bitcoin_cli, commitment_tx_id, penalty_tx) + appointment_data = build_appointment_data(commitment_tx_id, penalty_tx) locator = compute_locator(commitment_tx_id) # Send the appointment twice @@ -305,22 +309,22 @@ def test_two_identical_appointments(bitcoin_cli): # commitment_tx, penalty_tx = create_txs(bitcoin_cli) # commitment_tx_id = bitcoin_cli.decoderawtransaction(commitment_tx).get("txid") # -# appointment_data = build_appointment_data(bitcoin_cli, commitment_tx_id, penalty_tx) +# appointment_data = build_appointment_data(commitment_tx_id, penalty_tx) # locator = compute_locator(commitment_tx_id) # # # tmp keys from a different user -# tmp_sk = PrivateKey() -# tmp_compressed_pk = hexlify(tmp_sk.public_key.format(compressed=True)).decode("utf-8") -# teos_cli.register(tmp_compressed_pk, teos_base_endpoint) +# tmp_user_sk = PrivateKey() +# tmp_user_id = hexlify(tmp_user_sk.public_key.format(compressed=True)).decode("utf-8") +# teos_cli.register(tmp_user_id, teos_base_endpoint) # # # Send the appointment twice # assert add_appointment(appointment_data) is True -# assert add_appointment(appointment_data, sk=tmp_sk) is True +# assert add_appointment(appointment_data, sk=tmp_user_sk) is True # # # Check that we can get it from both users # appointment_info = get_appointment_info(locator) # assert appointment_info.get("status") == "being_watched" -# appointment_info = get_appointment_info(locator, sk=tmp_sk) +# appointment_info = get_appointment_info(locator, sk=tmp_user_sk) # assert appointment_info.get("status") == "being_watched" # # # Broadcast the commitment transaction and mine a block @@ -330,7 +334,7 @@ def test_two_identical_appointments(bitcoin_cli): # # The last appointment should have made it to the Responder # sleep(1) # appointment_info = get_appointment_info(locator) -# appointment_dup_info = get_appointment_info(locator, sk=tmp_sk) +# appointment_dup_info = get_appointment_info(locator, sk=tmp_user_sk) # # # One of the two request must be None, while the other must be valid # assert (appointment_info is None and appointment_dup_info is not None) or ( @@ -353,17 +357,17 @@ def test_two_appointment_same_locator_different_penalty_different_users(bitcoin_ new_addr = bitcoin_cli.getnewaddress() penalty_tx2 = create_penalty_tx(bitcoin_cli, decoded_commitment_tx, new_addr) - appointment1_data = build_appointment_data(bitcoin_cli, commitment_tx_id, penalty_tx1) - appointment2_data = build_appointment_data(bitcoin_cli, commitment_tx_id, penalty_tx2) + appointment1_data = build_appointment_data(commitment_tx_id, penalty_tx1) + appointment2_data = build_appointment_data(commitment_tx_id, penalty_tx2) locator = compute_locator(commitment_tx_id) # tmp keys for a different user - tmp_sk = PrivateKey() - tmp_compressed_pk = hexlify(tmp_sk.public_key.format(compressed=True)).decode("utf-8") - teos_cli.register(tmp_compressed_pk, teos_base_endpoint) + tmp_user_sk = PrivateKey() + tmp_user_id = hexlify(tmp_user_sk.public_key.format(compressed=True)).decode("utf-8") + teos_cli.register(tmp_user_id, teos_base_endpoint) appointment, _ = add_appointment(appointment1_data) - appointment_2, _ = add_appointment(appointment2_data, sk=tmp_sk) + appointment_2, _ = add_appointment(appointment2_data, sk=tmp_user_sk) # Broadcast the commitment transaction and mine a block new_addr = bitcoin_cli.getnewaddress() @@ -374,7 +378,7 @@ def test_two_appointment_same_locator_different_penalty_different_users(bitcoin_ appointment_info = None with pytest.raises(TowerResponseError): appointment_info = get_appointment_info(locator) - appointment2_info = get_appointment_info(locator, sk=tmp_sk) + appointment2_info = get_appointment_info(locator, sk=tmp_user_sk) if appointment_info is None: appointment_info = appointment2_info @@ -392,7 +396,7 @@ def test_appointment_shutdown_teos_trigger_back_online(bitcoin_cli): commitment_tx, penalty_tx = create_txs(bitcoin_cli) commitment_tx_id = bitcoin_cli.decoderawtransaction(commitment_tx).get("txid") - appointment_data = build_appointment_data(bitcoin_cli, commitment_tx_id, penalty_tx) + appointment_data = build_appointment_data(commitment_tx_id, penalty_tx) locator = compute_locator(commitment_tx_id) appointment, _ = add_appointment(appointment_data) @@ -425,7 +429,7 @@ def test_appointment_shutdown_teos_trigger_while_offline(bitcoin_cli): commitment_tx, penalty_tx = create_txs(bitcoin_cli) commitment_tx_id = bitcoin_cli.decoderawtransaction(commitment_tx).get("txid") - appointment_data = build_appointment_data(bitcoin_cli, commitment_tx_id, penalty_tx) + appointment_data = build_appointment_data(commitment_tx_id, penalty_tx) locator = compute_locator(commitment_tx_id) appointment, _ = add_appointment(appointment_data) diff --git a/test/teos/unit/conftest.py b/test/teos/unit/conftest.py index cbf3fc5..f49bbb3 100644 --- a/test/teos/unit/conftest.py +++ b/test/teos/unit/conftest.py @@ -10,26 +10,20 @@ from bitcoind_mock.bitcoind import BitcoindMock from bitcoind_mock.conf import BTC_RPC_HOST, BTC_RPC_PORT from bitcoind_mock.transaction import create_dummy_transaction +from teos import DEFAULT_CONF from teos.carrier import Carrier -from teos.tools import bitcoin_cli from teos.users_dbm import UsersDBM from teos.gatekeeper import Gatekeeper -from teos import LOG_PREFIX, DEFAULT_CONF from teos.responder import TransactionTracker from teos.block_processor import BlockProcessor from teos.appointments_dbm import AppointmentsDBM +from teos.extended_appointment import ExtendedAppointment -import common.cryptographer -from common.blob import Blob -from common.logger import Logger from common.tools import compute_locator -from common.appointment import Appointment from common.constants import LOCATOR_LEN_HEX from common.config_loader import ConfigLoader from common.cryptographer import Cryptographer -common.cryptographer.logger = Logger(actor="Cryptographer", log_name_prefix=LOG_PREFIX) - # Set params to connect to regtest for testing DEFAULT_CONF["BTC_RPC_PORT"]["value"] = 18443 DEFAULT_CONF["BTC_NETWORK"]["value"] = "regtest" @@ -86,8 +80,14 @@ def block_processor(): @pytest.fixture(scope="module") -def gatekeeper(user_db_manager): - return Gatekeeper(user_db_manager, get_config().get("DEFAULT_SLOTS")) +def gatekeeper(user_db_manager, block_processor): + return Gatekeeper( + user_db_manager, + block_processor, + get_config().get("DEFAULT_SLOTS"), + get_config().get("DEFAULT_SUBSCRIPTION_DURATION"), + get_config().get("EXPIRY_DELTA"), + ) def generate_keypair(): @@ -103,11 +103,21 @@ def get_random_value_hex(nbytes): return prv_hex.zfill(2 * nbytes) -def generate_block(): +def generate_block_w_delay(): requests.post(url="http://{}:{}/generate".format(BTC_RPC_HOST, BTC_RPC_PORT), timeout=5) sleep(0.5) +def generate_blocks_w_delay(n): + for _ in range(n): + generate_block() + sleep(0.2) + + +def generate_block(): + requests.post(url="http://{}:{}/generate".format(BTC_RPC_HOST, BTC_RPC_PORT), timeout=5) + + def generate_blocks(n): for _ in range(n): generate_block() @@ -118,39 +128,23 @@ def fork(block_hash): requests.post(fork_endpoint, json={"parent": block_hash}) -def generate_dummy_appointment(real_height=True, start_time_offset=5, end_time_offset=30): - if real_height: - current_height = bitcoin_cli(bitcoind_connect_params).getblockcount() - - else: - current_height = 10 - +def generate_dummy_appointment(): dispute_tx = create_dummy_transaction() dispute_txid = dispute_tx.tx_id.hex() penalty_tx = create_dummy_transaction(dispute_txid) - dummy_appointment_data = { - "tx": penalty_tx.hex(), - "tx_id": dispute_txid, - "start_time": current_height + start_time_offset, - "end_time": current_height + end_time_offset, - "to_self_delay": 20, - } - locator = compute_locator(dispute_txid) - blob = Blob(dummy_appointment_data.get("tx")) - - encrypted_blob = Cryptographer.encrypt(blob, dummy_appointment_data.get("tx_id")) + dummy_appointment_data = {"tx": penalty_tx.hex(), "tx_id": dispute_txid, "to_self_delay": 20} + encrypted_blob = Cryptographer.encrypt(dummy_appointment_data.get("tx"), dummy_appointment_data.get("tx_id")) appointment_data = { "locator": locator, - "start_time": dummy_appointment_data.get("start_time"), - "end_time": dummy_appointment_data.get("end_time"), "to_self_delay": dummy_appointment_data.get("to_self_delay"), "encrypted_blob": encrypted_blob, + "user_id": get_random_value_hex(16), } - return Appointment.from_dict(appointment_data), dispute_tx.hex() + return ExtendedAppointment.from_dict(appointment_data), dispute_tx.hex() def generate_dummy_tracker(): @@ -164,7 +158,7 @@ def generate_dummy_tracker(): dispute_txid=dispute_txid, penalty_txid=penalty_txid, penalty_rawtx=penalty_rawtx, - appointment_end=100, + user_id=get_random_value_hex(16), ) return TransactionTracker.from_dict(tracker_data) diff --git a/test/teos/unit/test_api.py b/test/teos/unit/test_api.py index db59a3b..daac059 100644 --- a/test/teos/unit/test_api.py +++ b/test/teos/unit/test_api.py @@ -3,9 +3,10 @@ from shutil import rmtree from binascii import hexlify from teos.api import API -import teos.errors as errors +import common.errors as errors from teos.watcher import Watcher from teos.inspector import Inspector +from teos.gatekeeper import UserInfo from teos.appointments_dbm import AppointmentsDBM from teos.responder import Responder, TransactionTracker @@ -39,8 +40,8 @@ appointments = {} locator_dispute_tx_map = {} -client_sk, client_pk = generate_keypair() -compressed_client_pk = hexlify(client_pk.format(compressed=True)).decode("utf-8") +user_sk, user_pk = generate_keypair() +user_id = hexlify(user_pk.format(compressed=True)).decode("utf-8") @pytest.fixture() @@ -58,10 +59,10 @@ def get_all_db_manager(): def api(db_manager, carrier, block_processor, gatekeeper, run_bitcoind): sk, pk = generate_keypair() - responder = Responder(db_manager, carrier, block_processor) - watcher = Watcher(db_manager, block_processor, responder, sk.to_der(), MAX_APPOINTMENTS, config.get("EXPIRY_DELTA")) + responder = Responder(db_manager, gatekeeper, carrier, block_processor) + watcher = Watcher(db_manager, gatekeeper, block_processor, responder, sk.to_der(), MAX_APPOINTMENTS) inspector = Inspector(block_processor, config.get("MIN_TO_SELF_DELAY")) - api = API(config.get("API_HOST"), config.get("API_PORT"), inspector, watcher, gatekeeper) + api = API(config.get("API_HOST"), config.get("API_PORT"), inspector, watcher) return api @@ -85,47 +86,52 @@ def appointment(): return appointment -def add_appointment(client, appointment_data, user_pk): +def add_appointment(client, appointment_data, user_id): r = client.post(add_appointment_endpoint, json=appointment_data) if r.status_code == HTTP_OK: locator = appointment_data.get("appointment").get("locator") - uuid = hash_160("{}{}".format(locator, user_pk)) + uuid = hash_160("{}{}".format(locator, user_id)) appointments[uuid] = appointment_data["appointment"] return r -def test_register(client): - data = {"public_key": compressed_client_pk} +def test_register(client, api): + current_height = api.watcher.block_processor.get_block_count() + data = {"public_key": user_id} r = client.post(register_endpoint, json=data) assert r.status_code == HTTP_OK - assert r.json.get("public_key") == compressed_client_pk + assert r.json.get("public_key") == user_id assert r.json.get("available_slots") == config.get("DEFAULT_SLOTS") + assert r.json.get("subscription_expiry") == current_height + config.get("DEFAULT_SUBSCRIPTION_DURATION") -def test_register_top_up(client): - # Calling register more than once will give us DEFAULT_SLOTS * number_of_calls slots +def test_register_top_up(client, api): + # Calling register more than once will give us DEFAULT_SLOTS * number_of_calls slots. + # It will also refresh the expiry. temp_sk, tmp_pk = generate_keypair() - tmp_pk_hex = hexlify(tmp_pk.format(compressed=True)).decode("utf-8") + tmp_user_id = hexlify(tmp_pk.format(compressed=True)).decode("utf-8") + current_height = api.watcher.block_processor.get_block_count() - data = {"public_key": tmp_pk_hex} + data = {"public_key": tmp_user_id} for i in range(10): r = client.post(register_endpoint, json=data) assert r.status_code == HTTP_OK - assert r.json.get("public_key") == tmp_pk_hex + assert r.json.get("public_key") == tmp_user_id assert r.json.get("available_slots") == config.get("DEFAULT_SLOTS") * (i + 1) + assert r.json.get("subscription_expiry") == current_height + config.get("DEFAULT_SUBSCRIPTION_DURATION") def test_register_no_client_pk(client): - data = {"public_key": compressed_client_pk + compressed_client_pk} + data = {} r = client.post(register_endpoint, json=data) assert r.status_code == HTTP_BAD_REQUEST def test_register_wrong_client_pk(client): - data = {} + data = {"public_key": user_id + user_id} r = client.post(register_endpoint, json=data) assert r.status_code == HTTP_BAD_REQUEST @@ -133,60 +139,62 @@ def test_register_wrong_client_pk(client): def test_register_no_json(client): r = client.post(register_endpoint, data="random_message") assert r.status_code == HTTP_BAD_REQUEST + assert errors.INVALID_REQUEST_FORMAT == r.json.get("error_code") def test_register_json_no_inner_dict(client): r = client.post(register_endpoint, json="random_message") assert r.status_code == HTTP_BAD_REQUEST + assert errors.INVALID_REQUEST_FORMAT == r.json.get("error_code") def test_add_appointment(api, client, appointment): - # Simulate the user registration - api.gatekeeper.registered_users[compressed_client_pk] = {"available_slots": 1} + # Simulate the user registration (end time does not matter here) + api.watcher.gatekeeper.registered_users[user_id] = UserInfo(available_slots=1, subscription_expiry=0) # Properly formatted appointment - appointment_signature = Cryptographer.sign(appointment.serialize(), client_sk) - r = add_appointment( - client, {"appointment": appointment.to_dict(), "signature": appointment_signature}, compressed_client_pk - ) + appointment_signature = Cryptographer.sign(appointment.serialize(), user_sk) + r = add_appointment(client, {"appointment": appointment.to_dict(), "signature": appointment_signature}, user_id) assert r.status_code == HTTP_OK assert r.json.get("available_slots") == 0 def test_add_appointment_no_json(api, client, appointment): - # Simulate the user registration - api.gatekeeper.registered_users[compressed_client_pk] = {"available_slots": 1} + # Simulate the user registration (end time does not matter here) + api.watcher.gatekeeper.registered_users[user_id] = UserInfo(available_slots=1, subscription_expiry=0) - # Properly formatted appointment + # No JSON data r = client.post(add_appointment_endpoint, data="random_message") assert r.status_code == HTTP_BAD_REQUEST + assert "Request is not json encoded" in r.json.get("error") + assert errors.INVALID_REQUEST_FORMAT == r.json.get("error_code") def test_add_appointment_json_no_inner_dict(api, client, appointment): - # Simulate the user registration - api.gatekeeper.registered_users[compressed_client_pk] = {"available_slots": 1} + # Simulate the user registration (end time does not matter here) + api.watcher.gatekeeper.registered_users[user_id] = UserInfo(available_slots=1, subscription_expiry=0) - # Properly formatted appointment + # JSON data with no inner dict (invalid data format) r = client.post(add_appointment_endpoint, json="random_message") assert r.status_code == HTTP_BAD_REQUEST + assert "Invalid request content" in r.json.get("error") + assert errors.INVALID_REQUEST_FORMAT == r.json.get("error_code") def test_add_appointment_wrong(api, client, appointment): - # Simulate the user registration - api.gatekeeper.registered_users[compressed_client_pk] = 1 + # Simulate the user registration (end time does not matter here) + api.watcher.gatekeeper.registered_users[user_id] = UserInfo(available_slots=1, subscription_expiry=0) - # Incorrect appointment + # Incorrect appointment (properly formatted, wrong data) appointment.to_self_delay = 0 - appointment_signature = Cryptographer.sign(appointment.serialize(), client_sk) - r = add_appointment( - client, {"appointment": appointment.to_dict(), "signature": appointment_signature}, compressed_client_pk - ) + appointment_signature = Cryptographer.sign(appointment.serialize(), user_sk) + r = add_appointment(client, {"appointment": appointment.to_dict(), "signature": appointment_signature}, user_id) assert r.status_code == HTTP_BAD_REQUEST - assert "Error {}:".format(errors.APPOINTMENT_FIELD_TOO_SMALL) in r.json.get("error") + assert errors.APPOINTMENT_FIELD_TOO_SMALL == r.json.get("error_code") def test_add_appointment_not_registered(api, client, appointment): - # Properly formatted appointment + # Properly formatted appointment, user is not registered tmp_sk, tmp_pk = generate_keypair() tmp_compressed_pk = hexlify(tmp_pk.format(compressed=True)).decode("utf-8") @@ -195,49 +203,43 @@ def test_add_appointment_not_registered(api, client, appointment): client, {"appointment": appointment.to_dict(), "signature": appointment_signature}, tmp_compressed_pk ) assert r.status_code == HTTP_BAD_REQUEST - assert "Error {}:".format(errors.APPOINTMENT_INVALID_SIGNATURE_OR_INSUFFICIENT_SLOTS) in r.json.get("error") + assert errors.APPOINTMENT_INVALID_SIGNATURE_OR_INSUFFICIENT_SLOTS == r.json.get("error_code") def test_add_appointment_registered_no_free_slots(api, client, appointment): - # Empty the user slots - api.gatekeeper.registered_users[compressed_client_pk] = {"available_slots": 0} + # Empty the user slots (end time does not matter here) + api.watcher.gatekeeper.registered_users[user_id] = UserInfo(available_slots=0, subscription_expiry=0) - # Properly formatted appointment - appointment_signature = Cryptographer.sign(appointment.serialize(), client_sk) - r = add_appointment( - client, {"appointment": appointment.to_dict(), "signature": appointment_signature}, compressed_client_pk - ) + # Properly formatted appointment, user has no available slots + appointment_signature = Cryptographer.sign(appointment.serialize(), user_sk) + r = add_appointment(client, {"appointment": appointment.to_dict(), "signature": appointment_signature}, user_id) assert r.status_code == HTTP_BAD_REQUEST - assert "Error {}:".format(errors.APPOINTMENT_INVALID_SIGNATURE_OR_INSUFFICIENT_SLOTS) in r.json.get("error") + assert errors.APPOINTMENT_INVALID_SIGNATURE_OR_INSUFFICIENT_SLOTS == r.json.get("error_code") def test_add_appointment_registered_not_enough_free_slots(api, client, appointment): - # Give some slots to the user - api.gatekeeper.registered_users[compressed_client_pk] = 1 + # Give some slots to the user (end time does not matter here) + api.watcher.gatekeeper.registered_users[user_id] = UserInfo(available_slots=1, subscription_expiry=0) - # Properly formatted appointment - appointment_signature = Cryptographer.sign(appointment.serialize(), client_sk) + # Properly formatted appointment, user has not enough slots + appointment_signature = Cryptographer.sign(appointment.serialize(), user_sk) # Let's create a big blob - appointment.encrypted_blob.data = TWO_SLOTS_BLOTS + appointment.encrypted_blob = TWO_SLOTS_BLOTS - r = add_appointment( - client, {"appointment": appointment.to_dict(), "signature": appointment_signature}, compressed_client_pk - ) + r = add_appointment(client, {"appointment": appointment.to_dict(), "signature": appointment_signature}, user_id) assert r.status_code == HTTP_BAD_REQUEST - assert "Error {}:".format(errors.APPOINTMENT_INVALID_SIGNATURE_OR_INSUFFICIENT_SLOTS) in r.json.get("error") + assert errors.APPOINTMENT_INVALID_SIGNATURE_OR_INSUFFICIENT_SLOTS == r.json.get("error_code") def test_add_appointment_multiple_times_same_user(api, client, appointment, n=MULTIPLE_APPOINTMENTS): - # Multiple appointments with the same locator should be valid and counted as updates - appointment_signature = Cryptographer.sign(appointment.serialize(), client_sk) + # Multiple appointments with the same locator should be valid and count as updates + appointment_signature = Cryptographer.sign(appointment.serialize(), user_sk) - # Simulate registering enough slots - api.gatekeeper.registered_users[compressed_client_pk] = {"available_slots": n} + # Simulate registering enough slots (end time does not matter here) + api.watcher.gatekeeper.registered_users[user_id] = UserInfo(available_slots=n, subscription_expiry=0) for _ in range(n): - r = add_appointment( - client, {"appointment": appointment.to_dict(), "signature": appointment_signature}, compressed_client_pk - ) + r = add_appointment(client, {"appointment": appointment.to_dict(), "signature": appointment_signature}, user_id) assert r.status_code == HTTP_OK assert r.json.get("available_slots") == n - 1 @@ -246,6 +248,7 @@ def test_add_appointment_multiple_times_same_user(api, client, appointment, n=MU def test_add_appointment_multiple_times_different_users(api, client, appointment, n=MULTIPLE_APPOINTMENTS): + # If the same appointment comes from different users, all are kept # Create user keys and appointment signatures user_keys = [generate_keypair() for _ in range(n)] signatures = [Cryptographer.sign(appointment.serialize(), key[0]) for key in user_keys] @@ -254,7 +257,7 @@ def test_add_appointment_multiple_times_different_users(api, client, appointment # Add one slot per public key for pair in user_keys: tmp_compressed_pk = hexlify(pair[1].format(compressed=True)).decode("utf-8") - api.gatekeeper.registered_users[tmp_compressed_pk] = {"available_slots": 2} + api.watcher.gatekeeper.registered_users[tmp_compressed_pk] = UserInfo(available_slots=2, subscription_expiry=0) # Send the appointments for compressed_pk, signature in zip(compressed_pks, signatures): @@ -268,77 +271,61 @@ def test_add_appointment_multiple_times_different_users(api, client, appointment def test_add_appointment_update_same_size(api, client, appointment): # Update an appointment by one of the same size and check that no additional slots are filled - api.gatekeeper.registered_users[compressed_client_pk] = {"available_slots": 1} + api.watcher.gatekeeper.registered_users[user_id] = UserInfo(available_slots=1, subscription_expiry=0) - appointment_signature = Cryptographer.sign(appointment.serialize(), client_sk) - # # Since we will replace the appointment, we won't added to appointments - r = add_appointment( - client, {"appointment": appointment.to_dict(), "signature": appointment_signature}, compressed_client_pk - ) + appointment_signature = Cryptographer.sign(appointment.serialize(), user_sk) + r = add_appointment(client, {"appointment": appointment.to_dict(), "signature": appointment_signature}, user_id) assert r.status_code == HTTP_OK and r.json.get("available_slots") == 0 # The user has no additional slots, but it should be able to update # Let's just reverse the encrypted blob for example - appointment.encrypted_blob.data = appointment.encrypted_blob.data[::-1] - appointment_signature = Cryptographer.sign(appointment.serialize(), client_sk) - r = add_appointment( - client, {"appointment": appointment.to_dict(), "signature": appointment_signature}, compressed_client_pk - ) + appointment.encrypted_blob = appointment.encrypted_blob[::-1] + appointment_signature = Cryptographer.sign(appointment.serialize(), user_sk) + r = add_appointment(client, {"appointment": appointment.to_dict(), "signature": appointment_signature}, user_id) assert r.status_code == HTTP_OK and r.json.get("available_slots") == 0 def test_add_appointment_update_bigger(api, client, appointment): # Update an appointment by one bigger, and check additional slots are filled - api.gatekeeper.registered_users[compressed_client_pk] = {"available_slots": 2} + api.watcher.gatekeeper.registered_users[user_id] = UserInfo(available_slots=2, subscription_expiry=0) - appointment_signature = Cryptographer.sign(appointment.serialize(), client_sk) - r = add_appointment( - client, {"appointment": appointment.to_dict(), "signature": appointment_signature}, compressed_client_pk - ) + appointment_signature = Cryptographer.sign(appointment.serialize(), user_sk) + r = add_appointment(client, {"appointment": appointment.to_dict(), "signature": appointment_signature}, user_id) assert r.status_code == HTTP_OK and r.json.get("available_slots") == 1 # The user has one slot, so it should be able to update as long as it only takes 1 additional slot - appointment.encrypted_blob.data = TWO_SLOTS_BLOTS - appointment_signature = Cryptographer.sign(appointment.serialize(), client_sk) - r = add_appointment( - client, {"appointment": appointment.to_dict(), "signature": appointment_signature}, compressed_client_pk - ) + appointment.encrypted_blob = TWO_SLOTS_BLOTS + appointment_signature = Cryptographer.sign(appointment.serialize(), user_sk) + r = add_appointment(client, {"appointment": appointment.to_dict(), "signature": appointment_signature}, user_id) assert r.status_code == HTTP_OK and r.json.get("available_slots") == 0 # Check that it'll fail if no enough slots are available # Double the size from before - appointment.encrypted_blob.data = TWO_SLOTS_BLOTS + TWO_SLOTS_BLOTS - appointment_signature = Cryptographer.sign(appointment.serialize(), client_sk) - r = add_appointment( - client, {"appointment": appointment.to_dict(), "signature": appointment_signature}, compressed_client_pk - ) + appointment.encrypted_blob = TWO_SLOTS_BLOTS + TWO_SLOTS_BLOTS + appointment_signature = Cryptographer.sign(appointment.serialize(), user_sk) + r = add_appointment(client, {"appointment": appointment.to_dict(), "signature": appointment_signature}, user_id) assert r.status_code == HTTP_BAD_REQUEST def test_add_appointment_update_smaller(api, client, appointment): # Update an appointment by one bigger, and check slots are freed - api.gatekeeper.registered_users[compressed_client_pk] = {"available_slots": 2} - + api.watcher.gatekeeper.registered_users[user_id] = UserInfo(available_slots=2, subscription_expiry=0) # This should take 2 slots - appointment.encrypted_blob.data = TWO_SLOTS_BLOTS - appointment_signature = Cryptographer.sign(appointment.serialize(), client_sk) - r = add_appointment( - client, {"appointment": appointment.to_dict(), "signature": appointment_signature}, compressed_client_pk - ) + appointment.encrypted_blob = TWO_SLOTS_BLOTS + appointment_signature = Cryptographer.sign(appointment.serialize(), user_sk) + r = add_appointment(client, {"appointment": appointment.to_dict(), "signature": appointment_signature}, user_id) assert r.status_code == HTTP_OK and r.json.get("available_slots") == 0 # Let's update with one just small enough - appointment.encrypted_blob.data = "A" * (ENCRYPTED_BLOB_MAX_SIZE_HEX - 2) - appointment_signature = Cryptographer.sign(appointment.serialize(), client_sk) - r = add_appointment( - client, {"appointment": appointment.to_dict(), "signature": appointment_signature}, compressed_client_pk - ) + appointment.encrypted_blob = "A" * (ENCRYPTED_BLOB_MAX_SIZE_HEX - 2) + appointment_signature = Cryptographer.sign(appointment.serialize(), user_sk) + r = add_appointment(client, {"appointment": appointment.to_dict(), "signature": appointment_signature}, user_id) assert r.status_code == HTTP_OK and r.json.get("available_slots") == 1 def test_add_too_many_appointment(api, client): # Give slots to the user - api.gatekeeper.registered_users[compressed_client_pk] = {"available_slots": 200} + api.watcher.gatekeeper.registered_users[user_id] = UserInfo(available_slots=200, subscription_expiry=0) free_appointment_slots = MAX_APPOINTMENTS - len(api.watcher.appointments) @@ -346,10 +333,8 @@ def test_add_too_many_appointment(api, client): appointment, dispute_tx = generate_dummy_appointment() locator_dispute_tx_map[appointment.locator] = dispute_tx - appointment_signature = Cryptographer.sign(appointment.serialize(), client_sk) - r = add_appointment( - client, {"appointment": appointment.to_dict(), "signature": appointment_signature}, compressed_client_pk - ) + appointment_signature = Cryptographer.sign(appointment.serialize(), user_sk) + r = add_appointment(client, {"appointment": appointment.to_dict(), "signature": appointment_signature}, user_id) if i < free_appointment_slots: assert r.status_code == HTTP_OK @@ -360,14 +345,18 @@ def test_add_too_many_appointment(api, client): def test_get_appointment_no_json(api, client, appointment): r = client.post(add_appointment_endpoint, data="random_message") assert r.status_code == HTTP_BAD_REQUEST + assert "Request is not json encoded" in r.json.get("error") + assert errors.INVALID_REQUEST_FORMAT == r.json.get("error_code") def test_get_appointment_json_no_inner_dict(api, client, appointment): r = client.post(add_appointment_endpoint, json="random_message") assert r.status_code == HTTP_BAD_REQUEST + assert "Invalid request content" in r.json.get("error") + assert errors.INVALID_REQUEST_FORMAT == r.json.get("error_code") -def test_request_random_appointment_registered_user(client, user_sk=client_sk): +def test_get_random_appointment_registered_user(client, user_sk=user_sk): locator = get_random_value_hex(LOCATOR_LEN_BYTES) message = "get appointment {}".format(locator) signature = Cryptographer.sign(message.encode("utf-8"), user_sk) @@ -381,60 +370,62 @@ def test_request_random_appointment_registered_user(client, user_sk=client_sk): assert received_appointment.get("status") == "not_found" -def test_request_appointment_not_registered_user(client): +def test_get_appointment_not_registered_user(client): # Not registered users have no associated appointments, so this should fail tmp_sk, tmp_pk = generate_keypair() # The tower is designed so a not found appointment and a request from a non-registered user return the same error to # prevent probing. - test_request_random_appointment_registered_user(client, tmp_sk) + test_get_random_appointment_registered_user(client, tmp_sk) -def test_request_appointment_in_watcher(api, client, appointment): +def test_get_appointment_in_watcher(api, client, appointment): # Mock the appointment in the Watcher - uuid = hash_160("{}{}".format(appointment.locator, compressed_client_pk)) + uuid = hash_160("{}{}".format(appointment.locator, user_id)) api.watcher.db_manager.store_watcher_appointment(uuid, appointment.to_dict()) # Next we can request it message = "get appointment {}".format(appointment.locator) - signature = Cryptographer.sign(message.encode("utf-8"), client_sk) + signature = Cryptographer.sign(message.encode("utf-8"), user_sk) data = {"locator": appointment.locator, "signature": signature} r = client.post(get_appointment_endpoint, json=data) assert r.status_code == HTTP_OK - # Check that the appointment is on the watcher + # Check that the appointment is on the Watcher assert r.json.get("status") == "being_watched" # Check the the sent appointment matches the received one + appointment_dict = appointment.to_dict() + appointment_dict.pop("user_id") assert r.json.get("locator") == appointment.locator assert appointment.to_dict() == r.json.get("appointment") -def test_request_appointment_in_responder(api, client, appointment): +def test_get_appointment_in_responder(api, client, appointment): # Mock the appointment in the Responder tracker_data = { "locator": appointment.locator, "dispute_txid": get_random_value_hex(32), "penalty_txid": get_random_value_hex(32), "penalty_rawtx": get_random_value_hex(250), - "appointment_end": appointment.end_time, + "user_id": get_random_value_hex(16), } tx_tracker = TransactionTracker.from_dict(tracker_data) - uuid = hash_160("{}{}".format(appointment.locator, compressed_client_pk)) + uuid = hash_160("{}{}".format(appointment.locator, user_id)) api.watcher.db_manager.create_triggered_appointment_flag(uuid) api.watcher.responder.db_manager.store_responder_tracker(uuid, tx_tracker.to_dict()) # Request back the data message = "get appointment {}".format(appointment.locator) - signature = Cryptographer.sign(message.encode("utf-8"), client_sk) + signature = Cryptographer.sign(message.encode("utf-8"), user_sk) data = {"locator": appointment.locator, "signature": signature} # Next we can request it r = client.post(get_appointment_endpoint, json=data) assert r.status_code == HTTP_OK - # Check that the appointment is on the watcher + # Check that the appointment is on the Responder assert r.json.get("status") == "dispute_responded" # Check the the sent appointment matches the received one @@ -442,10 +433,9 @@ def test_request_appointment_in_responder(api, client, appointment): assert tx_tracker.dispute_txid == r.json.get("appointment").get("dispute_txid") assert tx_tracker.penalty_txid == r.json.get("appointment").get("penalty_txid") assert tx_tracker.penalty_rawtx == r.json.get("appointment").get("penalty_rawtx") - assert tx_tracker.appointment_end == r.json.get("appointment").get("appointment_end") -def test_get_all_appointments_watcher(api, client, get_all_db_manager, appointment): +def test_get_all_appointments_watcher(api, client, get_all_db_manager): # Let's reset the dbs so we can test this clean api.watcher.db_manager = get_all_db_manager api.watcher.responder.db_manager = get_all_db_manager @@ -459,6 +449,7 @@ def test_get_all_appointments_watcher(api, client, get_all_db_manager, appointme non_triggered_appointments = {} for _ in range(10): uuid = get_random_value_hex(16) + appointment, _ = generate_dummy_appointment() appointment.locator = get_random_value_hex(16) non_triggered_appointments[uuid] = appointment.to_dict() api.watcher.db_manager.store_watcher_appointment(uuid, appointment.to_dict()) @@ -466,12 +457,13 @@ def test_get_all_appointments_watcher(api, client, get_all_db_manager, appointme triggered_appointments = {} for _ in range(10): uuid = get_random_value_hex(16) + appointment, _ = generate_dummy_appointment() appointment.locator = get_random_value_hex(16) triggered_appointments[uuid] = appointment.to_dict() api.watcher.db_manager.store_watcher_appointment(uuid, appointment.to_dict()) api.watcher.db_manager.create_triggered_appointment_flag(uuid) - # We should only get check the non-triggered appointments + # We should only get the non-triggered appointments r = client.get(get_all_appointment_endpoint) assert r.status_code == HTTP_OK @@ -501,7 +493,7 @@ def test_get_all_appointments_responder(api, client, get_all_db_manager): "dispute_txid": get_random_value_hex(32), "penalty_txid": get_random_value_hex(32), "penalty_rawtx": get_random_value_hex(250), - "appointment_end": 20, + "user_id": get_random_value_hex(16), } tracker = TransactionTracker.from_dict(tracker_data) tx_trackers[uuid] = tracker.to_dict() diff --git a/test/teos/unit/test_appointments_dbm.py b/test/teos/unit/test_appointments_dbm.py index 48928f6..7c29f55 100644 --- a/test/teos/unit/test_appointments_dbm.py +++ b/test/teos/unit/test_appointments_dbm.py @@ -19,7 +19,7 @@ from test.teos.unit.conftest import get_random_value_hex, generate_dummy_appoint @pytest.fixture(scope="module") def watcher_appointments(): - return {uuid4().hex: generate_dummy_appointment(real_height=False)[0] for _ in range(10)} + return {uuid4().hex: generate_dummy_appointment()[0] for _ in range(10)} @pytest.fixture(scope="module") @@ -215,7 +215,7 @@ def test_store_load_triggered_appointment(db_manager): assert db_watcher_appointments == db_watcher_appointments_with_triggered # Create an appointment flagged as triggered - triggered_appointment, _ = generate_dummy_appointment(real_height=False) + triggered_appointment, _ = generate_dummy_appointment() uuid = uuid4().hex assert db_manager.store_watcher_appointment(uuid, triggered_appointment.to_dict()) is True db_manager.create_triggered_appointment_flag(uuid) diff --git a/test/teos/unit/test_builder.py b/test/teos/unit/test_builder.py index 756cc5e..f804e54 100644 --- a/test/teos/unit/test_builder.py +++ b/test/teos/unit/test_builder.py @@ -4,6 +4,7 @@ from queue import Queue from teos.builder import Builder from teos.watcher import Watcher +from teos.tools import bitcoin_cli from teos.responder import Responder from test.teos.unit.conftest import ( @@ -11,9 +12,9 @@ from test.teos.unit.conftest import ( generate_dummy_appointment, generate_dummy_tracker, generate_block, - bitcoin_cli, get_config, bitcoind_connect_params, + generate_keypair, ) config = get_config() @@ -24,7 +25,7 @@ def test_build_appointments(): # Create some appointment data for i in range(10): - appointment, _ = generate_dummy_appointment(real_height=False) + appointment, _ = generate_dummy_appointment() uuid = uuid4().hex appointments_data[uuid] = appointment.to_dict() @@ -32,7 +33,7 @@ def test_build_appointments(): # Add some additional appointments that share the same locator to test all the builder's cases if i % 2 == 0: locator = appointment.locator - appointment, _ = generate_dummy_appointment(real_height=False) + appointment, _ = generate_dummy_appointment() uuid = uuid4().hex appointment.locator = locator @@ -45,8 +46,7 @@ def test_build_appointments(): for uuid, appointment in appointments.items(): assert uuid in appointments_data.keys() assert appointments_data[uuid].get("locator") == appointment.get("locator") - assert appointments_data[uuid].get("end_time") == appointment.get("end_time") - assert len(appointments_data[uuid].get("encrypted_blob")) == appointment.get("size") + assert appointments_data[uuid].get("user_id") == appointment.get("user_id") assert uuid in locator_uuid_map[appointment.get("locator")] @@ -75,7 +75,7 @@ def test_build_trackers(): assert tracker.get("penalty_txid") == trackers_data[uuid].get("penalty_txid") assert tracker.get("locator") == trackers_data[uuid].get("locator") - assert tracker.get("appointment_end") == trackers_data[uuid].get("appointment_end") + assert tracker.get("user_id") == trackers_data[uuid].get("user_id") assert uuid in tx_tracker_map[tracker.get("penalty_txid")] @@ -94,14 +94,14 @@ def test_populate_block_queue(): assert len(blocks) == 0 -def test_update_states_empty_list(db_manager, carrier, block_processor): +def test_update_states_empty_list(db_manager, gatekeeper, carrier, block_processor): w = Watcher( db_manager=db_manager, + gatekeeper=gatekeeper, block_processor=block_processor, - responder=Responder(db_manager, carrier, block_processor), - sk_der=None, + responder=Responder(db_manager, gatekeeper, carrier, block_processor), + sk_der=generate_keypair()[0].to_der(), max_appointments=config.get("MAX_APPOINTMENTS"), - expiry_delta=config.get("EXPIRY_DELTA"), ) missed_blocks_watcher = [] @@ -115,14 +115,14 @@ def test_update_states_empty_list(db_manager, carrier, block_processor): Builder.update_states(w, missed_blocks_responder, missed_blocks_watcher) -def test_update_states_responder_misses_more(run_bitcoind, db_manager, carrier, block_processor): +def test_update_states_responder_misses_more(run_bitcoind, db_manager, gatekeeper, carrier, block_processor): w = Watcher( db_manager=db_manager, + gatekeeper=gatekeeper, block_processor=block_processor, - responder=Responder(db_manager, carrier, block_processor), - sk_der=None, + responder=Responder(db_manager, gatekeeper, carrier, block_processor), + sk_der=generate_keypair()[0].to_der(), max_appointments=config.get("MAX_APPOINTMENTS"), - expiry_delta=config.get("EXPIRY_DELTA"), ) blocks = [] @@ -139,15 +139,15 @@ def test_update_states_responder_misses_more(run_bitcoind, db_manager, carrier, assert w.responder.last_known_block == blocks[-1] -def test_update_states_watcher_misses_more(db_manager, carrier, block_processor): +def test_update_states_watcher_misses_more(db_manager, gatekeeper, carrier, block_processor): # Same as before, but data is now in the Responder w = Watcher( db_manager=db_manager, + gatekeeper=gatekeeper, block_processor=block_processor, - responder=Responder(db_manager, carrier, block_processor), - sk_der=None, + responder=Responder(db_manager, gatekeeper, carrier, block_processor), + sk_der=generate_keypair()[0].to_der(), max_appointments=config.get("MAX_APPOINTMENTS"), - expiry_delta=config.get("EXPIRY_DELTA"), ) blocks = [] diff --git a/test/teos/unit/test_cleaner.py b/test/teos/unit/test_cleaner.py index ad2e263..5ea02a1 100644 --- a/test/teos/unit/test_cleaner.py +++ b/test/teos/unit/test_cleaner.py @@ -1,8 +1,9 @@ import random from uuid import uuid4 -from teos.responder import TransactionTracker from teos.cleaner import Cleaner +from teos.gatekeeper import UserInfo +from teos.responder import TransactionTracker from common.appointment import Appointment from test.teos.unit.conftest import get_random_value_hex @@ -23,7 +24,7 @@ def set_up_appointments(db_manager, total_appointments): uuid = uuid4().hex locator = get_random_value_hex(LOCATOR_LEN_BYTES) - appointment = Appointment(locator, None, None, None, None) + appointment = Appointment(locator, None, None) appointments[uuid] = {"locator": appointment.locator} locator_uuid_map[locator] = [uuid] @@ -156,7 +157,8 @@ def test_flag_triggered_appointments(db_manager): assert set(triggered_appointments).issubset(db_appointments) -def test_delete_completed_trackers_db_match(db_manager): +def test_delete_trackers_db_match(db_manager): + # Completed and expired trackers are deleted using the same method. The only difference is the logging message height = 0 for _ in range(ITERATIONS): @@ -165,12 +167,12 @@ def test_delete_completed_trackers_db_match(db_manager): completed_trackers = {tracker: 6 for tracker in selected_trackers} - Cleaner.delete_completed_trackers(completed_trackers, height, trackers, tx_tracker_map, db_manager) + Cleaner.delete_trackers(completed_trackers, height, trackers, tx_tracker_map, db_manager) assert not set(completed_trackers).issubset(trackers.keys()) -def test_delete_completed_trackers_no_db_match(db_manager): +def test_delete_trackers_no_db_match(db_manager): height = 0 for _ in range(ITERATIONS): @@ -203,5 +205,38 @@ def test_delete_completed_trackers_no_db_match(db_manager): completed_trackers = {tracker: 6 for tracker in selected_trackers} # We should be able to delete the correct ones and not fail in the others - Cleaner.delete_completed_trackers(completed_trackers, height, trackers, tx_tracker_map, db_manager) + Cleaner.delete_trackers(completed_trackers, height, trackers, tx_tracker_map, db_manager) assert not set(completed_trackers).issubset(trackers.keys()) + + +def test_delete_gatekeeper_appointments(gatekeeper): + # delete_gatekeeper_appointments should delete the appointments from user as long as both exist + + appointments_not_to_delete = {} + appointments_to_delete = {} + # Let's add some users and appointments to the Gatekeeper + for _ in range(10): + user_id = get_random_value_hex(16) + # The UserInfo params do not matter much here + gatekeeper.registered_users[user_id] = UserInfo(available_slots=100, subscription_expiry=0) + for _ in range(random.randint(0, 10)): + # Add some appointments + uuid = get_random_value_hex(16) + gatekeeper.registered_users[user_id].appointments[uuid] = 1 + + if random.randint(0, 1) % 2: + appointments_to_delete[uuid] = user_id + else: + appointments_not_to_delete[uuid] = user_id + + # Now let's delete half of them + Cleaner.delete_gatekeeper_appointments(gatekeeper, appointments_to_delete) + + all_appointments_gatekeeper = [] + # Let's get all the appointments in the Gatekeeper + for user_id, user in gatekeeper.registered_users.items(): + all_appointments_gatekeeper.extend(user.appointments) + + # Check that the first half of the appointments are not in the Gatekeeper, but the second half is + assert not set(appointments_to_delete).issubset(all_appointments_gatekeeper) + assert set(appointments_not_to_delete).issubset(all_appointments_gatekeeper) diff --git a/test/teos/unit/test_extended_appointment.py b/test/teos/unit/test_extended_appointment.py new file mode 100644 index 0000000..bc883fe --- /dev/null +++ b/test/teos/unit/test_extended_appointment.py @@ -0,0 +1,63 @@ +import pytest +from pytest import fixture + +from common.constants import LOCATOR_LEN_BYTES +from teos.extended_appointment import ExtendedAppointment + +from test.common.unit.conftest import get_random_value_hex + + +# Parent methods are not tested. +@fixture +def appointment_data(): + locator = get_random_value_hex(LOCATOR_LEN_BYTES) + to_self_delay = 20 + user_id = get_random_value_hex(16) + encrypted_blob_data = get_random_value_hex(100) + + return { + "locator": locator, + "to_self_delay": to_self_delay, + "encrypted_blob": encrypted_blob_data, + "user_id": user_id, + } + + +def test_init_appointment(appointment_data): + # The appointment has no checks whatsoever, since the inspector is the one taking care or that, and the only one + # creating appointments. + appointment = ExtendedAppointment( + appointment_data["locator"], + appointment_data["to_self_delay"], + appointment_data["encrypted_blob"], + appointment_data["user_id"], + ) + + assert ( + appointment_data["locator"] == appointment.locator + and appointment_data["to_self_delay"] == appointment.to_self_delay + and appointment_data["encrypted_blob"] == appointment.encrypted_blob + and appointment_data["user_id"] == appointment.user_id + ) + + +def test_get_summary(appointment_data): + assert ExtendedAppointment.from_dict(appointment_data).get_summary() == { + "locator": appointment_data["locator"], + "user_id": appointment_data["user_id"], + } + + +def test_from_dict(appointment_data): + # The appointment should be build if we don't miss any field + appointment = ExtendedAppointment.from_dict(appointment_data) + assert isinstance(appointment, ExtendedAppointment) + + # Otherwise it should fail + for key in appointment_data.keys(): + prev_val = appointment_data[key] + appointment_data[key] = None + + with pytest.raises(ValueError, match="Wrong appointment data"): + ExtendedAppointment.from_dict(appointment_data) + appointment_data[key] = prev_val diff --git a/test/teos/unit/test_gatekeeper.py b/test/teos/unit/test_gatekeeper.py index bfb916f..a688fce 100644 --- a/test/teos/unit/test_gatekeeper.py +++ b/test/teos/unit/test_gatekeeper.py @@ -1,57 +1,76 @@ import pytest -from teos.gatekeeper import IdentificationFailure, NotEnoughSlots +from teos.users_dbm import UsersDBM +from teos.block_processor import BlockProcessor +from teos.gatekeeper import AuthenticationFailure, NotEnoughSlots, UserInfo from common.cryptographer import Cryptographer +from common.exceptions import InvalidParameter +from common.constants import ENCRYPTED_BLOB_MAX_SIZE_HEX -from test.teos.unit.conftest import get_random_value_hex, generate_keypair, get_config +from test.teos.unit.conftest import get_random_value_hex, generate_keypair, get_config, generate_dummy_appointment config = get_config() -def test_init(gatekeeper): +def test_init(gatekeeper, run_bitcoind): assert isinstance(gatekeeper.default_slots, int) and gatekeeper.default_slots == config.get("DEFAULT_SLOTS") + assert isinstance( + gatekeeper.default_subscription_duration, int + ) and gatekeeper.default_subscription_duration == config.get("DEFAULT_SUBSCRIPTION_DURATION") + assert isinstance(gatekeeper.expiry_delta, int) and gatekeeper.expiry_delta == config.get("EXPIRY_DELTA") + assert isinstance(gatekeeper.block_processor, BlockProcessor) + assert isinstance(gatekeeper.user_db, UsersDBM) assert isinstance(gatekeeper.registered_users, dict) and len(gatekeeper.registered_users) == 0 def test_add_update_user(gatekeeper): # add_update_user adds DEFAULT_SLOTS to a given user as long as the identifier is {02, 03}| 32-byte hex str - user_pk = "02" + get_random_value_hex(32) + # it also add DEFAULT_SUBSCRIPTION_DURATION + current_block_height to the user + user_id = "02" + get_random_value_hex(32) for _ in range(10): - current_slots = gatekeeper.registered_users.get(user_pk) - current_slots = current_slots.get("available_slots") if current_slots is not None else 0 + user = gatekeeper.registered_users.get(user_id) + current_slots = user.available_slots if user is not None else 0 - gatekeeper.add_update_user(user_pk) + gatekeeper.add_update_user(user_id) - assert gatekeeper.registered_users.get(user_pk).get("available_slots") == current_slots + config.get( - "DEFAULT_SLOTS" + assert gatekeeper.registered_users.get(user_id).available_slots == current_slots + config.get("DEFAULT_SLOTS") + assert gatekeeper.registered_users[ + user_id + ].subscription_expiry == gatekeeper.block_processor.get_block_count() + config.get( + "DEFAULT_SUBSCRIPTION_DURATION" ) # The same can be checked for multiple users for _ in range(10): # The user identifier is changed every call - user_pk = "03" + get_random_value_hex(32) + user_id = "03" + get_random_value_hex(32) - gatekeeper.add_update_user(user_pk) - assert gatekeeper.registered_users.get(user_pk).get("available_slots") == config.get("DEFAULT_SLOTS") + gatekeeper.add_update_user(user_id) + assert gatekeeper.registered_users.get(user_id).available_slots == config.get("DEFAULT_SLOTS") + assert gatekeeper.registered_users[ + user_id + ].subscription_expiry == gatekeeper.block_processor.get_block_count() + config.get( + "DEFAULT_SUBSCRIPTION_DURATION" + ) -def test_add_update_user_wrong_pk(gatekeeper): +def test_add_update_user_wrong_id(gatekeeper): # Passing a wrong pk defaults to the errors in check_user_pk. We can try with one. - wrong_pk = get_random_value_hex(32) + wrong_id = get_random_value_hex(32) - with pytest.raises(ValueError): - gatekeeper.add_update_user(wrong_pk) + with pytest.raises(InvalidParameter): + gatekeeper.add_update_user(wrong_id) -def test_add_update_user_wrong_pk_prefix(gatekeeper): +def test_add_update_user_wrong_id_prefix(gatekeeper): # Prefixes must be 02 or 03, anything else should fail - wrong_pk = "04" + get_random_value_hex(32) + wrong_id = "04" + get_random_value_hex(32) - with pytest.raises(ValueError): - gatekeeper.add_update_user(wrong_pk) + with pytest.raises(InvalidParameter): + gatekeeper.add_update_user(wrong_id) def test_identify_user(gatekeeper): @@ -60,13 +79,13 @@ def test_identify_user(gatekeeper): # Let's first register a user sk, pk = generate_keypair() - compressed_pk = Cryptographer.get_compressed_pk(pk) - gatekeeper.add_update_user(compressed_pk) + user_id = Cryptographer.get_compressed_pk(pk) + gatekeeper.add_update_user(user_id) message = "Hey, it's me" signature = Cryptographer.sign(message.encode(), sk) - assert gatekeeper.identify_user(message.encode(), signature) == compressed_pk + assert gatekeeper.authenticate_user(message.encode(), signature) == user_id def test_identify_user_non_registered(gatekeeper): @@ -76,8 +95,8 @@ def test_identify_user_non_registered(gatekeeper): message = "Hey, it's me" signature = Cryptographer.sign(message.encode(), sk) - with pytest.raises(IdentificationFailure): - gatekeeper.identify_user(message.encode(), signature) + with pytest.raises(AuthenticationFailure): + gatekeeper.authenticate_user(message.encode(), signature) def test_identify_user_invalid_signature(gatekeeper): @@ -85,8 +104,8 @@ def test_identify_user_invalid_signature(gatekeeper): message = "Hey, it's me" signature = get_random_value_hex(72) - with pytest.raises(IdentificationFailure): - gatekeeper.identify_user(message.encode(), signature) + with pytest.raises(AuthenticationFailure): + gatekeeper.authenticate_user(message.encode(), signature) def test_identify_user_wrong(gatekeeper): @@ -97,41 +116,74 @@ def test_identify_user_wrong(gatekeeper): signature = Cryptographer.sign(message.encode(), sk) # Non-byte message and str sig - with pytest.raises(IdentificationFailure): - gatekeeper.identify_user(message, signature) + with pytest.raises(AuthenticationFailure): + gatekeeper.authenticate_user(message, signature) # byte message and non-str sig - with pytest.raises(IdentificationFailure): - gatekeeper.identify_user(message.encode(), signature.encode()) + with pytest.raises(AuthenticationFailure): + gatekeeper.authenticate_user(message.encode(), signature.encode()) # non-byte message and non-str sig - with pytest.raises(IdentificationFailure): - gatekeeper.identify_user(message, signature.encode()) + with pytest.raises(AuthenticationFailure): + gatekeeper.authenticate_user(message, signature.encode()) -def test_fill_slots(gatekeeper): - # Free slots will decrease the slot count of a user as long as he has enough slots, otherwise raise NotEnoughSlots - user_pk = "02" + get_random_value_hex(32) - gatekeeper.add_update_user(user_pk) +def test_add_update_appointment(gatekeeper): + # add_update_appointment should decrease the slot count if a new appointment is added + # let's add a new user + sk, pk = generate_keypair() + user_id = Cryptographer.get_compressed_pk(pk) + gatekeeper.add_update_user(user_id) - gatekeeper.fill_slots(user_pk, config.get("DEFAULT_SLOTS") - 1) - assert gatekeeper.registered_users.get(user_pk).get("available_slots") == 1 + # And now update add a new appointment + appointment, _ = generate_dummy_appointment() + appointment_uuid = get_random_value_hex(16) + remaining_slots = gatekeeper.add_update_appointment(user_id, appointment_uuid, appointment) + # This is a standard size appointment, so it should have reduced the slots by one + assert appointment_uuid in gatekeeper.registered_users[user_id].appointments + assert remaining_slots == config.get("DEFAULT_SLOTS") - 1 + + # Updates can leave the count as is, decrease it, or increase it, depending on the appointment size (modulo + # ENCRYPTED_BLOB_MAX_SIZE_HEX) + + # Appointments of the same size leave it as is + appointment_same_size, _ = generate_dummy_appointment() + remaining_slots = gatekeeper.add_update_appointment(user_id, appointment_uuid, appointment) + assert appointment_uuid in gatekeeper.registered_users[user_id].appointments + assert remaining_slots == config.get("DEFAULT_SLOTS") - 1 + + # Bigger appointments decrease it + appointment_x2_size = appointment_same_size + appointment_x2_size.encrypted_blob = "A" * (ENCRYPTED_BLOB_MAX_SIZE_HEX + 1) + remaining_slots = gatekeeper.add_update_appointment(user_id, appointment_uuid, appointment_x2_size) + assert appointment_uuid in gatekeeper.registered_users[user_id].appointments + assert remaining_slots == config.get("DEFAULT_SLOTS") - 2 + + # Smaller appointments increase it + remaining_slots = gatekeeper.add_update_appointment(user_id, appointment_uuid, appointment) + assert remaining_slots == config.get("DEFAULT_SLOTS") - 1 + + # If the appointment needs more slots than there's free, it should fail + gatekeeper.registered_users[user_id].available_slots = 1 + appointment_uuid = get_random_value_hex(16) with pytest.raises(NotEnoughSlots): - gatekeeper.fill_slots(user_pk, 2) - - # NotEnoughSlots is also raised if the user does not exist - with pytest.raises(NotEnoughSlots): - gatekeeper.fill_slots(get_random_value_hex(33), 2) + gatekeeper.add_update_appointment(user_id, appointment_uuid, appointment_x2_size) -def test_free_slots(gatekeeper): - # Free slots simply adds slots to the user as long as it exists. - user_pk = "03" + get_random_value_hex(32) - gatekeeper.add_update_user(user_pk) - gatekeeper.free_slots(user_pk, 42) +def test_get_expired_appointments(gatekeeper): + # get_expired_appointments returns a list of appointment uuids expiring at a given block - assert gatekeeper.registered_users.get(user_pk).get("available_slots") == config.get("DEFAULT_SLOTS") + 42 + appointment = {} + # Let's simulate adding some users with dummy expiry times + gatekeeper.registered_users = {} + for i in reversed(range(100)): + uuid = get_random_value_hex(16) + user_appointments = [get_random_value_hex(16)] + # Add a single appointment to the user + gatekeeper.registered_users[uuid] = UserInfo(100, i, user_appointments) + appointment[i] = user_appointments - # Just making sure it does not crash for non-registered user - assert gatekeeper.free_slots(get_random_value_hex(33), 10) is None + # Now let's check that reversed + for i in range(100): + assert gatekeeper.get_expired_appointments(i + gatekeeper.expiry_delta) == appointment[i] diff --git a/test/teos/unit/test_inspector.py b/test/teos/unit/test_inspector.py index b3993f5..51e2828 100644 --- a/test/teos/unit/test_inspector.py +++ b/test/teos/unit/test_inspector.py @@ -1,20 +1,15 @@ import pytest from binascii import unhexlify -import teos.errors as errors -from teos import LOG_PREFIX +import common.errors as errors from teos.block_processor import BlockProcessor from teos.inspector import Inspector, InspectionFailed +from teos.extended_appointment import ExtendedAppointment -import common.cryptographer -from common.logger import Logger -from common.appointment import Appointment from common.constants import LOCATOR_LEN_BYTES, LOCATOR_LEN_HEX from test.teos.unit.conftest import get_random_value_hex, bitcoind_connect_params, get_config -common.cryptographer.logger = Logger(actor="Cryptographer", log_name_prefix=LOG_PREFIX) - NO_HEX_STRINGS = [ "R" * LOCATOR_LEN_HEX, get_random_value_hex(LOCATOR_LEN_BYTES - 1) + "PP", @@ -100,101 +95,6 @@ def test_check_locator(): raise e -def test_check_start_time(): - # Time is defined in block height - current_time = 100 - - # Right format and right value (start time in the future) - start_time = 101 - assert inspector.check_start_time(start_time, current_time) is None - - # Start time too small (either same block or block in the past) - start_times = [100, 99, 98, -1] - for start_time in start_times: - with pytest.raises(InspectionFailed): - try: - inspector.check_start_time(start_time, current_time) - - except InspectionFailed as e: - assert e.erno == errors.APPOINTMENT_FIELD_TOO_SMALL - raise e - - # Empty field - start_time = None - with pytest.raises(InspectionFailed): - try: - inspector.check_start_time(start_time, current_time) - - except InspectionFailed as e: - assert e.erno == errors.APPOINTMENT_EMPTY_FIELD - raise e - - # Wrong data type - start_times = WRONG_TYPES - for start_time in start_times: - with pytest.raises(InspectionFailed): - try: - inspector.check_start_time(start_time, current_time) - - except InspectionFailed as e: - assert e.erno == errors.APPOINTMENT_WRONG_FIELD_TYPE - raise e - - -def test_check_end_time(): - # Time is defined in block height - current_time = 100 - start_time = 120 - - # Right format and right value (start time before end and end in the future) - end_time = 121 - assert inspector.check_end_time(end_time, start_time, current_time) is None - - # End time too small (start time after end time) - end_times = [120, 119, 118, -1] - for end_time in end_times: - with pytest.raises(InspectionFailed): - try: - inspector.check_end_time(end_time, start_time, current_time) - - except InspectionFailed as e: - assert e.erno == errors.APPOINTMENT_FIELD_TOO_SMALL - raise e - - # End time too small (either same height as current block or in the past) - current_time = 130 - end_times = [130, 129, 128, -1] - for end_time in end_times: - with pytest.raises(InspectionFailed): - try: - inspector.check_end_time(end_time, start_time, current_time) - - except InspectionFailed as e: - assert e.erno == errors.APPOINTMENT_FIELD_TOO_SMALL - raise e - - # Empty field - end_time = None - with pytest.raises(InspectionFailed): - try: - inspector.check_end_time(end_time, start_time, current_time) - - except InspectionFailed as e: - assert e.erno == errors.APPOINTMENT_EMPTY_FIELD - raise e - - # Wrong data type - end_times = WRONG_TYPES - for end_time in end_times: - with pytest.raises(InspectionFailed): - try: - inspector.check_end_time(end_time, start_time, current_time) - - except InspectionFailed as e: - assert e.erno == errors.APPOINTMENT_WRONG_FIELD_TYPE - raise e - - def test_check_to_self_delay(): # Right value, right format to_self_delays = [MIN_TO_SELF_DELAY, MIN_TO_SELF_DELAY + 1, MIN_TO_SELF_DELAY + 1000] @@ -239,10 +139,6 @@ def test_check_blob(): encrypted_blob = get_random_value_hex(120) assert inspector.check_blob(encrypted_blob) is None - # # Wrong content - # # FIXME: There is not proper defined format for this yet. It should be restricted by size at least, and check it - # # is multiple of the block size defined by the encryption function. - # Wrong type encrypted_blobs = WRONG_TYPES_NO_STR for encrypted_blob in encrypted_blobs: @@ -284,23 +180,15 @@ def test_inspect(run_bitcoind): to_self_delay = MIN_TO_SELF_DELAY encrypted_blob = get_random_value_hex(64) - appointment_data = { - "locator": locator, - "start_time": start_time, - "end_time": end_time, - "to_self_delay": to_self_delay, - "encrypted_blob": encrypted_blob, - } + appointment_data = {"locator": locator, "to_self_delay": to_self_delay, "encrypted_blob": encrypted_blob} appointment = inspector.inspect(appointment_data) assert ( - type(appointment) == Appointment + type(appointment) == ExtendedAppointment and appointment.locator == locator - and appointment.start_time == start_time - and appointment.end_time == end_time and appointment.to_self_delay == to_self_delay - and appointment.encrypted_blob.data == encrypted_blob + and appointment.encrypted_blob == encrypted_blob ) diff --git a/test/teos/unit/test_responder.py b/test/teos/unit/test_responder.py index 7c4d53d..bc8f566 100644 --- a/test/teos/unit/test_responder.py +++ b/test/teos/unit/test_responder.py @@ -9,23 +9,31 @@ from threading import Thread from teos.carrier import Carrier from teos.tools import bitcoin_cli from teos.chain_monitor import ChainMonitor +from teos.block_processor import BlockProcessor +from teos.gatekeeper import Gatekeeper, UserInfo from teos.appointments_dbm import AppointmentsDBM -from teos.responder import Responder, TransactionTracker +from teos.responder import Responder, TransactionTracker, CONFIRMATIONS_BEFORE_RETRY from common.constants import LOCATOR_LEN_HEX from bitcoind_mock.transaction import create_dummy_transaction, create_tx_from_hex from test.teos.unit.conftest import ( generate_block, generate_blocks, + generate_block_w_delay, + generate_blocks_w_delay, get_random_value_hex, bitcoind_connect_params, bitcoind_feed_params, + get_config, ) +config = get_config() + + @pytest.fixture(scope="module") -def responder(db_manager, carrier, block_processor): - responder = Responder(db_manager, carrier, block_processor) +def responder(db_manager, gatekeeper, carrier, block_processor): + responder = Responder(db_manager, gatekeeper, carrier, block_processor) chain_monitor = ChainMonitor(Queue(), responder.block_queue, block_processor, bitcoind_feed_params) chain_monitor.monitor_chain() @@ -66,31 +74,86 @@ def create_dummy_tracker_data(random_txid=False, penalty_rawtx=None): if random_txid is True: penalty_txid = get_random_value_hex(32) - appointment_end = bitcoin_cli(bitcoind_connect_params).getblockcount() + 2 locator = dispute_txid[:LOCATOR_LEN_HEX] + user_id = get_random_value_hex(16) - return locator, dispute_txid, penalty_txid, penalty_rawtx, appointment_end + return locator, dispute_txid, penalty_txid, penalty_rawtx, user_id def create_dummy_tracker(random_txid=False, penalty_rawtx=None): - locator, dispute_txid, penalty_txid, penalty_rawtx, appointment_end = create_dummy_tracker_data( - random_txid, penalty_rawtx - ) - return TransactionTracker(locator, dispute_txid, penalty_txid, penalty_rawtx, appointment_end) + locator, dispute_txid, penalty_txid, penalty_rawtx, user_id = create_dummy_tracker_data(random_txid, penalty_rawtx) + return TransactionTracker(locator, dispute_txid, penalty_txid, penalty_rawtx, user_id) def test_tracker_init(run_bitcoind): - locator, dispute_txid, penalty_txid, penalty_rawtx, appointment_end = create_dummy_tracker_data() - tracker = TransactionTracker(locator, dispute_txid, penalty_txid, penalty_rawtx, appointment_end) + locator, dispute_txid, penalty_txid, penalty_rawtx, user_id = create_dummy_tracker_data() + tracker = TransactionTracker(locator, dispute_txid, penalty_txid, penalty_rawtx, user_id) assert ( - tracker.dispute_txid == dispute_txid + tracker.locator == locator + and tracker.dispute_txid == dispute_txid and tracker.penalty_txid == penalty_txid and tracker.penalty_rawtx == penalty_rawtx - and tracker.appointment_end == appointment_end + and tracker.user_id == user_id ) +def test_tracker_to_dict(): + tracker = create_dummy_tracker() + tracker_dict = tracker.to_dict() + + assert ( + tracker.locator == tracker_dict["locator"] + and tracker.penalty_rawtx == tracker_dict["penalty_rawtx"] + and tracker.user_id == tracker_dict["user_id"] + ) + + +def test_tracker_from_dict(): + tracker_dict = create_dummy_tracker().to_dict() + new_tracker = TransactionTracker.from_dict(tracker_dict) + + assert tracker_dict == new_tracker.to_dict() + + +def test_tracker_from_dict_invalid_data(): + tracker_dict = create_dummy_tracker().to_dict() + + for value in ["dispute_txid", "penalty_txid", "penalty_rawtx", "user_id"]: + tracker_dict_copy = deepcopy(tracker_dict) + tracker_dict_copy[value] = None + + try: + TransactionTracker.from_dict(tracker_dict_copy) + assert False + + except ValueError: + assert True + + +def test_tracker_get_summary(): + tracker = create_dummy_tracker() + assert tracker.get_summary() == { + "locator": tracker.locator, + "user_id": tracker.user_id, + "penalty_txid": tracker.penalty_txid, + } + + +def test_init_responder(temp_db_manager, gatekeeper, carrier, block_processor): + responder = Responder(temp_db_manager, gatekeeper, carrier, block_processor) + assert isinstance(responder.trackers, dict) and len(responder.trackers) == 0 + assert isinstance(responder.tx_tracker_map, dict) and len(responder.tx_tracker_map) == 0 + assert isinstance(responder.unconfirmed_txs, list) and len(responder.unconfirmed_txs) == 0 + assert isinstance(responder.missed_confirmations, dict) and len(responder.missed_confirmations) == 0 + assert isinstance(responder.block_queue, Queue) and responder.block_queue.empty() + assert isinstance(responder.db_manager, AppointmentsDBM) + assert isinstance(responder.gatekeeper, Gatekeeper) + assert isinstance(responder.carrier, Carrier) + assert isinstance(responder.block_processor, BlockProcessor) + assert responder.last_known_block is None or isinstance(responder.last_known_block, str) + + def test_on_sync(run_bitcoind, responder, block_processor): # We're on sync if we're 1 or less blocks behind the tip chain_tip = block_processor.get_best_block_hash() @@ -108,50 +171,8 @@ def test_on_sync_fail(responder, block_processor): assert responder.on_sync(chain_tip) is False -def test_tracker_to_dict(): - tracker = create_dummy_tracker() - tracker_dict = tracker.to_dict() - - assert ( - tracker.locator == tracker_dict["locator"] - and tracker.penalty_rawtx == tracker_dict["penalty_rawtx"] - and tracker.appointment_end == tracker_dict["appointment_end"] - ) - - -def test_tracker_from_dict(): - tracker_dict = create_dummy_tracker().to_dict() - new_tracker = TransactionTracker.from_dict(tracker_dict) - - assert tracker_dict == new_tracker.to_dict() - - -def test_tracker_from_dict_invalid_data(): - tracker_dict = create_dummy_tracker().to_dict() - - for value in ["dispute_txid", "penalty_txid", "penalty_rawtx", "appointment_end"]: - tracker_dict_copy = deepcopy(tracker_dict) - tracker_dict_copy[value] = None - - try: - TransactionTracker.from_dict(tracker_dict_copy) - assert False - - except ValueError: - assert True - - -def test_init_responder(temp_db_manager, carrier, block_processor): - responder = Responder(temp_db_manager, carrier, block_processor) - assert isinstance(responder.trackers, dict) and len(responder.trackers) == 0 - assert isinstance(responder.tx_tracker_map, dict) and len(responder.tx_tracker_map) == 0 - assert isinstance(responder.unconfirmed_txs, list) and len(responder.unconfirmed_txs) == 0 - assert isinstance(responder.missed_confirmations, dict) and len(responder.missed_confirmations) == 0 - assert responder.block_queue.empty() - - -def test_handle_breach(db_manager, carrier, block_processor): - responder = Responder(db_manager, carrier, block_processor) +def test_handle_breach(db_manager, gatekeeper, carrier, block_processor): + responder = Responder(db_manager, gatekeeper, carrier, block_processor) uuid = uuid4().hex tracker = create_dummy_tracker() @@ -163,17 +184,17 @@ def test_handle_breach(db_manager, carrier, block_processor): tracker.dispute_txid, tracker.penalty_txid, tracker.penalty_rawtx, - tracker.appointment_end, + tracker.user_id, block_hash=get_random_value_hex(32), ) assert receipt.delivered is True -def test_handle_breach_bad_response(db_manager, block_processor): +def test_handle_breach_bad_response(db_manager, gatekeeper, block_processor): # We need a new carrier here, otherwise the transaction will be flagged as previously sent and receipt.delivered # will be True - responder = Responder(db_manager, Carrier(bitcoind_connect_params), block_processor) + responder = Responder(db_manager, gatekeeper, Carrier(bitcoind_connect_params), block_processor) uuid = uuid4().hex tracker = create_dummy_tracker() @@ -188,7 +209,7 @@ def test_handle_breach_bad_response(db_manager, block_processor): tracker.dispute_txid, tracker.penalty_txid, tracker.penalty_rawtx, - tracker.appointment_end, + tracker.user_id, block_hash=get_random_value_hex(32), ) @@ -199,9 +220,7 @@ def test_add_tracker(responder): for _ in range(20): uuid = uuid4().hex confirmations = 0 - locator, dispute_txid, penalty_txid, penalty_rawtx, appointment_end = create_dummy_tracker_data( - random_txid=True - ) + locator, dispute_txid, penalty_txid, penalty_rawtx, user_id = create_dummy_tracker_data(random_txid=True) # Check the tracker is not within the responder trackers before adding it assert uuid not in responder.trackers @@ -209,7 +228,7 @@ def test_add_tracker(responder): assert penalty_txid not in responder.unconfirmed_txs # And that it is afterwards - responder.add_tracker(uuid, locator, dispute_txid, penalty_txid, penalty_rawtx, appointment_end, confirmations) + responder.add_tracker(uuid, locator, dispute_txid, penalty_txid, penalty_rawtx, user_id, confirmations) assert uuid in responder.trackers assert penalty_txid in responder.tx_tracker_map assert penalty_txid in responder.unconfirmed_txs @@ -219,18 +238,18 @@ def test_add_tracker(responder): assert ( tracker.get("penalty_txid") == penalty_txid and tracker.get("locator") == locator - and tracker.get("appointment_end") == appointment_end + and tracker.get("user_id") == user_id ) def test_add_tracker_same_penalty_txid(responder): confirmations = 0 - locator, dispute_txid, penalty_txid, penalty_rawtx, appointment_end = create_dummy_tracker_data(random_txid=True) + locator, dispute_txid, penalty_txid, penalty_rawtx, user_id = create_dummy_tracker_data(random_txid=True) uuid_1 = uuid4().hex uuid_2 = uuid4().hex - responder.add_tracker(uuid_1, locator, dispute_txid, penalty_txid, penalty_rawtx, appointment_end, confirmations) - responder.add_tracker(uuid_2, locator, dispute_txid, penalty_txid, penalty_rawtx, appointment_end, confirmations) + responder.add_tracker(uuid_1, locator, dispute_txid, penalty_txid, penalty_rawtx, user_id, confirmations) + responder.add_tracker(uuid_2, locator, dispute_txid, penalty_txid, penalty_rawtx, user_id, confirmations) # Check that both trackers have been added assert uuid_1 in responder.trackers and uuid_2 in responder.trackers @@ -243,7 +262,7 @@ def test_add_tracker_same_penalty_txid(responder): assert ( tracker.get("penalty_txid") == penalty_txid and tracker.get("locator") == locator - and tracker.get("appointment_end") == appointment_end + and tracker.get("user_id") == user_id ) @@ -251,35 +270,45 @@ def test_add_tracker_already_confirmed(responder): for i in range(20): uuid = uuid4().hex confirmations = i + 1 - locator, dispute_txid, penalty_txid, penalty_rawtx, appointment_end = create_dummy_tracker_data( + locator, dispute_txid, penalty_txid, penalty_rawtx, user_id = create_dummy_tracker_data( penalty_rawtx=create_dummy_transaction().hex() ) - responder.add_tracker(uuid, locator, dispute_txid, penalty_txid, penalty_rawtx, appointment_end, confirmations) + responder.add_tracker(uuid, locator, dispute_txid, penalty_txid, penalty_rawtx, user_id, confirmations) assert penalty_txid not in responder.unconfirmed_txs + assert ( + responder.trackers[uuid].get("penalty_txid") == penalty_txid + and responder.trackers[uuid].get("locator") == locator + and responder.trackers[uuid].get("user_id") == user_id + ) -def test_do_watch(temp_db_manager, carrier, block_processor): +def test_do_watch(temp_db_manager, gatekeeper, carrier, block_processor): # Create a fresh responder to simplify the test - responder = Responder(temp_db_manager, carrier, block_processor) + responder = Responder(temp_db_manager, gatekeeper, carrier, block_processor) chain_monitor = ChainMonitor(Queue(), responder.block_queue, block_processor, bitcoind_feed_params) chain_monitor.monitor_chain() trackers = [create_dummy_tracker(penalty_rawtx=create_dummy_transaction().hex()) for _ in range(20)] + subscription_expiry = responder.block_processor.get_block_count() + 110 # Let's set up the trackers first for tracker in trackers: uuid = uuid4().hex - responder.trackers[uuid] = { - "locator": tracker.locator, - "penalty_txid": tracker.penalty_txid, - "appointment_end": tracker.appointment_end, - } + # Simulate user registration so trackers can properly expire + responder.gatekeeper.registered_users[tracker.user_id] = UserInfo( + available_slots=10, subscription_expiry=subscription_expiry + ) + + # Add data to the Responder + responder.trackers[uuid] = tracker.get_summary() responder.tx_tracker_map[tracker.penalty_txid] = [uuid] responder.missed_confirmations[tracker.penalty_txid] = 0 responder.unconfirmed_txs.append(tracker.penalty_txid) + # Assuming the appointment only took a single slot + responder.gatekeeper.registered_users[tracker.user_id].appointments[uuid] = 1 # We also need to store the info in the db responder.db_manager.create_triggered_appointment_flag(uuid) @@ -295,37 +324,40 @@ def test_do_watch(temp_db_manager, carrier, block_processor): broadcast_txs.append(tracker.penalty_txid) # Mine a block - generate_block() + generate_block_w_delay() # The transactions we sent shouldn't be in the unconfirmed transaction list anymore assert not set(broadcast_txs).issubset(responder.unconfirmed_txs) - # TODO: test that reorgs can be detected once data persistence is merged (new version of the simulator) + # CONFIRMATIONS_BEFORE_RETRY+1 blocks after, the responder should rebroadcast the unconfirmed txs (15 remaining) + generate_blocks_w_delay(CONFIRMATIONS_BEFORE_RETRY + 1) + assert len(responder.unconfirmed_txs) == 0 + assert len(responder.trackers) == 20 - # Generating 5 additional blocks should complete the 5 trackers - generate_blocks(5) + # Generating 100 - CONFIRMATIONS_BEFORE_RETRY -2 additional blocks should complete the first 5 trackers + generate_blocks_w_delay(100 - CONFIRMATIONS_BEFORE_RETRY - 2) + assert len(responder.unconfirmed_txs) == 0 + assert len(responder.trackers) == 15 + # Check they are not in the Gatekeeper either + for tracker in trackers[:5]: + assert len(responder.gatekeeper.registered_users[tracker.user_id].appointments) == 0 - assert not set(broadcast_txs).issubset(responder.tx_tracker_map) - - # Do the rest - broadcast_txs = [] + # CONFIRMATIONS_BEFORE_RETRY additional blocks should complete the rest + generate_blocks_w_delay(CONFIRMATIONS_BEFORE_RETRY) + assert len(responder.unconfirmed_txs) == 0 + assert len(responder.trackers) == 0 + # Check they are not in the Gatekeeper either for tracker in trackers[5:]: - bitcoin_cli(bitcoind_connect_params).sendrawtransaction(tracker.penalty_rawtx) - broadcast_txs.append(tracker.penalty_txid) - - # Mine a block - generate_blocks(6) - - assert len(responder.tx_tracker_map) == 0 + assert len(responder.gatekeeper.registered_users[tracker.user_id].appointments) == 0 -def test_check_confirmations(db_manager, carrier, block_processor): - responder = Responder(db_manager, carrier, block_processor) +def test_check_confirmations(db_manager, gatekeeper, carrier, block_processor): + responder = Responder(db_manager, gatekeeper, carrier, block_processor) chain_monitor = ChainMonitor(Queue(), responder.block_queue, block_processor, bitcoind_feed_params) chain_monitor.monitor_chain() # check_confirmations checks, given a list of transaction for a block, what of the known penalty transaction have - # been confirmed. To test this we need to create a list of transactions and the state of the responder + # been confirmed. To test this we need to create a list of transactions and the state of the Responder txs = [get_random_value_hex(32) for _ in range(20)] # The responder has a list of unconfirmed transaction, let make that some of them are the ones we've received @@ -352,7 +384,6 @@ def test_check_confirmations(db_manager, carrier, block_processor): assert responder.missed_confirmations[tx] == 1 -# TODO: Check this properly, a bug pass unnoticed! def test_get_txs_to_rebroadcast(responder): # Let's create a few fake txids and assign at least 6 missing confirmations to each txs_missing_too_many_conf = {get_random_value_hex(32): 6 + i for i in range(10)} @@ -376,68 +407,131 @@ def test_get_txs_to_rebroadcast(responder): assert txs_to_rebroadcast == list(txs_missing_too_many_conf.keys()) -def test_get_completed_trackers(db_manager, carrier, block_processor): - initial_height = bitcoin_cli(bitcoind_connect_params).getblockcount() - - responder = Responder(db_manager, carrier, block_processor) +def test_get_completed_trackers(db_manager, gatekeeper, carrier, block_processor): + responder = Responder(db_manager, gatekeeper, carrier, block_processor) chain_monitor = ChainMonitor(Queue(), responder.block_queue, block_processor, bitcoind_feed_params) chain_monitor.monitor_chain() - # A complete tracker is a tracker that has reached the appointment end with enough confs (> MIN_CONFIRMATIONS) - # We'll create three type of transactions: end reached + enough conf, end reached + no enough conf, end not reached - trackers_end_conf = { + # A complete tracker is a tracker which penalty transaction has been irrevocably resolved (i.e. has reached 100 + # confirmations) + # We'll create 3 type of txs: irrevocably resolved, confirmed but not irrevocably resolved, and unconfirmed + trackers_ir_resolved = { uuid4().hex: create_dummy_tracker(penalty_rawtx=create_dummy_transaction().hex()) for _ in range(10) } - trackers_end_no_conf = {} + trackers_confirmed = { + uuid4().hex: create_dummy_tracker(penalty_rawtx=create_dummy_transaction().hex()) for _ in range(10) + } + + trackers_unconfirmed = {} for _ in range(10): tracker = create_dummy_tracker(penalty_rawtx=create_dummy_transaction().hex()) responder.unconfirmed_txs.append(tracker.penalty_txid) - trackers_end_no_conf[uuid4().hex] = tracker - - trackers_no_end = {} - for _ in range(10): - tracker = create_dummy_tracker(penalty_rawtx=create_dummy_transaction().hex()) - tracker.appointment_end += 10 - trackers_no_end[uuid4().hex] = tracker + trackers_unconfirmed[uuid4().hex] = tracker all_trackers = {} - all_trackers.update(trackers_end_conf) - all_trackers.update(trackers_end_no_conf) - all_trackers.update(trackers_no_end) + all_trackers.update(trackers_ir_resolved) + all_trackers.update(trackers_confirmed) + all_trackers.update(trackers_unconfirmed) - # Let's add all to the responder + # Let's add all to the Responder for uuid, tracker in all_trackers.items(): - responder.trackers[uuid] = { - "locator": tracker.locator, - "penalty_txid": tracker.penalty_txid, - "appointment_end": tracker.appointment_end, - } + responder.trackers[uuid] = tracker.get_summary() - for uuid, tracker in all_trackers.items(): + for uuid, tracker in trackers_ir_resolved.items(): bitcoin_cli(bitcoind_connect_params).sendrawtransaction(tracker.penalty_rawtx) - # The dummy appointments have a end_appointment time of current + 2, but trackers need at least 6 confs by default - generate_blocks(6) + generate_block_w_delay() - # And now let's check - completed_trackers = responder.get_completed_trackers(initial_height + 6) - completed_trackers_ids = [tracker_id for tracker_id, confirmations in completed_trackers.items()] - ended_trackers_keys = list(trackers_end_conf.keys()) - assert set(completed_trackers_ids) == set(ended_trackers_keys) + for uuid, tracker in trackers_confirmed.items(): + bitcoin_cli(bitcoind_connect_params).sendrawtransaction(tracker.penalty_rawtx) - # Generating 6 additional blocks should also confirm trackers_no_end - generate_blocks(6) + # ir_resolved have 100 confirmations and confirmed have 99 + generate_blocks_w_delay(99) - completed_trackers = responder.get_completed_trackers(initial_height + 12) - completed_trackers_ids = [tracker_id for tracker_id, confirmations in completed_trackers.items()] - ended_trackers_keys.extend(list(trackers_no_end.keys())) + # Let's check + completed_trackers = responder.get_completed_trackers() + ended_trackers_keys = list(trackers_ir_resolved.keys()) + assert set(completed_trackers) == set(ended_trackers_keys) - assert set(completed_trackers_ids) == set(ended_trackers_keys) + # Generating 1 additional blocks should also include confirmed + generate_block_w_delay() + + completed_trackers = responder.get_completed_trackers() + ended_trackers_keys.extend(list(trackers_confirmed.keys())) + assert set(completed_trackers) == set(ended_trackers_keys) -def test_rebroadcast(db_manager, carrier, block_processor): - responder = Responder(db_manager, carrier, block_processor) +def test_get_expired_trackers(responder): + # expired trackers are those who's subscription has reached the expiry block and have not been confirmed. + # confirmed trackers that have reached their expiry will be kept until completed + current_block = responder.block_processor.get_block_count() + + # Lets first register the a couple of users + user1_id = get_random_value_hex(16) + responder.gatekeeper.registered_users[user1_id] = UserInfo( + available_slots=10, subscription_expiry=current_block + 15 + ) + user2_id = get_random_value_hex(16) + responder.gatekeeper.registered_users[user2_id] = UserInfo( + available_slots=10, subscription_expiry=current_block + 16 + ) + + # And create some trackers and add them to the corresponding user in the Gatekeeper + expired_unconfirmed_trackers_15 = {} + expired_unconfirmed_trackers_16 = {} + expired_confirmed_trackers_15 = {} + for _ in range(10): + uuid = uuid4().hex + dummy_tracker = create_dummy_tracker(penalty_rawtx=create_dummy_transaction().hex()) + dummy_tracker.user_id = user1_id + expired_unconfirmed_trackers_15[uuid] = dummy_tracker + responder.unconfirmed_txs.append(dummy_tracker.penalty_txid) + # Assume the appointment only took a single slot + responder.gatekeeper.registered_users[dummy_tracker.user_id].appointments[uuid] = 1 + + for _ in range(10): + uuid = uuid4().hex + dummy_tracker = create_dummy_tracker(penalty_rawtx=create_dummy_transaction().hex()) + dummy_tracker.user_id = user1_id + expired_confirmed_trackers_15[uuid] = dummy_tracker + # Assume the appointment only took a single slot + responder.gatekeeper.registered_users[dummy_tracker.user_id].appointments[uuid] = 1 + + for _ in range(10): + uuid = uuid4().hex + dummy_tracker = create_dummy_tracker(penalty_rawtx=create_dummy_transaction().hex()) + dummy_tracker.user_id = user2_id + expired_unconfirmed_trackers_16[uuid] = dummy_tracker + responder.unconfirmed_txs.append(dummy_tracker.penalty_txid) + # Assume the appointment only took a single slot + responder.gatekeeper.registered_users[dummy_tracker.user_id].appointments[uuid] = 1 + + all_trackers = {} + all_trackers.update(expired_confirmed_trackers_15) + all_trackers.update(expired_unconfirmed_trackers_15) + all_trackers.update(expired_unconfirmed_trackers_16) + + # Add everything to the Responder + for uuid, tracker in all_trackers.items(): + responder.trackers[uuid] = tracker.get_summary() + + # Currently nothing should be expired + assert responder.get_expired_trackers(current_block) == [] + + # 15 blocks (+ EXPIRY_DELTA) afterwards only user1's confirmed trackers should be expired + assert responder.get_expired_trackers(current_block + 15 + config.get("EXPIRY_DELTA")) == list( + expired_unconfirmed_trackers_15.keys() + ) + + # 1 (+ EXPIRY_DELTA) block after that user2's should be expired + assert responder.get_expired_trackers(current_block + 16 + config.get("EXPIRY_DELTA")) == list( + expired_unconfirmed_trackers_16.keys() + ) + + +def test_rebroadcast(db_manager, gatekeeper, carrier, block_processor): + responder = Responder(db_manager, gatekeeper, carrier, block_processor) chain_monitor = ChainMonitor(Queue(), responder.block_queue, block_processor, bitcoind_feed_params) chain_monitor.monitor_chain() @@ -446,17 +540,13 @@ def test_rebroadcast(db_manager, carrier, block_processor): # Rebroadcast calls add_response with retry=True. The tracker data is already in trackers. for i in range(20): uuid = uuid4().hex - locator, dispute_txid, penalty_txid, penalty_rawtx, appointment_end = create_dummy_tracker_data( + locator, dispute_txid, penalty_txid, penalty_rawtx, user_id = create_dummy_tracker_data( penalty_rawtx=create_dummy_transaction().hex() ) - tracker = TransactionTracker(locator, dispute_txid, penalty_txid, penalty_rawtx, appointment_end) + tracker = TransactionTracker(locator, dispute_txid, penalty_txid, penalty_rawtx, user_id) - responder.trackers[uuid] = { - "locator": locator, - "penalty_txid": penalty_txid, - "appointment_end": appointment_end, - } + responder.trackers[uuid] = {"locator": locator, "penalty_txid": penalty_txid, "user_id": user_id} # We need to add it to the db too responder.db_manager.create_triggered_appointment_flag(uuid) diff --git a/test/teos/unit/test_tools.py b/test/teos/unit/test_tools.py index 9a68a19..024ec45 100644 --- a/test/teos/unit/test_tools.py +++ b/test/teos/unit/test_tools.py @@ -13,13 +13,7 @@ def test_can_connect_to_bitcoind(): assert can_connect_to_bitcoind(bitcoind_connect_params) is True -# def test_can_connect_to_bitcoind_bitcoin_not_running(): -# # Kill the simulator thread and test the check fails -# bitcoind_process.kill() -# assert can_connect_to_bitcoind() is False - - -def test_bitcoin_cli(): +def test_bitcoin_cli(run_bitcoind): try: bitcoin_cli(bitcoind_connect_params).help() assert True diff --git a/test/teos/unit/test_users_dbm.py b/test/teos/unit/test_users_dbm.py index 5066561..605c6e4 100644 --- a/test/teos/unit/test_users_dbm.py +++ b/test/teos/unit/test_users_dbm.py @@ -1,15 +1,15 @@ -from teos.appointments_dbm import AppointmentsDBM +from teos.users_dbm import UsersDBM +from teos.gatekeeper import UserInfo from test.teos.unit.conftest import get_random_value_hex - stored_users = {} def open_create_db(db_path): try: - db_manager = AppointmentsDBM(db_path) + db_manager = UsersDBM(db_path) return db_manager @@ -19,27 +19,27 @@ def open_create_db(db_path): def test_store_user(user_db_manager): # Store user should work as long as the user_pk is properly formatted and data is a dictionary - user_pk = "02" + get_random_value_hex(32) - user_data = {"available_slots": 42} - stored_users[user_pk] = user_data - assert user_db_manager.store_user(user_pk, user_data) is True + user_id = "02" + get_random_value_hex(32) + user_info = UserInfo(available_slots=42, subscription_expiry=100) + stored_users[user_id] = user_info.to_dict() + assert user_db_manager.store_user(user_id, user_info.to_dict()) is True # Wrong pks should return False on adding - user_pk = "04" + get_random_value_hex(32) - user_data = {"available_slots": 42} - assert user_db_manager.store_user(user_pk, user_data) is False + user_id = "04" + get_random_value_hex(32) + user_info = UserInfo(available_slots=42, subscription_expiry=100) + assert user_db_manager.store_user(user_id, user_info.to_dict()) is False # Same for wrong types - assert user_db_manager.store_user(42, user_data) is False + assert user_db_manager.store_user(42, user_info.to_dict()) is False # And for wrong type user data - assert user_db_manager.store_user(user_pk, 42) is False + assert user_db_manager.store_user(user_id, 42) is False def test_load_user(user_db_manager): # Loading a user we have stored should work - for user_pk, user_data in stored_users.items(): - assert user_db_manager.load_user(user_pk) == user_data + for user_id, user_data in stored_users.items(): + assert user_db_manager.load_user(user_id) == user_data # Random keys should fail assert user_db_manager.load_user(get_random_value_hex(33)) is None @@ -50,11 +50,11 @@ def test_load_user(user_db_manager): def test_delete_user(user_db_manager): # Deleting an existing user should work - for user_pk, user_data in stored_users.items(): - assert user_db_manager.delete_user(user_pk) is True + for user_id, user_data in stored_users.items(): + assert user_db_manager.delete_user(user_id) is True - for user_pk, user_data in stored_users.items(): - assert user_db_manager.load_user(user_pk) is None + for user_id, user_data in stored_users.items(): + assert user_db_manager.load_user(user_id) is None # But deleting a non existing one should not fail assert user_db_manager.delete_user(get_random_value_hex(32)) is True @@ -70,10 +70,10 @@ def test_load_all_users(user_db_manager): # Adding some and checking we get them all for i in range(10): - user_pk = "02" + get_random_value_hex(32) - user_data = {"available_slots": i} - user_db_manager.store_user(user_pk, user_data) - stored_users[user_pk] = user_data + user_id = "02" + get_random_value_hex(32) + user_info = UserInfo(available_slots=42, subscription_expiry=100) + user_db_manager.store_user(user_id, user_info.to_dict()) + stored_users[user_id] = user_info.to_dict() all_users = user_db_manager.load_all_users() diff --git a/test/teos/unit/test_watcher.py b/test/teos/unit/test_watcher.py index 77ab810..25f84e5 100644 --- a/test/teos/unit/test_watcher.py +++ b/test/teos/unit/test_watcher.py @@ -4,22 +4,21 @@ from shutil import rmtree from threading import Thread from coincurve import PrivateKey -from teos import LOG_PREFIX from teos.carrier import Carrier -from teos.watcher import Watcher from teos.tools import bitcoin_cli from teos.responder import Responder +from teos.gatekeeper import UserInfo from teos.chain_monitor import ChainMonitor from teos.appointments_dbm import AppointmentsDBM from teos.block_processor import BlockProcessor +from teos.watcher import Watcher, AppointmentLimitReached +from teos.gatekeeper import Gatekeeper, AuthenticationFailure, NotEnoughSlots -import common.cryptographer -from common.logger import Logger from common.tools import compute_locator from common.cryptographer import Cryptographer from test.teos.unit.conftest import ( - generate_blocks, + generate_blocks_w_delay, generate_dummy_appointment, get_random_value_hex, generate_keypair, @@ -28,12 +27,7 @@ from test.teos.unit.conftest import ( bitcoind_connect_params, ) -common.cryptographer.logger = Logger(actor="Cryptographer", log_name_prefix=LOG_PREFIX) - - APPOINTMENTS = 5 -START_TIME_OFFSET = 1 -END_TIME_OFFSET = 1 TEST_SET_SIZE = 200 config = get_config() @@ -56,14 +50,12 @@ def temp_db_manager(): @pytest.fixture(scope="module") -def watcher(db_manager): +def watcher(db_manager, gatekeeper): block_processor = BlockProcessor(bitcoind_connect_params) carrier = Carrier(bitcoind_connect_params) - responder = Responder(db_manager, carrier, block_processor) - watcher = Watcher( - db_manager, block_processor, responder, signing_key.to_der(), MAX_APPOINTMENTS, config.get("EXPIRY_DELTA") - ) + responder = Responder(db_manager, gatekeeper, carrier, block_processor) + watcher = Watcher(db_manager, gatekeeper, block_processor, responder, signing_key.to_der(), MAX_APPOINTMENTS) chain_monitor = ChainMonitor( watcher.block_queue, watcher.responder.block_queue, block_processor, bitcoind_feed_params @@ -89,9 +81,7 @@ def create_appointments(n): dispute_txs = [] for i in range(n): - appointment, dispute_tx = generate_dummy_appointment( - start_time_offset=START_TIME_OFFSET, end_time_offset=END_TIME_OFFSET - ) + appointment, dispute_tx = generate_dummy_appointment() uuid = uuid4().hex appointments[uuid] = appointment @@ -105,85 +95,107 @@ def test_init(run_bitcoind, watcher): assert isinstance(watcher.appointments, dict) and len(watcher.appointments) == 0 assert isinstance(watcher.locator_uuid_map, dict) and len(watcher.locator_uuid_map) == 0 assert watcher.block_queue.empty() + assert isinstance(watcher.db_manager, AppointmentsDBM) + assert isinstance(watcher.gatekeeper, Gatekeeper) assert isinstance(watcher.block_processor, BlockProcessor) assert isinstance(watcher.responder, Responder) assert isinstance(watcher.max_appointments, int) - assert isinstance(watcher.expiry_delta, int) assert isinstance(watcher.signing_key, PrivateKey) -def test_get_appointment_summary(watcher): - # get_appointment_summary returns an appointment summary if found, else None. - random_uuid = get_random_value_hex(16) - appointment_summary = {"locator": get_random_value_hex(16), "end_time": 10, "size": 200} - watcher.appointments[random_uuid] = appointment_summary - assert watcher.get_appointment_summary(random_uuid) == appointment_summary +def test_add_appointment_non_registered(watcher): + # Appointments from non-registered users should fail + user_sk, user_pk = generate_keypair() - # Requesting a non-existing appointment - assert watcher.get_appointment_summary(get_random_value_hex(16)) is None + appointment, dispute_tx = generate_dummy_appointment() + appointment_signature = Cryptographer.sign(appointment.serialize(), user_sk) + + with pytest.raises(AuthenticationFailure, match="User not found"): + watcher.add_appointment(appointment, appointment_signature) + + +def test_add_appointment_no_slots(watcher): + # Appointments from register users with no available slots should aso fail + user_sk, user_pk = generate_keypair() + user_id = Cryptographer.get_compressed_pk(user_pk) + watcher.gatekeeper.registered_users[user_id] = UserInfo(available_slots=0, subscription_expiry=10) + + appointment, dispute_tx = generate_dummy_appointment() + appointment_signature = Cryptographer.sign(appointment.serialize(), user_sk) + + with pytest.raises(NotEnoughSlots): + watcher.add_appointment(appointment, appointment_signature) def test_add_appointment(watcher): - # We should be able to add appointments up to the limit - for _ in range(10): - appointment, dispute_tx = generate_dummy_appointment( - start_time_offset=START_TIME_OFFSET, end_time_offset=END_TIME_OFFSET - ) - user_pk = get_random_value_hex(33) + # Simulate the user is registered + user_sk, user_pk = generate_keypair() + available_slots = 100 + user_id = Cryptographer.get_compressed_pk(user_pk) + watcher.gatekeeper.registered_users[user_id] = UserInfo(available_slots=available_slots, subscription_expiry=10) - added_appointment, sig = watcher.add_appointment(appointment, user_pk) + appointment, dispute_tx = generate_dummy_appointment() + appointment_signature = Cryptographer.sign(appointment.serialize(), user_sk) - assert added_appointment is True - assert Cryptographer.verify_rpk( - watcher.signing_key.public_key, Cryptographer.recover_pk(appointment.serialize(), sig) - ) + response = watcher.add_appointment(appointment, appointment_signature) + assert response.get("locator") == appointment.locator + assert Cryptographer.get_compressed_pk(watcher.signing_key.public_key) == Cryptographer.get_compressed_pk( + Cryptographer.recover_pk(appointment.serialize(), response.get("signature")) + ) + assert response.get("available_slots") == available_slots - 1 - # Check that we can also add an already added appointment (same locator) - added_appointment, sig = watcher.add_appointment(appointment, user_pk) + # Check that we can also add an already added appointment (same locator) + response = watcher.add_appointment(appointment, appointment_signature) + assert response.get("locator") == appointment.locator + assert Cryptographer.get_compressed_pk(watcher.signing_key.public_key) == Cryptographer.get_compressed_pk( + Cryptographer.recover_pk(appointment.serialize(), response.get("signature")) + ) + # The slot count should not have been reduced and only one copy is kept. + assert response.get("available_slots") == available_slots - 1 + assert len(watcher.locator_uuid_map[appointment.locator]) == 1 - assert added_appointment is True - assert Cryptographer.verify_rpk( - watcher.signing_key.public_key, Cryptographer.recover_pk(appointment.serialize(), sig) - ) + # If two appointments with the same locator come from different users, they are kept. + another_user_sk, another_user_pk = generate_keypair() + another_user_id = Cryptographer.get_compressed_pk(another_user_pk) + watcher.gatekeeper.registered_users[another_user_id] = UserInfo( + available_slots=available_slots, subscription_expiry=10 + ) - # If two appointments with the same locator from the same user are added, they are overwritten, but if they come - # from different users, they are kept. - assert len(watcher.locator_uuid_map[appointment.locator]) == 1 - - different_user_pk = get_random_value_hex(33) - added_appointment, sig = watcher.add_appointment(appointment, different_user_pk) - assert added_appointment is True - assert Cryptographer.verify_rpk( - watcher.signing_key.public_key, Cryptographer.recover_pk(appointment.serialize(), sig) - ) - assert len(watcher.locator_uuid_map[appointment.locator]) == 2 + appointment_signature = Cryptographer.sign(appointment.serialize(), another_user_sk) + response = watcher.add_appointment(appointment, appointment_signature) + assert response.get("locator") == appointment.locator + assert Cryptographer.get_compressed_pk(watcher.signing_key.public_key) == Cryptographer.get_compressed_pk( + Cryptographer.recover_pk(appointment.serialize(), response.get("signature")) + ) + assert response.get("available_slots") == available_slots - 1 + assert len(watcher.locator_uuid_map[appointment.locator]) == 2 def test_add_too_many_appointments(watcher): - # Any appointment on top of those should fail + # Simulate the user is registered + user_sk, user_pk = generate_keypair() + available_slots = 100 + user_id = Cryptographer.get_compressed_pk(user_pk) + watcher.gatekeeper.registered_users[user_id] = UserInfo(available_slots=available_slots, subscription_expiry=10) + + # Appointments on top of the limit should be rejected watcher.appointments = dict() - for _ in range(MAX_APPOINTMENTS): - appointment, dispute_tx = generate_dummy_appointment( - start_time_offset=START_TIME_OFFSET, end_time_offset=END_TIME_OFFSET + for i in range(MAX_APPOINTMENTS): + appointment, dispute_tx = generate_dummy_appointment() + appointment_signature = Cryptographer.sign(appointment.serialize(), user_sk) + + response = watcher.add_appointment(appointment, appointment_signature) + assert response.get("locator") == appointment.locator + assert Cryptographer.get_compressed_pk(watcher.signing_key.public_key) == Cryptographer.get_compressed_pk( + Cryptographer.recover_pk(appointment.serialize(), response.get("signature")) ) - user_pk = get_random_value_hex(33) + assert response.get("available_slots") == available_slots - (i + 1) - added_appointment, sig = watcher.add_appointment(appointment, user_pk) - - assert added_appointment is True - assert Cryptographer.verify_rpk( - watcher.signing_key.public_key, Cryptographer.recover_pk(appointment.serialize(), sig) - ) - - appointment, dispute_tx = generate_dummy_appointment( - start_time_offset=START_TIME_OFFSET, end_time_offset=END_TIME_OFFSET - ) - user_pk = get_random_value_hex(33) - added_appointment, sig = watcher.add_appointment(appointment, user_pk) - - assert added_appointment is False - assert sig is None + with pytest.raises(AppointmentLimitReached): + appointment, dispute_tx = generate_dummy_appointment() + appointment_signature = Cryptographer.sign(appointment.serialize(), user_sk) + watcher.add_appointment(appointment, appointment_signature) def test_do_watch(watcher, temp_db_manager): @@ -195,9 +207,19 @@ def test_do_watch(watcher, temp_db_manager): # Set the data into the Watcher and in the db watcher.locator_uuid_map = locator_uuid_map watcher.appointments = {} + watcher.gatekeeper.registered_users = {} + # Simulate a register (times out in 10 bocks) + user_id = get_random_value_hex(16) + watcher.gatekeeper.registered_users[user_id] = UserInfo( + available_slots=100, subscription_expiry=watcher.block_processor.get_block_count() + 10 + ) + + # Add the appointments for uuid, appointment in appointments.items(): - watcher.appointments[uuid] = {"locator": appointment.locator, "end_time": appointment.end_time, "size": 200} + watcher.appointments[uuid] = {"locator": appointment.locator, "user_id": user_id} + # Assume the appointment only takes one slot + watcher.gatekeeper.registered_users[user_id].appointments[uuid] = 1 watcher.db_manager.store_watcher_appointment(uuid, appointment.to_dict()) watcher.db_manager.create_append_locator_map(appointment.locator, uuid) @@ -208,17 +230,21 @@ def test_do_watch(watcher, temp_db_manager): for dispute_tx in dispute_txs[:2]: bitcoin_cli(bitcoind_connect_params).sendrawtransaction(dispute_tx) - # After generating enough blocks, the number of appointments should have reduced by two - generate_blocks(START_TIME_OFFSET + END_TIME_OFFSET) + # After generating a block, the appointment count should have been reduced by 2 (two breaches) + generate_blocks_w_delay(1) assert len(watcher.appointments) == APPOINTMENTS - 2 - # The rest of appointments will timeout after the end (2) + EXPIRY_DELTA + # The rest of appointments will timeout after the subscription times-out (9 more blocks) + EXPIRY_DELTA # Wait for an additional block to be safe - generate_blocks(config.get("EXPIRY_DELTA") + START_TIME_OFFSET + END_TIME_OFFSET) - + generate_blocks_w_delay(10 + config.get("EXPIRY_DELTA")) assert len(watcher.appointments) == 0 + # Check that they are not in the Gatekeeper either, only the two that passed to the Responder should remain + assert len(watcher.gatekeeper.registered_users[user_id].appointments) == 2 + + # FIXME: We should also add cases where the transactions are invalid. bitcoind_mock needs to be extended for this. + def test_get_breaches(watcher, txids, locator_uuid_map): watcher.locator_uuid_map = locator_uuid_map @@ -239,7 +265,7 @@ def test_get_breaches_random_data(watcher, locator_uuid_map): assert len(potential_breaches) == 0 -def test_filter_valid_breaches_random_data(watcher): +def test_filter_breaches_random_data(watcher): appointments = {} locator_uuid_map = {} breaches = {} @@ -247,7 +273,7 @@ def test_filter_valid_breaches_random_data(watcher): for i in range(TEST_SET_SIZE): dummy_appointment, _ = generate_dummy_appointment() uuid = uuid4().hex - appointments[uuid] = {"locator": dummy_appointment.locator, "end_time": dummy_appointment.end_time} + appointments[uuid] = {"locator": dummy_appointment.locator, "user_id": dummy_appointment.user_id} watcher.db_manager.store_watcher_appointment(uuid, dummy_appointment.to_dict()) watcher.db_manager.create_append_locator_map(dummy_appointment.locator, uuid) @@ -260,7 +286,7 @@ def test_filter_valid_breaches_random_data(watcher): watcher.locator_uuid_map = locator_uuid_map watcher.appointments = appointments - valid_breaches, invalid_breaches = watcher.filter_valid_breaches(breaches) + valid_breaches, invalid_breaches = watcher.filter_breaches(breaches) # We have "triggered" TEST_SET_SIZE/2 breaches, all of them invalid. assert len(valid_breaches) == 0 and len(invalid_breaches) == TEST_SET_SIZE / 2 @@ -278,7 +304,7 @@ def test_filter_valid_breaches(watcher): ) dummy_appointment, _ = generate_dummy_appointment() - dummy_appointment.encrypted_blob.data = encrypted_blob + dummy_appointment.encrypted_blob = encrypted_blob dummy_appointment.locator = compute_locator(dispute_txid) uuid = uuid4().hex @@ -287,13 +313,13 @@ def test_filter_valid_breaches(watcher): breaches = {dummy_appointment.locator: dispute_txid} for uuid, appointment in appointments.items(): - watcher.appointments[uuid] = {"locator": appointment.locator, "end_time": appointment.end_time} + watcher.appointments[uuid] = {"locator": appointment.locator, "user_id": appointment.user_id} watcher.db_manager.store_watcher_appointment(uuid, dummy_appointment.to_dict()) watcher.db_manager.create_append_locator_map(dummy_appointment.locator, uuid) watcher.locator_uuid_map = locator_uuid_map - valid_breaches, invalid_breaches = watcher.filter_valid_breaches(breaches) + valid_breaches, invalid_breaches = watcher.filter_breaches(breaches) # We have "triggered" a single breach and it was valid. assert len(invalid_breaches) == 0 and len(valid_breaches) == 1 diff --git a/watchtower-plugin/arg_parser.py b/watchtower-plugin/arg_parser.py new file mode 100644 index 0000000..6236c92 --- /dev/null +++ b/watchtower-plugin/arg_parser.py @@ -0,0 +1,116 @@ +import re + +from common.tools import is_compressed_pk, is_locator, is_256b_hex_str +from common.exceptions import InvalidParameter + + +def parse_register_arguments(tower_id, host, port, config): + """ + Parses the arguments of the register command and checks that they are correct. + + Args: + tower_id (:obj:`str`): the identifier of the tower to connect to (a compressed public key). + host (:obj:`str`): the ip or hostname to connect to, optional. + host (:obj:`int`): the port to connect to, optional. + config: (:obj:`dict`): the configuration dictionary. + + Returns: + :obj:`tuple`: the tower id and tower network address. + + Raises: + :obj:`common.exceptions.InvalidParameter`: if any of the parameters is wrong or missing. + """ + + if not isinstance(tower_id, str): + raise InvalidParameter(f"tower id must be a compressed public key (33-byte hex value) not {str(tower_id)}") + + # tower_id is of the form tower_id@[ip][:][port] + if "@" in tower_id: + if not (host and port): + tower_id, tower_netaddr = tower_id.split("@") + + if not tower_netaddr: + raise InvalidParameter("no tower endpoint was provided") + + # Only host was specified or colons where specified but not port + if ":" not in tower_netaddr or tower_netaddr.endswith(":"): + tower_netaddr = f"{tower_netaddr}:{config.get('DEFAULT_PORT')}" + + else: + raise InvalidParameter("cannot specify host as both xxx@yyy and separate arguments") + + # host was specified, but no port, defaulting + elif host: + tower_netaddr = f"{host}:{config.get('DEFAULT_PORT')}" + + # host and port specified + elif host and port: + tower_netaddr = f"{host}:{port}" + + else: + raise InvalidParameter("tower host is missing") + + if not is_compressed_pk(tower_id): + raise InvalidParameter("tower id must be a compressed public key (33-byte hex value)") + + return tower_id, tower_netaddr + + +def parse_get_appointment_arguments(tower_id, locator): + """ + Parses the arguments of the get_appointment command and checks that they are correct. + + Args: + tower_id (:obj:`str`): the identifier of the tower to connect to (a compressed public key). + locator (:obj:`str`): the locator of the appointment to query the tower about. + + Returns: + :obj:`tuple`: the tower id and appointment locator. + + Raises: + :obj:`common.exceptions.InvalidParameter`: if any of the parameters is wrong or missing. + """ + + if not is_compressed_pk(tower_id): + raise InvalidParameter("tower id must be a compressed public key (33-byte hex value)") + + if not is_locator(locator): + raise InvalidParameter("The provided locator is not valid", locator=locator) + + return tower_id, locator + + +def parse_add_appointment_arguments(kwargs): + """ + Parses the arguments of the add_appointment command and checks that they are correct. + + The expected arguments are a commitment transaction id (32-byte hex string) and the penalty transaction. + + Args: + kwargs (:obj:`dict`): a dictionary of arguments. + + Returns: + :obj:`tuple`: the commitment transaction id and the penalty transaction. + + Raises: + :obj:`common.exceptions.InvalidParameter`: if any of the parameters is wrong or missing. + """ + + # Arguments to add_appointment come from c-lightning and they have been sanitised. Checking this just in case. + commitment_txid = kwargs.get("commitment_txid") + penalty_tx = kwargs.get("penalty_tx") + + if commitment_txid is None: + raise InvalidParameter("missing required parameter: commitment_txid") + + if penalty_tx is None: + raise InvalidParameter("missing required parameter: penalty_tx") + + if not is_256b_hex_str(commitment_txid): + raise InvalidParameter("commitment_txid has invalid format") + + # Checking the basic stuff for the penalty transaction for now + if type(penalty_tx) is not str or re.search(r"^[0-9A-Fa-f]+$", penalty_tx) is None: + raise InvalidParameter("penalty_tx has invalid format") + + return commitment_txid, penalty_tx diff --git a/watchtower-plugin/exceptions.py b/watchtower-plugin/exceptions.py new file mode 100644 index 0000000..fd0d395 --- /dev/null +++ b/watchtower-plugin/exceptions.py @@ -0,0 +1,9 @@ +from common.exceptions import BasicException + + +class TowerConnectionError(BasicException): + """Raised when the tower responds with an error""" + + +class TowerResponseError(BasicException): + """Raised when the tower responds with an error""" diff --git a/watchtower-plugin/keys.py b/watchtower-plugin/keys.py new file mode 100644 index 0000000..587f8e8 --- /dev/null +++ b/watchtower-plugin/keys.py @@ -0,0 +1,82 @@ +import os.path +from pathlib import Path +from coincurve import PrivateKey + +from common.exceptions import InvalidKey +from common.cryptographer import Cryptographer + + +def save_key(sk, filename): + """ + Saves the secret key on disk. + + Args: + sk (:obj:`EllipticCurvePrivateKey`): a private key file to be saved on disk. + filename (:obj:`str`): the name that will be given to the key file. + """ + + with open(filename, "wb") as der_out: + der_out.write(sk.to_der()) + + +def generate_keys(data_dir): + """ + Generates a key pair for the client. + + Args: + data_dir (:obj:`str`): path to data directory where the keys will be stored. + + Returns: + :obj:`tuple`: a tuple containing a ``PrivateKey`` and a ``str`` representing the client sk and compressed pk + respectively. + + Raises: + :obj:`FileExistsError`: if the key pair already exists in the given directory. + """ + + # Create the output folder it it does not exist (and all the parents if they don't either) + Path(data_dir).mkdir(parents=True, exist_ok=True) + sk_file_name = os.path.join(data_dir, "sk.der") + + if os.path.exists(sk_file_name): + raise FileExistsError("The client key pair already exists") + + sk = PrivateKey() + pk = sk.public_key + save_key(sk, sk_file_name) + + return sk, Cryptographer.get_compressed_pk(pk) + + +def load_keys(data_dir): + """ + Loads a the client key pair. + + Args: + data_dir (:obj:`str`): path to data directory where the keys are stored. + + Returns: + :obj:`tuple`: a tuple containing a ``PrivateKey`` and a ``str`` representing the client sk and compressed pk + respectively. + + Raises: + :obj:`InvalidKey `: if any of the keys is invalid or cannot be loaded. + """ + + if not isinstance(data_dir, str): + raise ValueError("Invalid data_dir. Please check your settings") + + sk_file_path = os.path.join(data_dir, "sk.der") + + cli_sk_der = Cryptographer.load_key_file(sk_file_path) + cli_sk = Cryptographer.load_private_key_der(cli_sk_der) + + if cli_sk is None: + raise InvalidKey("Client private key is invalid or cannot be parsed") + + compressed_cli_pk = Cryptographer.get_compressed_pk(cli_sk.public_key) + + if compressed_cli_pk is None: + raise InvalidKey("Client public key cannot be loaded") + + return cli_sk, compressed_cli_pk diff --git a/watchtower-plugin/net/http.py b/watchtower-plugin/net/http.py new file mode 100644 index 0000000..ea31833 --- /dev/null +++ b/watchtower-plugin/net/http.py @@ -0,0 +1,142 @@ +import json +import requests +from requests import ConnectionError, ConnectTimeout +from requests.exceptions import MissingSchema, InvalidSchema, InvalidURL + +from common import errors +from common import constants +from common.appointment import Appointment +from common.exceptions import SignatureError +from common.cryptographer import Cryptographer + +from exceptions import TowerConnectionError, TowerResponseError + + +def add_appointment(plugin, tower_id, tower, appointment_dict, signature): + try: + plugin.log(f"Sending appointment {appointment_dict.get('locator')} to {tower_id}") + response = send_appointment(tower_id, tower, appointment_dict, signature) + plugin.log(f"Appointment accepted and signed by {tower_id}") + plugin.log(f"Remaining slots: {response.get('available_slots')}") + + # # TODO: Not storing the whole appointments for now. The node can recreate all the data if needed. + # # DISCUSS: It may be worth checking that the available slots match instead of blindly trusting. + return response.get("signature"), response.get("available_slots") + + except SignatureError as e: + plugin.log(f"{tower_id} is misbehaving, not using it any longer") + raise e + + except TowerConnectionError as e: + plugin.log(f"{tower_id} cannot be reached") + + raise e + + except TowerResponseError as e: + data = e.kwargs.get("data") + status_code = e.kwargs.get("status_code") + + if data and status_code == constants.HTTP_BAD_REQUEST: + if data.get("error_code") == errors.APPOINTMENT_INVALID_SIGNATURE_OR_INSUFFICIENT_SLOTS: + message = f"There is a subscription issue with {tower_id}" + raise TowerResponseError(message, status="subscription error") + + elif data.get("error_code") >= errors.INVALID_REQUEST_FORMAT: + message = f"Appointment sent to {tower_id} is invalid" + raise TowerResponseError(message, status="reachable", invalid_appointment=True) + + elif status_code == constants.HTTP_SERVICE_UNAVAILABLE: + # Flag appointment for retry + message = f"{tower_id} is temporarily unavailable" + + raise TowerResponseError(message, status="temporarily unreachable") + + # Log unexpected behaviour without raising + plugin.log(str(e), level="warn") + + +def send_appointment(tower_id, tower, appointment_dict, signature): + data = {"appointment": appointment_dict, "signature": signature} + + add_appointment_endpoint = f"{tower.netaddr}/add_appointment" + response = process_post_response(post_request(data, add_appointment_endpoint, tower_id)) + + tower_signature = response.get("signature") + # Check that the server signed the appointment as it should. + if not tower_signature: + raise SignatureError("The response does not contain the signature of the appointment", signature=None) + + rpk = Cryptographer.recover_pk(Appointment.from_dict(appointment_dict).serialize(), tower_signature) + recovered_id = Cryptographer.get_compressed_pk(rpk) + if tower_id != recovered_id: + raise SignatureError( + "The returned appointment's signature is invalid", + tower_id=tower_id, + recovered_id=recovered_id, + signature=tower_signature, + ) + + return response + + +def post_request(data, endpoint, tower_id): + """ + Sends a post request to the tower. + + Args: + data (:obj:`dict`): a dictionary containing the data to be posted. + endpoint (:obj:`str`): the endpoint to send the post request. + tower_id (:obj:`str`): the identifier of the tower to connect to (a compressed public key). + + Returns: + :obj:`dict`: a json-encoded dictionary with the server response if the data can be posted. + + Raises: + :obj:`ConnectionError`: if the client cannot connect to the tower. + """ + + try: + return requests.post(url=endpoint, json=data, timeout=5) + + except ConnectTimeout: + message = f"Cannot connect to {tower_id}. Connection timeout" + + except ConnectionError: + message = f"Cannot connect to {tower_id}. Tower cannot be reached" + + except (InvalidSchema, MissingSchema, InvalidURL): + message = f"Invalid URL. No schema, or invalid schema, found (url={endpoint}, tower_id={tower_id})" + + raise TowerConnectionError(message) + + +def process_post_response(response): + """ + Processes the server response to a post request. + + Args: + response (:obj:`requests.models.Response`): a ``Response`` object obtained from the request. + + Returns: + :obj:`dict`: a dictionary containing the tower's response data if the response type is + ``HTTP_OK``. + + Raises: + :obj:`TowerResponseError `: if the tower responded with an error, or the + response was invalid. + """ + + try: + response_json = response.json() + + except (json.JSONDecodeError, AttributeError): + raise TowerResponseError( + "The server returned a non-JSON response", status_code=response.status_code, reason=response.reason + ) + + if response.status_code not in [constants.HTTP_OK, constants.HTTP_NOT_FOUND]: + raise TowerResponseError( + "The server returned an error", status_code=response.status_code, reason=response.reason, data=response_json + ) + + return response_json diff --git a/watchtower-plugin/requirements.txt b/watchtower-plugin/requirements.txt new file mode 100644 index 0000000..84ba49d --- /dev/null +++ b/watchtower-plugin/requirements.txt @@ -0,0 +1,7 @@ +pyln-client +requests +coincurve +cryptography==2.8 +pyzbase32 +plyvel +backoff \ No newline at end of file diff --git a/watchtower-plugin/retrier.py b/watchtower-plugin/retrier.py new file mode 100644 index 0000000..f8c4751 --- /dev/null +++ b/watchtower-plugin/retrier.py @@ -0,0 +1,148 @@ +import backoff +from threading import Thread + +from common.exceptions import SignatureError + +from net.http import add_appointment +from exceptions import TowerConnectionError, TowerResponseError + + +MAX_RETRIES = None + + +def check_retry(status): + """ + Checks is the job needs to be retried. Jobs are retried if max_retries is not reached and the tower status is + temporarily unreachable. + + Args: + status (:obj:`str`): the tower status. + + Returns: + :obj:`bool`: True is the status is "temporarily unreachable", False otherwise. + """ + return status == "temporarily unreachable" + + +def on_backoff(details): + """ + Function called when backing off after a retry. Logs data regarding the retry. + Args: + details: the retry details (check backoff library for more info). + """ + plugin = details.get("args")[1] + tower_id = details.get("args")[2] + plugin.log(f"Retry {details.get('tries')} failed for tower {tower_id}, backing off") + + +def on_giveup(details): + """ + Function called when giving up after the last retry. Logs data regarding the retry and flags the tower as + unreachable. + + Args: + details: the retry details (check backoff library for more info). + """ + plugin = details.get("args")[1] + tower_id = details.get("args")[2] + + plugin.log(f"Max retries reached, abandoning tower {tower_id}") + + tower_update = {"status": "unreachable"} + plugin.wt_client.update_tower_state(tower_id, tower_update) + + +def set_max_retries(max_retries): + """Workaround to set max retries from Retrier to the backoff.on_predicate decorator""" + global MAX_RETRIES + MAX_RETRIES = max_retries + + +def max_retries(): + """Workaround to set max retries from Retrier to the backoff.on_predicate decorator""" + return MAX_RETRIES + + +class Retrier: + """ + The Retrier is in charge of the retry process for appointments that were sent to towers that were temporarily + unreachable. + + Args: + max_retries (:obj:`int`): the maximum number of times that a tower will be retried. + temp_unreachable_towers (:obj:`Queue`): a queue of temporarily unreachable towers populated by the plugin on + failing to deliver an appointment. + """ + + def __init__(self, max_retries, temp_unreachable_towers): + self.temp_unreachable_towers = temp_unreachable_towers + set_max_retries(max_retries) + + def manage_retry(self, plugin): + """ + Listens to the temporarily unreachable towers queue and creates a thread to manage each tower it gets. + + Args: + plugin (:obj:`Plugin`): the plugin object. + """ + + while True: + tower_id = self.temp_unreachable_towers.get() + tower = plugin.wt_client.towers[tower_id] + + Thread(target=self.do_retry, args=[plugin, tower_id, tower], daemon=True).start() + + @backoff.on_predicate(backoff.expo, check_retry, max_tries=max_retries, on_backoff=on_backoff, on_giveup=on_giveup) + def do_retry(self, plugin, tower_id, tower): + """ + Retries to send a list of pending appointments to a temporarily unreachable tower. This function is managed by + manage_retries and run in a different thread per tower. + + For every pending appointment the worker thread tries to send the data to the tower. If the tower keeps being + unreachable, the job is retries up to MAX_RETRIES. If MAX_RETRIES is reached, the worker thread gives up and the + tower is flagged as unreachable. + + Args: + plugin (:obj:`Plugin`): the plugin object. + tower_id (:obj:`str`): the id of the tower managed by the thread. + tower: (:obj:`TowerSummary`): the tower data. + + Returns: + :obj:`str`: the tower status if it is not reachable. + """ + + for appointment_dict, signature in plugin.wt_client.towers[tower_id].pending_appointments: + tower_update = {} + try: + tower_signature, available_slots = add_appointment(plugin, tower_id, tower, appointment_dict, signature) + tower_update["status"] = "reachable" + tower_update["appointment"] = (appointment_dict.get("locator"), tower_signature) + tower_update["available_slots"] = available_slots + + except SignatureError as e: + tower_update["status"] = "misbehaving" + tower_update["misbehaving_proof"] = { + "appointment": appointment_dict, + "signature": e.kwargs.get("signature"), + "recovered_id": e.kwargs.get("recovered_id"), + } + + except TowerConnectionError: + tower_update["status"] = "temporarily unreachable" + + except TowerResponseError as e: + tower_update["status"] = e.kwargs.get("status") + + if e.kwargs.get("invalid_appointment"): + tower_update["invalid_appointment"] = (appointment_dict, signature) + + if tower_update["status"] in ["reachable", "misbehaving"]: + tower_update["pending_appointment"] = ([appointment_dict, signature], "remove") + + if tower_update["status"] != "temporarily unreachable": + # Update memory and TowersDB + plugin.wt_client.update_tower_state(tower_id, tower_update) + + # Continue looping if reachable, return for either retry or stop otherwise + if tower_update["status"] != "reachable": + return tower_update.get("status") diff --git a/watchtower-plugin/template.conf b/watchtower-plugin/template.conf new file mode 100644 index 0000000..85c8305 --- /dev/null +++ b/watchtower-plugin/template.conf @@ -0,0 +1,4 @@ +[teos] +api_port = 9814 +max_retries = 30 + diff --git a/watchtower-plugin/test_watchtower.py b/watchtower-plugin/test_watchtower.py new file mode 100644 index 0000000..aa15664 --- /dev/null +++ b/watchtower-plugin/test_watchtower.py @@ -0,0 +1,423 @@ +import random +import configparser +from time import sleep +from coincurve import PrivateKey +from threading import Thread +from flask import Flask, request, jsonify +from pyln.testing.fixtures import * # noqa: F401,F403 + +from common import errors +from common import constants +from common.appointment import Appointment +from common.cryptographer import Cryptographer + +plugin_path = os.path.join(os.path.dirname(__file__), "watchtower.py") + +tower_netaddr = "localhost" +tower_port = "1234" +tower_sk = PrivateKey() +tower_id = Cryptographer.get_compressed_pk(tower_sk.public_key) + +mocked_return = None + + +class TowerMock: + def __init__(self, tower_sk): + self.sk = tower_sk + self.users = {} + self.app = Flask(__name__) + + # Adds all the routes to the functions listed above. + routes = { + "/register": (self.register, ["POST"]), + "/add_appointment": (self.add_appointment, ["POST"]), + "/get_appointment": (self.get_appointment, ["POST"]), + } + + for url, params in routes.items(): + self.app.add_url_rule(url, view_func=params[0], methods=params[1]) + + # Setting Flask log to ERROR only so it does not mess with our logging. Also disabling flask initial messages + logging.getLogger("werkzeug").setLevel(logging.ERROR) + os.environ["WERKZEUG_RUN_MAIN"] = "true" + + # Thread(target=app.run, kwargs={"host": tower_netaddr, "port": tower_port}, daemon=True).start() + + def register(self): + user_id = request.get_json().get("public_key") + + if user_id not in self.users: + self.users[user_id] = {"available_slots": 100, "subscription_expiry": 4320} + else: + self.users[user_id]["available_slots"] += 100 + self.users[user_id]["subscription_expiry"] = 4320 + + rcode = constants.HTTP_OK + response = { + "public_key": user_id, + "available_slots": self.users[user_id].get("available_slots"), + "subscription_expiry": self.users[user_id].get("subscription_expiry"), + } + + return response, rcode + + def add_appointment(self): + appointment = Appointment.from_dict(request.get_json().get("appointment")) + user_id = Cryptographer.get_compressed_pk( + Cryptographer.recover_pk(appointment.serialize(), request.get_json().get("signature")) + ) + + if mocked_return == "success": + response, rtype = add_appointment_success(appointment, self.users[user_id], self.sk) + elif mocked_return == "reject_no_slots": + response, rtype = add_appointment_reject_no_slots() + elif mocked_return == "reject_invalid": + response, rtype = add_appointment_reject_invalid() + elif mocked_return == "misbehaving_tower": + response, rtype = add_appointment_misbehaving_tower(appointment, self.users[user_id], self.sk) + else: + response, rtype = add_appointment_service_unavailable() + + return jsonify(response), rtype + + def get_appointment(self): + locator = request.get_json().get("locator") + message = f"get appointment {locator}" + user_id = Cryptographer.get_compressed_pk( + Cryptographer.recover_pk(message.encode(), request.get_json().get("signature")) + ) + + if ( + user_id in self.users + and "appointments" in self.users[user_id] + and locator in self.users[user_id]["appointments"] + ): + rcode = constants.HTTP_OK + response = self.users[user_id]["appointments"][locator] + response["status"] = "being_watched" + + else: + rcode = constants.HTTP_NOT_FOUND + response = {"locator": locator, "status": "not_found"} + + return jsonify(response), rcode + + +def add_appointment_success(appointment, user, tower_sk): + rcode = constants.HTTP_OK + response = { + "locator": appointment.locator, + "signature": Cryptographer.sign(appointment.serialize(), tower_sk), + "available_slots": user.get("available_slots") - 1, + "subscription_expiry": user.get("subscription_expiry"), + } + + user["available_slots"] = response.get("available_slots") + if user.get("appointments"): + user["appointments"][appointment.locator] = appointment.to_dict() + else: + user["appointments"] = {appointment.locator: appointment.to_dict()} + + return response, rcode + + +def add_appointment_reject_no_slots(): + # This covers non-registered users and users with no available slots + + rcode = constants.HTTP_BAD_REQUEST + response = { + "error": "appointment rejected. Invalid signature or user does not have enough slots available", + "error_code": errors.APPOINTMENT_INVALID_SIGNATURE_OR_INSUFFICIENT_SLOTS, + } + + return response, rcode + + +def add_appointment_reject_invalid(): + # This covers malformed appointments (e.g. no json) and appointments with invalid data + # Pick whatever reason, should not matter + + rcode = constants.HTTP_BAD_REQUEST + response = {"error": "appointment rejected", "error_code": errors.APPOINTMENT_EMPTY_FIELD} + + return response, rcode + + +def add_appointment_service_unavailable(): + # This covers any reason why the service may be unavailable (e.g. tower run out of free slots) + + rcode = constants.HTTP_SERVICE_UNAVAILABLE + response = {"error": "appointment rejected"} + + return response, rcode + + +def add_appointment_misbehaving_tower(appointment, user, tower_sk): + # This covers a tower signing with invalid keys + wrong_sk = PrivateKey.from_hex(get_random_value_hex(32)) + + wrong_sig = Cryptographer.sign(appointment.serialize(), wrong_sk) + response, rcode = add_appointment_success(appointment, user, tower_sk) + user["appointments"][appointment.locator]["signature"] = wrong_sig + response["signature"] = wrong_sig + + return response, rcode + + +def get_random_value_hex(nbytes): + pseudo_random_value = random.getrandbits(8 * nbytes) + prv_hex = "{:x}".format(pseudo_random_value) + return prv_hex.zfill(2 * nbytes) + + +@pytest.fixture(scope="session", autouse=True) +def init_tower(): + os.environ["TOWERS_DATA_DIR"] = "/tmp/watchtower" + config = configparser.ConfigParser() + config["general"] = {"max_retries": "5"} + + os.makedirs(os.environ["TOWERS_DATA_DIR"], exist_ok=True) + + with open(os.path.join(os.environ["TOWERS_DATA_DIR"], "watchtower.conf"), "w") as configfile: + config.write(configfile) + + yield + + shutil.rmtree(os.environ["TOWERS_DATA_DIR"]) + + +@pytest.fixture(scope="session", autouse=True) +def prng_seed(): + random.seed(0) + + +@pytest.fixture(scope="session", autouse=True) +def run_tower(): + tower = TowerMock(tower_sk) + Thread(target=tower.app.run, kwargs={"host": tower_netaddr, "port": tower_port}, daemon=True).start() + + +def test_helpme_starts(node_factory): + l1 = node_factory.get_node() + # Test dynamically + l1.rpc.plugin_start(plugin_path) + l1.rpc.plugin_stop(plugin_path) + l1.rpc.plugin_start(plugin_path) + l1.stop() + # Then statically + l1.daemon.opts["plugin"] = plugin_path + l1.start() + + +def test_watchtower(node_factory): + """ Tests sending data to a single tower with short connection issue""" + + global mocked_return + # FIXME: node_factory is a function scope fixture, so I cannot reuse it while splitting the tests logically + l1, l2 = node_factory.line_graph(2, opts=[{"may_fail": True, "allow_broken_log": True}, {"plugin": plugin_path}]) + + # Register a new tower + l2.rpc.registertower("{}@{}:{}".format(tower_id, tower_netaddr, tower_port)) + + # Make sure the tower in our list of towers + tower_ids = [tower.get("id") for tower in l2.rpc.listtowers().get("towers")] + assert tower_id in tower_ids + + # There are no appointments in the tower at the moment + assert not l2.rpc.gettowerinfo(tower_id).get("appointments") + + # Force a new commitment + mocked_return = "success" + l1.rpc.pay(l2.rpc.invoice(25000000, "lbl1", "desc")["bolt11"]) + + # Check that the tower got it (list is not empty anymore) + # FIXME: it would be great to check the ids, I haven't found a way to check the list of commitments though. + # simply signing the last tx won't work since every payment creates two updates. + appointments = l2.rpc.gettowerinfo(tower_id).get("appointments") + assert appointments + assert not l2.rpc.gettowerinfo(tower_id).get("pending_appointments") + + # Disconnect the tower and see how appointments get backed up + mocked_return = "service_unavailable" + l1.rpc.pay(l2.rpc.invoice(25000000, "lbl2", "desc")["bolt11"]) + pending_appointments = [ + data.get("appointment").get("locator") for data in l2.rpc.gettowerinfo(tower_id).get("pending_appointments") + ] + assert pending_appointments + assert l2.rpc.gettowerinfo(tower_id).get("pending_appointments") + + # The fail has triggered the retry strategy. By "turning it back on" we should get the pending appointments trough + mocked_return = "success" + + # Give it some time to switch + while l2.rpc.gettowerinfo(tower_id).get("pending_appointments"): + sleep(0.1) + + # The previously pending appointments are now part of the sent appointments + assert set(pending_appointments).issubset(l2.rpc.gettowerinfo(tower_id).get("appointments").keys()) + + +def test_watchtower_retry_offline(node_factory): + """Tests sending data to a tower that gets offline for a while. Forces retry using ``retrytower``""" + + global mocked_return + l1, l2 = node_factory.line_graph(2, opts=[{"may_fail": True, "allow_broken_log": True}, {"plugin": plugin_path}]) + + # Send some appointments with to tower "offline" + mocked_return = "service_unavailable" + + # There are no pending appointment atm + assert not l2.rpc.gettowerinfo(tower_id).get("pending_appointments") + + l1.rpc.pay(l2.rpc.invoice(25000000, "lbl3", "desc")["bolt11"]) + pending_appointments = [ + data.get("appointment").get("locator") for data in l2.rpc.gettowerinfo(tower_id).get("pending_appointments") + ] + assert pending_appointments + + # Wait until the auto-retry gives up and force a retry manually + while l2.rpc.gettowerinfo(tower_id).get("status") == "temporarily unreachable": + sleep(0.1) + l2.rpc.retrytower(tower_id) + + # After retrying with an offline tower the pending appointments are the exact same + assert pending_appointments == [ + data.get("appointment").get("locator") for data in l2.rpc.gettowerinfo(tower_id).get("pending_appointments") + ] + + # Now we can "turn the tower back on" and force a retry + mocked_return = "success" + l2.rpc.retrytower(tower_id) + + # Give it some time to send everything + while l2.rpc.gettowerinfo(tower_id).get("pending_appointments"): + sleep(0.1) + + # Check that all went trough + assert set(pending_appointments).issubset(l2.rpc.gettowerinfo(tower_id).get("appointments").keys()) + + +def test_watchtower_no_slots(node_factory): + """Tests sending data to tower for a user that has no available slots""" + + global mocked_return + l1, l2 = node_factory.line_graph(2, opts=[{"may_fail": True, "allow_broken_log": True}, {"plugin": plugin_path}]) + + # Simulates there are no available slots when trying to send an appointment to the tower + mocked_return = "reject_no_slots" + + # There are no pending appointments atm + assert not l2.rpc.gettowerinfo(tower_id).get("pending_appointments") + + # Make a payment and the appointment should be left as pending + l1.rpc.pay(l2.rpc.invoice(25000000, "lbl3", "desc")["bolt11"]) + pending_appointments = [ + data.get("appointment").get("locator") for data in l2.rpc.gettowerinfo(tower_id).get("pending_appointments") + ] + assert pending_appointments + + # Retrying should work but appointment won't go trough + assert "Retrying tower" in l2.rpc.retrytower(tower_id) + assert pending_appointments == [ + data.get("appointment").get("locator") for data in l2.rpc.gettowerinfo(tower_id).get("pending_appointments") + ] + + # Adding slots + retrying should work + mocked_return = "success" + l2.rpc.retrytower(tower_id) + while l2.rpc.gettowerinfo(tower_id).get("pending_appointments"): + sleep(0.1) + + assert set(pending_appointments).issubset(l2.rpc.gettowerinfo(tower_id).get("appointments").keys()) + + +def test_watchtower_invalid_appointment(node_factory): + """Tests sending an invalid appointment to a tower""" + + global mocked_return + l1, l2 = node_factory.line_graph(2, opts=[{"may_fail": True, "allow_broken_log": True}, {"plugin": plugin_path}]) + + # Simulates sending an appointment with invalid data to the tower + mocked_return = "reject_invalid" + + # There are no invalid appointment atm + assert not l2.rpc.gettowerinfo(tower_id).get("invalid_appointments") + + # Make a payment and the appointment should be flagged as invalid + l1.rpc.pay(l2.rpc.invoice(25000000, "lbl4", "desc")["bolt11"]) + + # The appointments have been saved as invalid + assert l2.rpc.gettowerinfo(tower_id).get("invalid_appointments") + + +def test_watchtower_multiple_towers(node_factory): + """ Test sending data to multiple towers at the same time""" + global mocked_return + + # Create the new tower + another_tower_netaddr = "localhost" + another_tower_port = "5678" + another_tower_sk = PrivateKey() + another_tower_id = Cryptographer.get_compressed_pk(another_tower_sk.public_key) + + another_tower = TowerMock(another_tower_sk) + Thread( + target=another_tower.app.run, kwargs={"host": another_tower_netaddr, "port": another_tower_port}, daemon=True + ).start() + + l1, l2 = node_factory.line_graph(2, opts=[{"may_fail": True, "allow_broken_log": True}, {"plugin": plugin_path}]) + + # Register a new tower + l2.rpc.registertower("{}@{}:{}".format(another_tower_id, another_tower_netaddr, another_tower_port)) + + # Make sure the tower in our list of towers + tower_ids = [tower.get("id") for tower in l2.rpc.listtowers().get("towers")] + assert another_tower_id in tower_ids + + # Force a new commitment + mocked_return = "success" + l1.rpc.pay(l2.rpc.invoice(25000000, "lbl6", "desc")["bolt11"]) + + # Check that both towers got it + another_tower_appointments = l2.rpc.gettowerinfo(another_tower_id).get("appointments") + assert another_tower_appointments + assert not l2.rpc.gettowerinfo(another_tower_id).get("pending_appointments") + assert set(another_tower_appointments).issubset(l2.rpc.gettowerinfo(tower_id).get("appointments")) + + +def test_watchtower_misbehaving(node_factory): + """Tests sending an appointment to a misbehaving tower""" + + global mocked_return + l1, l2 = node_factory.line_graph(2, opts=[{"may_fail": True, "allow_broken_log": True}, {"plugin": plugin_path}]) + + # Simulates a tower that replies with an invalid signature + mocked_return = "misbehaving_tower" + + # There is no proof of misbehaviour atm + assert not l2.rpc.gettowerinfo(tower_id).get("misbehaving_proof") + + # Make a payment and the appointment make it to the tower, but the response will contain an invalid signature + l1.rpc.pay(l2.rpc.invoice(25000000, "lbl5", "desc")["bolt11"]) + + # The tower should have stored the proof of misbehaviour + tower_info = l2.rpc.gettowerinfo(tower_id) + assert tower_info.get("status") == "misbehaving" + assert tower_info.get("misbehaving_proof") + + +def test_get_appointment(node_factory): + l1, l2 = node_factory.line_graph(2, opts=[{"may_fail": True, "allow_broken_log": True}, {"plugin": plugin_path}]) + + local_appointments = l2.rpc.gettowerinfo(tower_id).get("appointments") + # Get should get a reply for every local appointment + for locator in local_appointments: + response = l2.rpc.getappointment(tower_id, locator) + assert response.get("locator") == locator + assert response.get("status") == "being_watched" + + # Made up appointments should return a 404 + rand_locator = get_random_value_hex(16) + response = l2.rpc.getappointment(tower_id, rand_locator) + assert response.get("locator") == rand_locator + assert response.get("status") == "not_found" diff --git a/watchtower-plugin/tower_info.py b/watchtower-plugin/tower_info.py new file mode 100644 index 0000000..ef7f321 --- /dev/null +++ b/watchtower-plugin/tower_info.py @@ -0,0 +1,123 @@ +class TowerInfo: + """ + TowerInfo represents all the data the plugin hold about a tower. + + Args: + netaddr (:obj:`str`): the tower network address. + available_slots (:obj:`int`): the amount of available appointment slots in the tower. + status (:obj:`str`): the tower status. The tower can be in the following status: + reachable: if the tower can be reached. + temporarily unreachable: if the tower cannot be reached but the issue is transitory. + unreachable: if the tower cannot be reached and the issue has persisted long enough, or it is permanent. + subscription error: if there has been a problem with the subscription (e.g: run out of slots). + misbehaving: if the tower has been caught misbehaving (e.g: an invalid signature has been received). + + Attributes: + appointments (:obj:`dict`): a collection of accepted appointments. + pending_appointments (:obj:`list`): a collection of pending appointments. Appointments are pending when the + tower is unreachable or the subscription has expired / run out of slots. + invalid_appointments (:obj:`list`): a collection of invalid appointments. Appointments are invalid if the tower + rejects them for not following the proper format. + misbehaving_proof (:obj:`dict`): a proof of misbehaviour from the tower. The tower is abandoned if so. + """ + + def __init__(self, netaddr, available_slots, status="reachable"): + self.netaddr = netaddr + self.available_slots = available_slots + self.status = status + + self.appointments = {} + self.pending_appointments = [] + self.invalid_appointments = [] + self.misbehaving_proof = None + + @classmethod + def from_dict(cls, tower_data): + """ + Builds a TowerInfo object from a dictionary. + + Args: + tower_data (:obj:`dict`): a dictionary containing all the TowerInfo fields. + + Returns: + :obj:`TowerInfo`: A TowerInfo object built with the provided data. + + Raises: + :obj:`ValueError`: If any of the expected fields is missing in the dictionary. + """ + + netaddr = tower_data.get("netaddr") + available_slots = tower_data.get("available_slots") + status = tower_data.get("status") + appointments = tower_data.get("appointments") + pending_appointments = tower_data.get("pending_appointments") + invalid_appointments = tower_data.get("invalid_appointments") + misbehaving_proof = tower_data.get("misbehaving_proof") + + if any( + v is None + for v in [netaddr, available_slots, status, appointments, pending_appointments, invalid_appointments] + ): + raise ValueError("Wrong appointment data, some fields are missing") + + tower = cls(netaddr, available_slots, status) + tower.appointments = appointments + tower.pending_appointments = pending_appointments + tower.invalid_appointments = invalid_appointments + tower.misbehaving_proof = misbehaving_proof + + return tower + + def to_dict(self): + """ + Builds a dictionary from a TowerInfo object. + + Returns: + :obj:`dict`: The TowerInfo object as a dictionary. + """ + return self.__dict__ + + def get_summary(self): + """ + Gets a summary of the TowerInfo object. + + The plugin only stores the minimal information in memory, the rest is dumped into the DB. Data kept in memory + is stored in TowerSummary objects. + + Returns: + :obj:`dict`: The summary of the TowerInfo object. + """ + return TowerSummary(self) + + +class TowerSummary: + """ + A smaller representation of the TowerInfo data to be kept in memory. + + Args: + tower_info(:obj:`TowerInfo`): A TowerInfo object. + + Attributes: + netaddr (:obj:`str`): the tower network address. + status (:obj:`str`): the status of the tower. + available_slots (:obj:`int`): the amount of available appointment slots in the tower. + pending_appointments (:obj:`list`): the collection of pending appointments. + invalid_appointments (:obj:`list`): the collection of invalid appointments. + """ + + def __init__(self, tower_info): + self.netaddr = tower_info.netaddr + self.status = tower_info.status + self.available_slots = tower_info.available_slots + self.pending_appointments = tower_info.pending_appointments + self.invalid_appointments = tower_info.invalid_appointments + + def to_dict(self): + """ + Builds a dictionary from a TowerSummary object. + + Returns: + :obj:`dict`: The TowerSummary object as a dictionary. + """ + + return self.__dict__ diff --git a/watchtower-plugin/towers_dbm.py b/watchtower-plugin/towers_dbm.py new file mode 100644 index 0000000..6888f17 --- /dev/null +++ b/watchtower-plugin/towers_dbm.py @@ -0,0 +1,117 @@ +import json + +from common.db_manager import DBManager +from common.tools import is_compressed_pk + + +class TowersDBM(DBManager): + """ + The :class:`TowersDBM` is in charge of interacting with the towers database (``LevelDB``). + Keys and values are stored as bytes in the database but processed as strings by the manager. + + Args: + db_path (:obj:`str`): the path (relative or absolute) to the system folder containing the database. A fresh + database will be created if the specified path does not contain one. + + Raises: + :obj:`ValueError`: If the provided ``db_path`` is not a string. + :obj:`plyvel.Error`: If the db is currently unavailable (being used by another process). + """ + + def __init__(self, db_path, plugin): + if not isinstance(db_path, str): + raise ValueError("db_path must be a valid path/name") + + super().__init__(db_path) + self.plugin = plugin + + def store_tower_record(self, tower_id, tower_data): + """ + Stores a tower record to the database. ``tower_id`` is used as identifier. + + Args: + tower_id (:obj:`str`): a 33-byte hex-encoded string identifying the tower. + tower_data (:obj:`dict`): the tower associated data, as a dictionary. + + Returns: + :obj:`bool`: True if the tower record was stored in the database, False otherwise. + """ + + if is_compressed_pk(tower_id): + try: + self.create_entry(tower_id, json.dumps(tower_data.to_dict())) + self.plugin.log(f"Adding tower to Tower's db (id={tower_id})") + return True + + except (json.JSONDecodeError, TypeError): + self.plugin.log( + f"Could't add tower to db. Wrong tower data format (tower_id={tower_id}, tower_data={tower_data.to_dict()})" + ) + return False + + else: + self.plugin.log( + f"Could't add user to db. Wrong pk format (tower_id={tower_id}, tower_data={tower_data.to_dict()})" + ) + return False + + def load_tower_record(self, tower_id): + """ + Loads a tower record from the database using the ``tower_id`` as identifier. + + Args: + + tower_id (:obj:`str`): a 33-byte hex-encoded string identifying the tower. + + Returns: + :obj:`dict`: A dictionary containing the tower data if the ``key`` is found. + + Returns ``None`` otherwise. + """ + + try: + data = self.load_entry(tower_id) + data = json.loads(data) + except (TypeError, json.decoder.JSONDecodeError): + data = None + + return data + + def delete_tower_record(self, tower_id): + """ + Deletes a tower record from the database. + + Args: + tower_id (:obj:`str`): a 33-byte hex-encoded string identifying the tower. + + Returns: + :obj:`bool`: True if the tower was deleted from the database or it was non-existent, False otherwise. + """ + + try: + self.delete_entry(tower_id) + self.plugin.log(f"Deleting tower from Tower's db (id={tower_id})") + return True + + except TypeError: + self.plugin.log(f"Cannot delete user from db, user key has wrong type (id={tower_id})") + return False + + def load_all_tower_records(self): + """ + Loads all tower records from the database. + + Returns: + :obj:`dict`: A dictionary containing all tower records indexed by ``tower_id``. + + Returns an empty dictionary if no data is found. + """ + + data = {} + + for k, v in self.db.iterator(): + # Get uuid and appointment_data from the db + tower_id = k.decode("utf-8") + data[tower_id] = json.loads(v) + + return data diff --git a/watchtower-plugin/watchtower.py b/watchtower-plugin/watchtower.py new file mode 100755 index 0000000..a6671fb --- /dev/null +++ b/watchtower-plugin/watchtower.py @@ -0,0 +1,400 @@ +#!/usr/bin/env python3 +import os +import plyvel +from queue import Queue +from pyln.client import Plugin +from threading import Thread, Lock + +from common.tools import compute_locator +from common.appointment import Appointment +from common.config_loader import ConfigLoader +from common.cryptographer import Cryptographer +from common.exceptions import InvalidParameter, SignatureError, EncryptionError + +import arg_parser +from retrier import Retrier +from tower_info import TowerInfo +from towers_dbm import TowersDBM +from keys import generate_keys, load_keys +from exceptions import TowerConnectionError, TowerResponseError +from net.http import post_request, process_post_response, add_appointment + + +DATA_DIR = os.getenv("TOWERS_DATA_DIR", os.path.expanduser("~/.watchtower/")) +CONF_FILE_NAME = "watchtower.conf" + +DEFAULT_CONF = { + "DEFAULT_PORT": {"value": 9814, "type": int}, + "MAX_RETRIES": {"value": 30, "type": int}, + "APPOINTMENTS_FOLDER_NAME": {"value": "appointment_receipts", "type": str, "path": True}, + "TOWERS_DB": {"value": "towers", "type": str, "path": True}, + "PRIVATE_KEY": {"value": "sk.der", "type": str, "path": True}, +} + + +plugin = Plugin() + + +class WTClient: + """ + Holds all the data regarding the watchtower client. + + Fires an additional tread to take care of retries. + + Args: + sk (:obj:`PrivateKey): the user private key. Used to sign appointment sent to the towers. + user_id (:obj:`PrivateKey): the identifier of the user (compressed public key). + config (:obj:`dict`): the client configuration loaded on a dictionary. + + Attributes: + towers (:obj:`dict`): a collection of registered towers. Indexed by tower_id, populated with :obj:`TowerSummary` + objects. + db_manager (:obj:`towers_dbm.TowersDBM`): a manager to interact with the towers database. + retrier (:obj:`retrier.Retrier`): a ``Retrier`` in charge of retrying sending jobs to temporarily unreachable + towers. + lock (:obj:`Lock`): a thread lock. + """ + + def __init__(self, sk, user_id, config): + self.sk = sk + self.user_id = user_id + self.towers = {} + self.db_manager = TowersDBM(config.get("TOWERS_DB"), plugin) + self.retrier = Retrier(config.get("MAX_RETRIES"), Queue()) + self.config = config + self.lock = Lock() + + # Populate the towers dict with data from the db + for tower_id, tower_info in self.db_manager.load_all_tower_records().items(): + self.towers[tower_id] = TowerInfo.from_dict(tower_info).get_summary() + + Thread(target=self.retrier.manage_retry, args=[plugin], daemon=True).start() + + def update_tower_state(self, tower_id, tower_update): + """ + Updates the state of a tower both in memory and disk. + + Access if restricted thought a lock to prevent race conditions. + + Args: + tower_id (:obj:`str`): the identifier of the tower to be updated. + tower_update (:obj:`dict`): a dictionary containing the data to be added / removed. + """ + + self.lock.acquire() + tower_info = TowerInfo.from_dict(self.db_manager.load_tower_record(tower_id)) + + if "status" in tower_update: + tower_info.status = tower_update.get("status") + if "appointment" in tower_update: + locator, signature = tower_update.get("appointment") + tower_info.appointments[locator] = signature + tower_info.available_slots = tower_update.get("available_slots") + if "pending_appointment" in tower_update: + data, action = tower_update.get("pending_appointment") + if action == "add": + tower_info.pending_appointments.append(list(data)) + else: + tower_info.pending_appointments.remove(list(data)) + if "invalid_appointment" in tower_update: + tower_info.invalid_appointments.append(list(tower_update.get("invalid_appointment"))) + + if "misbehaving_proof" in tower_update: + tower_info.misbehaving_proof = tower_update.get("misbehaving_proof") + + self.towers[tower_id] = tower_info.get_summary() + self.db_manager.store_tower_record(tower_id, tower_info) + self.lock.release() + + +@plugin.init() +def init(options, configuration, plugin): + """Initializes the plugin""" + + try: + user_sk, user_id = generate_keys(DATA_DIR) + plugin.log(f"Generating a new key pair for the watchtower client. Keys stored at {DATA_DIR}") + + except FileExistsError: + plugin.log("A key file for the watchtower client already exists. Loading it") + user_sk, user_id = load_keys(DATA_DIR) + + plugin.log(f"Plugin watchtower client initialized. User id = {user_id}") + config_loader = ConfigLoader(DATA_DIR, CONF_FILE_NAME, DEFAULT_CONF, {}) + + try: + plugin.wt_client = WTClient(user_sk, user_id, config_loader.build_config()) + except plyvel.IOError: + error = "Cannot load towers db. Resource temporarily unavailable" + plugin.log("Cannot load towers db. Resource temporarily unavailable") + raise IOError(error) + + +@plugin.method("registertower", desc="Register your public key (user id) with the tower.") +def register(plugin, tower_id, host=None, port=None): + """ + Registers the user to the tower. + + Args: + plugin (:obj:`Plugin`): this plugin. + tower_id (:obj:`str`): the identifier of the tower to connect to (a compressed public key). + host (:obj:`str`): the ip or hostname to connect to, optional. + port (:obj:`int`): the port to connect to, optional. + + Accepted tower_id formats: + - tower_id@host:port + - tower_id host port + - tower_id@host (will default port to DEFAULT_PORT) + - tower_id host (will default port to DEFAULT_PORT) + + Returns: + :obj:`dict`: a dictionary containing the subscription data. + """ + + try: + tower_id, tower_netaddr = arg_parser.parse_register_arguments(tower_id, host, port, plugin.wt_client.config) + + # Defaulting to http hosts for now + if not tower_netaddr.startswith("http"): + tower_netaddr = "http://" + tower_netaddr + + # Send request to the server. + register_endpoint = f"{tower_netaddr}/register" + data = {"public_key": plugin.wt_client.user_id} + + plugin.log(f"Registering in the Eye of Satoshi (tower_id={tower_id})") + + response = process_post_response(post_request(data, register_endpoint, tower_id)) + plugin.log(f"Registration succeeded. Available slots: {response.get('available_slots')}") + + # Save data + tower_info = TowerInfo(tower_netaddr, response.get("available_slots")) + plugin.wt_client.lock.acquire() + plugin.wt_client.towers[tower_id] = tower_info.get_summary() + plugin.wt_client.db_manager.store_tower_record(tower_id, tower_info) + plugin.wt_client.lock.release() + + return response + + except (InvalidParameter, TowerConnectionError, TowerResponseError) as e: + plugin.log(str(e), level="warn") + return e.to_json() + + +@plugin.method("getappointment", desc="Gets appointment data from the tower given the tower id and the locator.") +def get_appointment(plugin, tower_id, locator): + """ + Gets information about an appointment from the tower. + + Args: + plugin (:obj:`Plugin`): this plugin. + tower_id (:obj:`str`): the identifier of the tower to query. + locator (:obj:`str`): the appointment locator. + + Returns: + :obj:`dict`: a dictionary containing the appointment data. + """ + + # FIXME: All responses from the tower should be signed. + try: + tower_id, locator = arg_parser.parse_get_appointment_arguments(tower_id, locator) + + if tower_id not in plugin.wt_client.towers: + raise InvalidParameter("tower_id is not within the registered towers", tower_id=tower_id) + + message = f"get appointment {locator}" + signature = Cryptographer.sign(message.encode(), plugin.wt_client.sk) + data = {"locator": locator, "signature": signature} + + # Send request to the server. + tower_netaddr = plugin.wt_client.towers[tower_id].netaddr + get_appointment_endpoint = f"{tower_netaddr}/get_appointment" + plugin.log(f"Requesting appointment from {tower_id}") + + response = process_post_response(post_request(data, get_appointment_endpoint, tower_id)) + return response + + except (InvalidParameter, TowerConnectionError, TowerResponseError) as e: + plugin.log(str(e), level="warn") + return e.to_json() + + +@plugin.method("listtowers", desc="List all registered towers.") +def list_towers(plugin): + """ + Lists all the registered towers. The given information comes from memory, so it is summarized. + + Args: + plugin (:obj:`Plugin`): this plugin. + + Returns: + :obj:`dict`: a dictionary containing the registered towers data. + """ + + towers_info = {"towers": []} + for tower_id, tower in plugin.wt_client.towers.items(): + values = {k: v for k, v in tower.to_dict().items() if k not in ["pending_appointments", "invalid_appointments"]} + pending_appointments = [appointment.get("locator") for appointment, signature in tower.pending_appointments] + invalid_appointments = [appointment.get("locator") for appointment, signature in tower.invalid_appointments] + values["pending_appointments"] = pending_appointments + values["invalid_appointments"] = invalid_appointments + towers_info["towers"].append({"id": tower_id, **values}) + + return towers_info + + +@plugin.method("gettowerinfo", desc="List all towers registered towers.") +def get_tower_info(plugin, tower_id): + """ + Gets information about a given tower. Data comes from disk (DB), so all stored data is provided. + + Args: + plugin (:obj:`Plugin`): this plugin. + tower_id: (:obj:`str`): the identifier of the queried tower. + + Returns: + :obj:`dict`: a dictionary containing all data about the queried tower. + """ + + tower_info = TowerInfo.from_dict(plugin.wt_client.db_manager.load_tower_record(tower_id)) + pending_appointments = [ + {"appointment": appointment, "signature": signature} + for appointment, signature in tower_info.pending_appointments + ] + invalid_appointments = [ + {"appointment": appointment, "tower_signature": signature} + for appointment, signature in tower_info.invalid_appointments + ] + tower_info.pending_appointments = pending_appointments + tower_info.invalid_appointments = invalid_appointments + return {"id": tower_id, **tower_info.to_dict()} + + +@plugin.method("retrytower", desc="Retry to send pending appointment to an unreachable tower.") +def retry_tower(plugin, tower_id): + """ + Triggers a manual retry of a tower, tries to send all pending appointments to to it. + + Only works if the tower is unreachable or there's been a subscription error. + + Args: + plugin (:obj:`Plugin`): this plugin. + tower_id: (:obj:`str`): the identifier of the tower to be retried. + + Returns: + + """ + response = None + plugin.wt_client.lock.acquire() + tower = plugin.wt_client.towers.get(tower_id) + + if not tower: + response = {"error": f"{tower_id} is not a registered tower"} + + # FIXME: it may be worth only allowing unreachable and forcing a retry on register_tower if the state is + # subscription error. + if tower.status not in ["unreachable", "subscription error"]: + response = { + "error": f"Cannot retry tower. Expected tower status 'unreachable' or 'subscription error'. Received {tower.status}" + } + if not tower.pending_appointments: + response = {"error": f"{tower_id} does not have pending appointments"} + + if not response: + response = f"Retrying tower {tower_id}" + plugin.log(response) + plugin.wt_client.towers[tower_id].status = "temporarily unreachable" + plugin.wt_client.retrier.temp_unreachable_towers.put(tower_id) + + plugin.wt_client.lock.release() + return response + + +@plugin.hook("commitment_revocation") +def on_commitment_revocation(plugin, **kwargs): + """ + Sends an appointment to all registered towers for every net commitment transaction. + + kwargs should contain the commitment identifier (commitment_txid) and the penalty transaction (penalty_tx) + + Args: + plugin (:obj:`Plugin`): this plugin. + """ + + try: + commitment_txid, penalty_tx = arg_parser.parse_add_appointment_arguments(kwargs) + appointment = Appointment( + locator=compute_locator(commitment_txid), + to_self_delay=20, # does not matter for now, any value 20-2^32-1 would do + encrypted_blob=Cryptographer.encrypt(penalty_tx, commitment_txid), + ) + signature = Cryptographer.sign(appointment.serialize(), plugin.wt_client.sk) + + except (InvalidParameter, EncryptionError, SignatureError) as e: + plugin.log(str(e), level="warn") + return {"result": "continue"} + + # Send appointment to the towers. + # FIXME: sending the appointment to all registered towers atm. Some management would be nice. + for tower_id, tower in plugin.wt_client.towers.items(): + tower_update = {} + + if tower.status == "misbehaving": + continue + + try: + if tower.status == "reachable": + tower_signature, available_slots = add_appointment( + plugin, tower_id, tower, appointment.to_dict(), signature + ) + tower_update["appointment"] = (appointment.locator, tower_signature) + tower_update["available_slots"] = available_slots + + else: + if tower.status in ["temporarily unreachable", "unreachable"]: + plugin.log(f"{tower_id} is {tower.status}. Adding {appointment.locator} to pending") + elif tower.status == "subscription error": + plugin.log(f"There is a subscription issue with {tower_id}. Adding appointment to pending") + + tower_update["pending_appointment"] = (appointment.to_dict(), signature), "add" + + except SignatureError as e: + tower_update["status"] = "misbehaving" + tower_update["misbehaving_proof"] = { + "appointment": appointment.to_dict(), + "signature": e.kwargs.get("signature"), + "recovered_id": e.kwargs.get("recovered_id"), + } + + except TowerConnectionError: + # All TowerConnectionError are transitory. Connections are tried on register so URLs cannot be malformed. + # Flag appointment for retry + tower_update["status"] = "temporarily unreachable" + plugin.log(f"Adding {appointment.locator} to pending") + tower_update["pending_appointment"] = (appointment.to_dict(), signature), "add" + tower_update["retry"] = True + + except TowerResponseError as e: + tower_update["status"] = e.kwargs.get("status") + + if tower_update["status"] in ["temporarily unreachable", "subscription error"]: + plugin.log(f"Adding {appointment.locator} to pending") + tower_update["pending_appointment"] = (appointment.to_dict(), signature), "add" + + if tower_update["status"] == "temporarily unreachable": + tower_update["retry"] = True + + if e.kwargs.get("invalid_appointment"): + tower_update["invalid_appointment"] = (appointment.to_dict(), signature) + + finally: + # Update memory and TowersDB + plugin.wt_client.update_tower_state(tower_id, tower_update) + + if tower_update.get("retry"): + plugin.wt_client.retrier.temp_unreachable_towers.put(tower_id) + + return {"result": "continue"} + + +plugin.run()