watcher - simplifies locator_cache and check_breach

locator_cache.blocks was storing a dictionary with both the locator and txid pair, when only the locators were actually necessary. Reworks the code a bit to only use locators.
check_breach was returning some values that could be worked out from the unaltered inputs.

Also fixes some comments and docs.
This commit is contained in:
Sergi Delgado Segura
2020-06-03 17:21:21 +02:00
parent 37d1bd9b12
commit 6ea3e8e3ff
2 changed files with 54 additions and 72 deletions

View File

@@ -22,7 +22,7 @@ class AppointmentLimitReached(BasicException):
class AppointmentAlreadyTriggered(BasicException): class AppointmentAlreadyTriggered(BasicException):
"""Raised an appointment is sent to the Watcher but that same data has already been sent to the Responder""" """Raised when an appointment is sent to the Watcher but that same data has already been sent to the Responder"""
class LocatorCache: class LocatorCache:
@@ -35,11 +35,11 @@ class LocatorCache:
blocks_in_cache (:obj:`int`): the numbers of blocks to keep in the cache. blocks_in_cache (:obj:`int`): the numbers of blocks to keep in the cache.
Attributes: Attributes:
cache (:obj:`dict`): a dictionary of ``locator:dispute_txid`` pair that received appointments are checked cache (:obj:`dict`): a dictionary of ``locator:dispute_txid`` pairs that received appointments are checked
against. against.
blocks (:obj:`OrderedDict`): An ordered dictionary of the last ``blocks_in_cache`` blocks (block_hash:locators). blocks (:obj:`OrderedDict`): An ordered dictionary of the last ``blocks_in_cache`` blocks (block_hash:locators).
Used to keep track of what data belongs to what block, so data can be pruned accordingly. Also needed to Used to keep track of what data belongs to what block, so data can be pruned accordingly. Also needed to
rebuilt the cache in case of a reorgs. rebuild the cache in case of reorgs.
cache_size (:obj:`int`): the size of the cache in blocks. cache_size (:obj:`int`): the size of the cache in blocks.
""" """
@@ -50,10 +50,10 @@ class LocatorCache:
def init(self, last_known_block, block_processor): def init(self, last_known_block, block_processor):
""" """
Sets the initial state of the block cache. Sets the initial state of the locator cache.
Args: Args:
last_known_block (:obj:`str`): the last known block of the ``Watcher``. last_known_block (:obj:`str`): the last known block by the ``Watcher``.
block_processor (:obj:`teos.block_processor.BlockProcessor`): a ``BlockProcessor`` instance. block_processor (:obj:`teos.block_processor.BlockProcessor`): a ``BlockProcessor`` instance.
""" """
@@ -62,7 +62,7 @@ class LocatorCache:
target_block_hash = last_known_block target_block_hash = last_known_block
for _ in range(self.cache_size): for _ in range(self.cache_size):
# In some setups, like regtest, it could be the case that there are no enough previous blocks. # In some setups, like regtest, it could be the case that there are no enough previous blocks.
# In those cases we pull as many as we can. # In those cases we pull as many as we can (up to ``cache_size``).
if target_block_hash: if target_block_hash:
target_block = block_processor.get_block(target_block_hash) target_block = block_processor.get_block(target_block_hash)
if not target_block: if not target_block:
@@ -70,9 +70,9 @@ class LocatorCache:
else: else:
break break
locators = {compute_locator(txid): txid for txid in target_block.get("tx")} locator_txid_map = {compute_locator(txid): txid for txid in target_block.get("tx")}
self.cache.update(locators) self.cache.update(locator_txid_map)
self.blocks[target_block_hash] = locators self.blocks[target_block_hash] = list(locator_txid_map.keys())
target_block_hash = target_block.get("previousblockhash") target_block_hash = target_block.get("previousblockhash")
self.blocks = OrderedDict(reversed((list(self.blocks.items())))) self.blocks = OrderedDict(reversed((list(self.blocks.items()))))
@@ -80,18 +80,15 @@ class LocatorCache:
def fix_cache(self, last_known_block, block_processor): def fix_cache(self, last_known_block, block_processor):
tmp_cache = LocatorCache(self.cache_size) tmp_cache = LocatorCache(self.cache_size)
# We assume there are no reorgs back to genesis. If so, this would raise some log warnings. And the cache will
# be filled with less than ``cache_size`` blocks.`
target_block_hash = last_known_block target_block_hash = last_known_block
for _ in range(self.cache_size): for _ in range(tmp_cache.cache_size):
target_block = block_processor.get_block(target_block_hash) target_block = block_processor.get_block(target_block_hash)
if target_block: if target_block:
if target_block_hash in self.blocks: locator_txid_map = {compute_locator(txid): txid for txid in target_block.get("tx")}
tmp_cache.cache.update(self.blocks[target_block_hash]) tmp_cache.cache.update(locator_txid_map)
tmp_cache.blocks[target_block_hash] = self.blocks[target_block_hash] tmp_cache.blocks[target_block_hash] = list(locator_txid_map.keys())
else:
locators = {compute_locator(txid): txid for txid in target_block.get("tx")}
tmp_cache.cache.update(locators)
tmp_cache.blocks[target_block_hash] = locators
target_block_hash = target_block.get("previousblockhash") target_block_hash = target_block.get("previousblockhash")
self.blocks = OrderedDict(reversed((list(tmp_cache.blocks.items())))) self.blocks = OrderedDict(reversed((list(tmp_cache.blocks.items()))))
@@ -103,8 +100,8 @@ class LocatorCache:
def remove_older_block(self): def remove_older_block(self):
""" Removes the older block from the cache """ """ Removes the older block from the cache """
block_hash, locator_map = self.blocks.popitem(last=False) block_hash, locators = self.blocks.popitem(last=False)
for locator, txid in locator_map.items(): for locator in locators:
del self.cache[locator] del self.cache[locator]
logger.debug("Block removed from cache", block_hash=block_hash) logger.debug("Block removed from cache", block_hash=block_hash)
@@ -153,7 +150,7 @@ class Watcher:
signing_key (:mod:`PrivateKey`): a private key used to sign accepted appointments. signing_key (:mod:`PrivateKey`): a private key used to sign accepted appointments.
max_appointments (:obj:`int`): the maximum amount of appointments accepted by the ``Watcher`` at the same time. max_appointments (:obj:`int`): the maximum amount of appointments accepted by the ``Watcher`` at the same time.
last_known_block (:obj:`str`): the last block known by the ``Watcher``. last_known_block (:obj:`str`): the last block known by the ``Watcher``.
last_known_block (:obj:`LocatorCache`): a cache of locators from the last ``blocks_in_cache`` blocks. locator_cache (:obj:`LocatorCache`): a cache of locators for the last ``blocks_in_cache`` blocks.
Raises: Raises:
:obj:`InvalidKey <common.exceptions.InvalidKey>`: if teos sk cannot be loaded. :obj:`InvalidKey <common.exceptions.InvalidKey>`: if teos sk cannot be loaded.
@@ -237,19 +234,14 @@ class Watcher:
# Appointments that were triggered in blocks hold in the cache # Appointments that were triggered in blocks hold in the cache
if appointment.locator in self.locator_cache.cache: if appointment.locator in self.locator_cache.cache:
try: try:
breach = self.check_breach(uuid, appointment, self.locator_cache.cache[appointment.locator]) dispute_txid = self.locator_cache.cache[appointment.locator]
penalty_txid, penalty_rawtx = self.check_breach(uuid, appointment, dispute_txid)
receipt = self.responder.handle_breach( receipt = self.responder.handle_breach(
uuid, uuid, appointment.locator, dispute_txid, penalty_txid, penalty_rawtx, user_id, self.last_known_block
breach["locator"],
breach["dispute_txid"],
breach["penalty_txid"],
breach["penalty_rawtx"],
user_id,
self.last_known_block,
) )
# At this point the appointment is accepted but data is only kept if it goes through the Responder # At this point the appointment is accepted but data is only kept if it goes through the Responder.
# otherwise it is dropped. # Otherwise it is dropped.
if receipt.delivered: if receipt.delivered:
self.db_manager.store_watcher_appointment(uuid, appointment.to_dict()) self.db_manager.store_watcher_appointment(uuid, appointment.to_dict())
self.db_manager.create_append_locator_map(appointment.locator, uuid) self.db_manager.create_append_locator_map(appointment.locator, uuid)
@@ -261,7 +253,7 @@ class Watcher:
# could be used to discourage user misbehaviour. # could be used to discourage user misbehaviour.
pass pass
# Regular appointments that have not been triggered (or not recently at least) # Regular appointments that have not been triggered (or, at least, not recently)
else: else:
self.appointments[uuid] = appointment.get_summary() self.appointments[uuid] = appointment.get_summary()
@@ -321,12 +313,12 @@ class Watcher:
txids = block.get("tx") txids = block.get("tx")
# Compute the locator for every transaction in the block and add them to the cache # Compute the locator for every transaction in the block and add them to the cache
locators_txid_map = {compute_locator(txid): txid for txid in txids} locator_txid_map = {compute_locator(txid): txid for txid in txids}
self.locator_cache.cache.update(locators_txid_map) self.locator_cache.cache.update(locator_txid_map)
self.locator_cache.blocks[block_hash] = locators_txid_map self.locator_cache.blocks[block_hash] = list(locator_txid_map.keys())
logger.debug("Block added to cache", block_hash=block_hash) logger.debug("Block added to cache", block_hash=block_hash)
if len(self.appointments) > 0 and locators_txid_map: if len(self.appointments) > 0 and locator_txid_map:
expired_appointments = self.gatekeeper.get_expired_appointments(block["height"]) expired_appointments = self.gatekeeper.get_expired_appointments(block["height"])
# Make sure we only try to delete what is on the Watcher (some appointments may have been triggered) # Make sure we only try to delete what is on the Watcher (some appointments may have been triggered)
expired_appointments = list(set(expired_appointments).intersection(self.appointments.keys())) expired_appointments = list(set(expired_appointments).intersection(self.appointments.keys()))
@@ -340,7 +332,7 @@ class Watcher:
expired_appointments, self.appointments, self.locator_uuid_map, self.db_manager expired_appointments, self.appointments, self.locator_uuid_map, self.db_manager
) )
valid_breaches, invalid_breaches = self.filter_breaches(self.get_breaches(locators_txid_map)) valid_breaches, invalid_breaches = self.filter_breaches(self.get_breaches(locator_txid_map))
triggered_flags = [] triggered_flags = []
appointments_to_delete = [] appointments_to_delete = []
@@ -399,12 +391,12 @@ class Watcher:
self.last_known_block = block.get("hash") self.last_known_block = block.get("hash")
self.block_queue.task_done() self.block_queue.task_done()
def get_breaches(self, locators_txid_map): def get_breaches(self, locator_txid_map):
""" """
Gets a dictionary of channel breaches given a map of locator:dispute_txid. Gets a dictionary of channel breaches given a map of locator:dispute_txid.
Args: Args:
locators_txid_map (:obj:`dict`): the dictionary of locators (locator:txid) derived from a list of locator_txid_map (:obj:`dict`): the dictionary of locators (locator:txid) derived from a list of
transaction ids. transaction ids.
Returns: Returns:
@@ -413,8 +405,8 @@ class Watcher:
""" """
# Check is any of the tx_ids in the received block is an actual match # Check is any of the tx_ids in the received block is an actual match
intersection = set(self.locator_uuid_map.keys()).intersection(locators_txid_map.keys()) intersection = set(self.locator_uuid_map.keys()).intersection(locator_txid_map.keys())
breaches = {locator: locators_txid_map[locator] for locator in intersection} breaches = {locator: locator_txid_map[locator] for locator in intersection}
if len(breaches) > 0: if len(breaches) > 0:
logger.info("List of breaches", breaches=breaches) logger.info("List of breaches", breaches=breaches)
@@ -434,8 +426,7 @@ class Watcher:
dispute_txid (:obj:`str`): the id of the transaction that triggered the breach. dispute_txid (:obj:`str`): the id of the transaction that triggered the breach.
Returns: Returns:
:obj:`dic`: The breach data in a dictionary (locator, dispute_txid, penalty_txid, penalty_rawtx), if the :obj:`tuple`: A tuple containing the penalty txid and the raw penalty tx.
breach is correct.
Raises: Raises:
:obj:`EncryptionError`: If the encrypted blob from the provided appointment cannot be decrypted with the :obj:`EncryptionError`: If the encrypted blob from the provided appointment cannot be decrypted with the
@@ -452,21 +443,14 @@ class Watcher:
raise e raise e
except InvalidTransactionFormat as e: except InvalidTransactionFormat as e:
logger.info("The breach contained an invalid transaction") logger.info("The breach contained an invalid transaction", uuid=uuid)
raise e raise e
valid_breach = {
"locator": appointment.locator,
"dispute_txid": dispute_txid,
"penalty_txid": penalty_tx.get("txid"),
"penalty_rawtx": penalty_rawtx,
}
logger.info( logger.info(
"Breach found for locator", locator=appointment.locator, uuid=uuid, penalty_txid=penalty_tx.get("txid") "Breach found for locator", locator=appointment.locator, uuid=uuid, penalty_txid=penalty_tx.get("txid")
) )
return valid_breach return penalty_tx.get("txid"), penalty_rawtx
def filter_breaches(self, breaches): def filter_breaches(self, breaches):
""" """
@@ -482,7 +466,7 @@ class Watcher:
:obj:`dict`: A dictionary containing all the breaches flagged either as valid or invalid. :obj:`dict`: A dictionary containing all the breaches flagged either as valid or invalid.
The structure is as follows: The structure is as follows:
``{locator, dispute_txid, penalty_txid, penalty_rawtx, valid_breach}`` ``{locator, dispute_txid, penalty_txid, penalty_rawtx}``
""" """
valid_breaches = {} valid_breaches = {}
@@ -496,22 +480,24 @@ class Watcher:
appointment = ExtendedAppointment.from_dict(self.db_manager.load_watcher_appointment(uuid)) appointment = ExtendedAppointment.from_dict(self.db_manager.load_watcher_appointment(uuid))
if appointment.encrypted_blob in decrypted_blobs: if appointment.encrypted_blob in decrypted_blobs:
penalty_tx, penalty_rawtx = decrypted_blobs[appointment.encrypted_blob] penalty_txid, penalty_rawtx = decrypted_blobs[appointment.encrypted_blob]
valid_breaches[uuid] = { valid_breaches[uuid] = {
"locator": appointment.locator, "locator": appointment.locator,
"dispute_txid": dispute_txid, "dispute_txid": dispute_txid,
"penalty_txid": penalty_tx.get("txid"), "penalty_txid": penalty_txid,
"penalty_rawtx": penalty_rawtx, "penalty_rawtx": penalty_rawtx,
} }
else: else:
try: try:
valid_breach = self.check_breach(uuid, appointment, dispute_txid) penalty_txid, penalty_rawtx = self.check_breach(uuid, appointment, dispute_txid)
valid_breaches[uuid] = valid_breach valid_breaches[uuid] = {
decrypted_blobs[appointment.encrypted_blob] = ( "locator": appointment.locator,
valid_breach["penalty_txid"], "dispute_txid": dispute_txid,
valid_breach["penalty_rawtx"], "penalty_txid": penalty_txid,
) "penalty_rawtx": penalty_rawtx,
}
decrypted_blobs[appointment.encrypted_blob] = (penalty_txid, penalty_rawtx)
except (EncryptionError, InvalidTransactionFormat): except (EncryptionError, InvalidTransactionFormat):
invalid_breaches.append(uuid) invalid_breaches.append(uuid)

