diff --git a/teos/watcher.py b/teos/watcher.py index 3d7120d..a2ee8a2 100644 --- a/teos/watcher.py +++ b/teos/watcher.py @@ -22,7 +22,7 @@ class AppointmentLimitReached(BasicException): class AppointmentAlreadyTriggered(BasicException): - """Raised an appointment is sent to the Watcher but that same data has already been sent to the Responder""" + """Raised when an appointment is sent to the Watcher but that same data has already been sent to the Responder""" class LocatorCache: @@ -35,11 +35,11 @@ class LocatorCache: blocks_in_cache (:obj:`int`): the numbers of blocks to keep in the cache. Attributes: - cache (:obj:`dict`): a dictionary of ``locator:dispute_txid`` pair that received appointments are checked + cache (:obj:`dict`): a dictionary of ``locator:dispute_txid`` pairs that received appointments are checked against. blocks (:obj:`OrderedDict`): An ordered dictionary of the last ``blocks_in_cache`` blocks (block_hash:locators). Used to keep track of what data belongs to what block, so data can be pruned accordingly. Also needed to - rebuilt the cache in case of a reorgs. + rebuild the cache in case of reorgs. cache_size (:obj:`int`): the size of the cache in blocks. """ @@ -50,10 +50,10 @@ class LocatorCache: def init(self, last_known_block, block_processor): """ - Sets the initial state of the block cache. + Sets the initial state of the locator cache. Args: - last_known_block (:obj:`str`): the last known block of the ``Watcher``. + last_known_block (:obj:`str`): the last known block by the ``Watcher``. block_processor (:obj:`teos.block_processor.BlockProcessor`): a ``BlockProcessor`` instance. """ @@ -62,7 +62,7 @@ class LocatorCache: target_block_hash = last_known_block for _ in range(self.cache_size): # In some setups, like regtest, it could be the case that there are no enough previous blocks. - # In those cases we pull as many as we can. + # In those cases we pull as many as we can (up to ``cache_size``). if target_block_hash: target_block = block_processor.get_block(target_block_hash) if not target_block: @@ -70,9 +70,9 @@ class LocatorCache: else: break - locators = {compute_locator(txid): txid for txid in target_block.get("tx")} - self.cache.update(locators) - self.blocks[target_block_hash] = locators + locator_txid_map = {compute_locator(txid): txid for txid in target_block.get("tx")} + self.cache.update(locator_txid_map) + self.blocks[target_block_hash] = list(locator_txid_map.keys()) target_block_hash = target_block.get("previousblockhash") self.blocks = OrderedDict(reversed((list(self.blocks.items())))) @@ -80,18 +80,15 @@ class LocatorCache: def fix_cache(self, last_known_block, block_processor): tmp_cache = LocatorCache(self.cache_size) + # We assume there are no reorgs back to genesis. If so, this would raise some log warnings. And the cache will + # be filled with less than ``cache_size`` blocks.` target_block_hash = last_known_block - for _ in range(self.cache_size): + for _ in range(tmp_cache.cache_size): target_block = block_processor.get_block(target_block_hash) if target_block: - if target_block_hash in self.blocks: - tmp_cache.cache.update(self.blocks[target_block_hash]) - tmp_cache.blocks[target_block_hash] = self.blocks[target_block_hash] - else: - locators = {compute_locator(txid): txid for txid in target_block.get("tx")} - tmp_cache.cache.update(locators) - tmp_cache.blocks[target_block_hash] = locators - + locator_txid_map = {compute_locator(txid): txid for txid in target_block.get("tx")} + tmp_cache.cache.update(locator_txid_map) + tmp_cache.blocks[target_block_hash] = list(locator_txid_map.keys()) target_block_hash = target_block.get("previousblockhash") self.blocks = OrderedDict(reversed((list(tmp_cache.blocks.items())))) @@ -103,8 +100,8 @@ class LocatorCache: def remove_older_block(self): """ Removes the older block from the cache """ - block_hash, locator_map = self.blocks.popitem(last=False) - for locator, txid in locator_map.items(): + block_hash, locators = self.blocks.popitem(last=False) + for locator in locators: del self.cache[locator] logger.debug("Block removed from cache", block_hash=block_hash) @@ -153,7 +150,7 @@ class Watcher: signing_key (:mod:`PrivateKey`): a private key used to sign accepted appointments. max_appointments (:obj:`int`): the maximum amount of appointments accepted by the ``Watcher`` at the same time. last_known_block (:obj:`str`): the last block known by the ``Watcher``. - last_known_block (:obj:`LocatorCache`): a cache of locators from the last ``blocks_in_cache`` blocks. + locator_cache (:obj:`LocatorCache`): a cache of locators for the last ``blocks_in_cache`` blocks. Raises: :obj:`InvalidKey `: if teos sk cannot be loaded. @@ -237,19 +234,14 @@ class Watcher: # Appointments that were triggered in blocks hold in the cache if appointment.locator in self.locator_cache.cache: try: - breach = self.check_breach(uuid, appointment, self.locator_cache.cache[appointment.locator]) + dispute_txid = self.locator_cache.cache[appointment.locator] + penalty_txid, penalty_rawtx = self.check_breach(uuid, appointment, dispute_txid) receipt = self.responder.handle_breach( - uuid, - breach["locator"], - breach["dispute_txid"], - breach["penalty_txid"], - breach["penalty_rawtx"], - user_id, - self.last_known_block, + uuid, appointment.locator, dispute_txid, penalty_txid, penalty_rawtx, user_id, self.last_known_block ) - # At this point the appointment is accepted but data is only kept if it goes through the Responder - # otherwise it is dropped. + # At this point the appointment is accepted but data is only kept if it goes through the Responder. + # Otherwise it is dropped. if receipt.delivered: self.db_manager.store_watcher_appointment(uuid, appointment.to_dict()) self.db_manager.create_append_locator_map(appointment.locator, uuid) @@ -261,7 +253,7 @@ class Watcher: # could be used to discourage user misbehaviour. pass - # Regular appointments that have not been triggered (or not recently at least) + # Regular appointments that have not been triggered (or, at least, not recently) else: self.appointments[uuid] = appointment.get_summary() @@ -321,12 +313,12 @@ class Watcher: txids = block.get("tx") # Compute the locator for every transaction in the block and add them to the cache - locators_txid_map = {compute_locator(txid): txid for txid in txids} - self.locator_cache.cache.update(locators_txid_map) - self.locator_cache.blocks[block_hash] = locators_txid_map + locator_txid_map = {compute_locator(txid): txid for txid in txids} + self.locator_cache.cache.update(locator_txid_map) + self.locator_cache.blocks[block_hash] = list(locator_txid_map.keys()) logger.debug("Block added to cache", block_hash=block_hash) - if len(self.appointments) > 0 and locators_txid_map: + if len(self.appointments) > 0 and locator_txid_map: expired_appointments = self.gatekeeper.get_expired_appointments(block["height"]) # Make sure we only try to delete what is on the Watcher (some appointments may have been triggered) expired_appointments = list(set(expired_appointments).intersection(self.appointments.keys())) @@ -340,7 +332,7 @@ class Watcher: expired_appointments, self.appointments, self.locator_uuid_map, self.db_manager ) - valid_breaches, invalid_breaches = self.filter_breaches(self.get_breaches(locators_txid_map)) + valid_breaches, invalid_breaches = self.filter_breaches(self.get_breaches(locator_txid_map)) triggered_flags = [] appointments_to_delete = [] @@ -399,12 +391,12 @@ class Watcher: self.last_known_block = block.get("hash") self.block_queue.task_done() - def get_breaches(self, locators_txid_map): + def get_breaches(self, locator_txid_map): """ Gets a dictionary of channel breaches given a map of locator:dispute_txid. Args: - locators_txid_map (:obj:`dict`): the dictionary of locators (locator:txid) derived from a list of + locator_txid_map (:obj:`dict`): the dictionary of locators (locator:txid) derived from a list of transaction ids. Returns: @@ -413,8 +405,8 @@ class Watcher: """ # Check is any of the tx_ids in the received block is an actual match - intersection = set(self.locator_uuid_map.keys()).intersection(locators_txid_map.keys()) - breaches = {locator: locators_txid_map[locator] for locator in intersection} + intersection = set(self.locator_uuid_map.keys()).intersection(locator_txid_map.keys()) + breaches = {locator: locator_txid_map[locator] for locator in intersection} if len(breaches) > 0: logger.info("List of breaches", breaches=breaches) @@ -434,8 +426,7 @@ class Watcher: dispute_txid (:obj:`str`): the id of the transaction that triggered the breach. Returns: - :obj:`dic`: The breach data in a dictionary (locator, dispute_txid, penalty_txid, penalty_rawtx), if the - breach is correct. + :obj:`tuple`: A tuple containing the penalty txid and the raw penalty tx. Raises: :obj:`EncryptionError`: If the encrypted blob from the provided appointment cannot be decrypted with the @@ -452,21 +443,14 @@ class Watcher: raise e except InvalidTransactionFormat as e: - logger.info("The breach contained an invalid transaction") + logger.info("The breach contained an invalid transaction", uuid=uuid) raise e - valid_breach = { - "locator": appointment.locator, - "dispute_txid": dispute_txid, - "penalty_txid": penalty_tx.get("txid"), - "penalty_rawtx": penalty_rawtx, - } - logger.info( "Breach found for locator", locator=appointment.locator, uuid=uuid, penalty_txid=penalty_tx.get("txid") ) - return valid_breach + return penalty_tx.get("txid"), penalty_rawtx def filter_breaches(self, breaches): """ @@ -482,7 +466,7 @@ class Watcher: :obj:`dict`: A dictionary containing all the breaches flagged either as valid or invalid. The structure is as follows: - ``{locator, dispute_txid, penalty_txid, penalty_rawtx, valid_breach}`` + ``{locator, dispute_txid, penalty_txid, penalty_rawtx}`` """ valid_breaches = {} @@ -496,22 +480,24 @@ class Watcher: appointment = ExtendedAppointment.from_dict(self.db_manager.load_watcher_appointment(uuid)) if appointment.encrypted_blob in decrypted_blobs: - penalty_tx, penalty_rawtx = decrypted_blobs[appointment.encrypted_blob] + penalty_txid, penalty_rawtx = decrypted_blobs[appointment.encrypted_blob] valid_breaches[uuid] = { "locator": appointment.locator, "dispute_txid": dispute_txid, - "penalty_txid": penalty_tx.get("txid"), + "penalty_txid": penalty_txid, "penalty_rawtx": penalty_rawtx, } else: try: - valid_breach = self.check_breach(uuid, appointment, dispute_txid) - valid_breaches[uuid] = valid_breach - decrypted_blobs[appointment.encrypted_blob] = ( - valid_breach["penalty_txid"], - valid_breach["penalty_rawtx"], - ) + penalty_txid, penalty_rawtx = self.check_breach(uuid, appointment, dispute_txid) + valid_breaches[uuid] = { + "locator": appointment.locator, + "dispute_txid": dispute_txid, + "penalty_txid": penalty_txid, + "penalty_rawtx": penalty_rawtx, + } + decrypted_blobs[appointment.encrypted_blob] = (penalty_txid, penalty_rawtx) except (EncryptionError, InvalidTransactionFormat): invalid_breaches.append(uuid) diff --git a/test/teos/unit/test_watcher.py b/test/teos/unit/test_watcher.py index 44e75d2..1c54cc9 100644 --- a/test/teos/unit/test_watcher.py +++ b/test/teos/unit/test_watcher.py @@ -148,9 +148,9 @@ def test_fix_cache(block_processor): # Now let's fake a reorg of less than ``cache_size``. We'll go two blocks into the past. current_tip = block_processor.get_best_block_hash() - current_tip_locators = list(locator_cache.blocks[current_tip].keys()) + current_tip_locators = locator_cache.blocks[current_tip] current_tip_parent = block_processor.get_block(current_tip).get("previousblockhash") - current_tip_parent_locators = list(locator_cache.blocks[current_tip_parent].keys()) + current_tip_parent_locators = locator_cache.blocks[current_tip_parent] fake_tip = block_processor.get_block(current_tip_parent).get("previousblockhash") locator_cache.fix_cache(fake_tip, block_processor) @@ -171,9 +171,9 @@ def test_fix_cache(block_processor): new_cache.fix_cache(block_processor.get_best_block_hash(), block_processor) # None of the data from the old cache is in the new cache - for block_hash, data in locator_cache.blocks.items(): + for block_hash, locators in locator_cache.blocks.items(): assert block_hash not in new_cache.blocks - for locator, txid in data.items(): + for locator in locators: assert locator not in new_cache.cache # The data in the new cache corresponds to the last ``cache_size`` blocks. @@ -181,7 +181,7 @@ def test_fix_cache(block_processor): for i in range(block_count, block_count - locator_cache.cache_size, -1): block_hash = bitcoin_cli(bitcoind_connect_params).getblockhash(i - 1) assert block_hash in new_cache.blocks - for locator, _ in new_cache.blocks[block_hash].items(): + for locator in new_cache.blocks[block_hash]: assert locator in new_cache.cache @@ -533,12 +533,8 @@ def test_check_breach(watcher): appointment, dispute_tx = generate_dummy_appointment() dispute_txid = watcher.block_processor.decode_raw_transaction(dispute_tx).get("txid") - valid_breach = watcher.check_breach(uuid, appointment, dispute_txid) - assert ( - valid_breach - and valid_breach.get("locator") == appointment.locator - and valid_breach.get("dispute_txid") == dispute_txid - ) + penalty_txid, penalty_rawtx = watcher.check_breach(uuid, appointment, dispute_txid) + assert Cryptographer.encrypt(penalty_rawtx, dispute_txid) == appointment.encrypted_blob def test_check_breach_random_data(watcher):