View File

@@ -148,9 +148,9 @@ def test_fix_cache(block_processor):
# Now let's fake a reorg of less than ``cache_size``. We'll go two blocks into the past. # Now let's fake a reorg of less than ``cache_size``. We'll go two blocks into the past.
current_tip = block_processor.get_best_block_hash() current_tip = block_processor.get_best_block_hash()
current_tip_locators = list(locator_cache.blocks[current_tip].keys()) current_tip_locators = locator_cache.blocks[current_tip]
current_tip_parent = block_processor.get_block(current_tip).get("previousblockhash") current_tip_parent = block_processor.get_block(current_tip).get("previousblockhash")
current_tip_parent_locators = list(locator_cache.blocks[current_tip_parent].keys()) current_tip_parent_locators = locator_cache.blocks[current_tip_parent]
fake_tip = block_processor.get_block(current_tip_parent).get("previousblockhash") fake_tip = block_processor.get_block(current_tip_parent).get("previousblockhash")
locator_cache.fix_cache(fake_tip, block_processor) locator_cache.fix_cache(fake_tip, block_processor)
@@ -171,9 +171,9 @@ def test_fix_cache(block_processor):
new_cache.fix_cache(block_processor.get_best_block_hash(), block_processor) new_cache.fix_cache(block_processor.get_best_block_hash(), block_processor)
# None of the data from the old cache is in the new cache # None of the data from the old cache is in the new cache
for block_hash, data in locator_cache.blocks.items(): for block_hash, locators in locator_cache.blocks.items():
assert block_hash not in new_cache.blocks assert block_hash not in new_cache.blocks
for locator, txid in data.items(): for locator in locators:
assert locator not in new_cache.cache assert locator not in new_cache.cache
# The data in the new cache corresponds to the last ``cache_size`` blocks. # The data in the new cache corresponds to the last ``cache_size`` blocks.
@@ -181,7 +181,7 @@ def test_fix_cache(block_processor):
for i in range(block_count, block_count - locator_cache.cache_size, -1): for i in range(block_count, block_count - locator_cache.cache_size, -1):
block_hash = bitcoin_cli(bitcoind_connect_params).getblockhash(i - 1) block_hash = bitcoin_cli(bitcoind_connect_params).getblockhash(i - 1)
assert block_hash in new_cache.blocks assert block_hash in new_cache.blocks
for locator, _ in new_cache.blocks[block_hash].items(): for locator in new_cache.blocks[block_hash]:
assert locator in new_cache.cache assert locator in new_cache.cache
@@ -533,12 +533,8 @@ def test_check_breach(watcher):
appointment, dispute_tx = generate_dummy_appointment() appointment, dispute_tx = generate_dummy_appointment()
dispute_txid = watcher.block_processor.decode_raw_transaction(dispute_tx).get("txid") dispute_txid = watcher.block_processor.decode_raw_transaction(dispute_tx).get("txid")
valid_breach = watcher.check_breach(uuid, appointment, dispute_txid) penalty_txid, penalty_rawtx = watcher.check_breach(uuid, appointment, dispute_txid)
assert ( assert Cryptographer.encrypt(penalty_rawtx, dispute_txid) == appointment.encrypted_blob
valid_breach
and valid_breach.get("locator") == appointment.locator
and valid_breach.get("dispute_txid") == dispute_txid
)
def test_check_breach_random_data(watcher): def test_check_breach_random_data(watcher):