watcher - simplifies locator_cache and check_breach

locator_cache.blocks was storing a dictionary with both the locator and txid pair, when only the locators were actually necessary. Reworks the code a bit to only use locators.
check_breach was returning some values that could be worked out from the unaltered inputs.

Also fixes some comments and docs.
This commit is contained in:
Sergi Delgado Segura
2020-06-03 17:21:21 +02:00
parent 37d1bd9b12
commit 6ea3e8e3ff
2 changed files with 54 additions and 72 deletions

View File

@@ -22,7 +22,7 @@ class AppointmentLimitReached(BasicException):
class AppointmentAlreadyTriggered(BasicException):
"""Raised an appointment is sent to the Watcher but that same data has already been sent to the Responder"""
"""Raised when an appointment is sent to the Watcher but that same data has already been sent to the Responder"""
class LocatorCache:
@@ -35,11 +35,11 @@ class LocatorCache:
blocks_in_cache (:obj:`int`): the numbers of blocks to keep in the cache.
Attributes:
cache (:obj:`dict`): a dictionary of ``locator:dispute_txid`` pair that received appointments are checked
cache (:obj:`dict`): a dictionary of ``locator:dispute_txid`` pairs that received appointments are checked
against.
blocks (:obj:`OrderedDict`): An ordered dictionary of the last ``blocks_in_cache`` blocks (block_hash:locators).
Used to keep track of what data belongs to what block, so data can be pruned accordingly. Also needed to
rebuilt the cache in case of a reorgs.
rebuild the cache in case of reorgs.
cache_size (:obj:`int`): the size of the cache in blocks.
"""
@@ -50,10 +50,10 @@ class LocatorCache:
def init(self, last_known_block, block_processor):
"""
Sets the initial state of the block cache.
Sets the initial state of the locator cache.
Args:
last_known_block (:obj:`str`): the last known block of the ``Watcher``.
last_known_block (:obj:`str`): the last known block by the ``Watcher``.
block_processor (:obj:`teos.block_processor.BlockProcessor`): a ``BlockProcessor`` instance.
"""
@@ -62,7 +62,7 @@ class LocatorCache:
target_block_hash = last_known_block
for _ in range(self.cache_size):
# In some setups, like regtest, it could be the case that there are no enough previous blocks.
# In those cases we pull as many as we can.
# In those cases we pull as many as we can (up to ``cache_size``).
if target_block_hash:
target_block = block_processor.get_block(target_block_hash)
if not target_block:
@@ -70,9 +70,9 @@ class LocatorCache:
else:
break
locators = {compute_locator(txid): txid for txid in target_block.get("tx")}
self.cache.update(locators)
self.blocks[target_block_hash] = locators
locator_txid_map = {compute_locator(txid): txid for txid in target_block.get("tx")}
self.cache.update(locator_txid_map)
self.blocks[target_block_hash] = list(locator_txid_map.keys())
target_block_hash = target_block.get("previousblockhash")
self.blocks = OrderedDict(reversed((list(self.blocks.items()))))
@@ -80,18 +80,15 @@ class LocatorCache:
def fix_cache(self, last_known_block, block_processor):
tmp_cache = LocatorCache(self.cache_size)
# We assume there are no reorgs back to genesis. If so, this would raise some log warnings. And the cache will
# be filled with less than ``cache_size`` blocks.`
target_block_hash = last_known_block
for _ in range(self.cache_size):
for _ in range(tmp_cache.cache_size):
target_block = block_processor.get_block(target_block_hash)
if target_block:
if target_block_hash in self.blocks:
tmp_cache.cache.update(self.blocks[target_block_hash])
tmp_cache.blocks[target_block_hash] = self.blocks[target_block_hash]
else:
locators = {compute_locator(txid): txid for txid in target_block.get("tx")}
tmp_cache.cache.update(locators)
tmp_cache.blocks[target_block_hash] = locators
locator_txid_map = {compute_locator(txid): txid for txid in target_block.get("tx")}
tmp_cache.cache.update(locator_txid_map)
tmp_cache.blocks[target_block_hash] = list(locator_txid_map.keys())
target_block_hash = target_block.get("previousblockhash")
self.blocks = OrderedDict(reversed((list(tmp_cache.blocks.items()))))
@@ -103,8 +100,8 @@ class LocatorCache:
def remove_older_block(self):
""" Removes the older block from the cache """
block_hash, locator_map = self.blocks.popitem(last=False)
for locator, txid in locator_map.items():
block_hash, locators = self.blocks.popitem(last=False)
for locator in locators:
del self.cache[locator]
logger.debug("Block removed from cache", block_hash=block_hash)
@@ -153,7 +150,7 @@ class Watcher:
signing_key (:mod:`PrivateKey`): a private key used to sign accepted appointments.
max_appointments (:obj:`int`): the maximum amount of appointments accepted by the ``Watcher`` at the same time.
last_known_block (:obj:`str`): the last block known by the ``Watcher``.
last_known_block (:obj:`LocatorCache`): a cache of locators from the last ``blocks_in_cache`` blocks.
locator_cache (:obj:`LocatorCache`): a cache of locators for the last ``blocks_in_cache`` blocks.
Raises:
:obj:`InvalidKey <common.exceptions.InvalidKey>`: if teos sk cannot be loaded.
@@ -237,19 +234,14 @@ class Watcher:
# Appointments that were triggered in blocks hold in the cache
if appointment.locator in self.locator_cache.cache:
try:
breach = self.check_breach(uuid, appointment, self.locator_cache.cache[appointment.locator])
dispute_txid = self.locator_cache.cache[appointment.locator]
penalty_txid, penalty_rawtx = self.check_breach(uuid, appointment, dispute_txid)
receipt = self.responder.handle_breach(
uuid,
breach["locator"],
breach["dispute_txid"],
breach["penalty_txid"],
breach["penalty_rawtx"],
user_id,
self.last_known_block,
uuid, appointment.locator, dispute_txid, penalty_txid, penalty_rawtx, user_id, self.last_known_block
)
# At this point the appointment is accepted but data is only kept if it goes through the Responder
# otherwise it is dropped.
# At this point the appointment is accepted but data is only kept if it goes through the Responder.
# Otherwise it is dropped.
if receipt.delivered:
self.db_manager.store_watcher_appointment(uuid, appointment.to_dict())
self.db_manager.create_append_locator_map(appointment.locator, uuid)
@@ -261,7 +253,7 @@ class Watcher:
# could be used to discourage user misbehaviour.
pass
# Regular appointments that have not been triggered (or not recently at least)
# Regular appointments that have not been triggered (or, at least, not recently)
else:
self.appointments[uuid] = appointment.get_summary()
@@ -321,12 +313,12 @@ class Watcher:
txids = block.get("tx")
# Compute the locator for every transaction in the block and add them to the cache
locators_txid_map = {compute_locator(txid): txid for txid in txids}
self.locator_cache.cache.update(locators_txid_map)
self.locator_cache.blocks[block_hash] = locators_txid_map
locator_txid_map = {compute_locator(txid): txid for txid in txids}
self.locator_cache.cache.update(locator_txid_map)
self.locator_cache.blocks[block_hash] = list(locator_txid_map.keys())
logger.debug("Block added to cache", block_hash=block_hash)
if len(self.appointments) > 0 and locators_txid_map:
if len(self.appointments) > 0 and locator_txid_map:
expired_appointments = self.gatekeeper.get_expired_appointments(block["height"])
# Make sure we only try to delete what is on the Watcher (some appointments may have been triggered)
expired_appointments = list(set(expired_appointments).intersection(self.appointments.keys()))
@@ -340,7 +332,7 @@ class Watcher:
expired_appointments, self.appointments, self.locator_uuid_map, self.db_manager
)
valid_breaches, invalid_breaches = self.filter_breaches(self.get_breaches(locators_txid_map))
valid_breaches, invalid_breaches = self.filter_breaches(self.get_breaches(locator_txid_map))
triggered_flags = []
appointments_to_delete = []
@@ -399,12 +391,12 @@ class Watcher:
self.last_known_block = block.get("hash")
self.block_queue.task_done()
def get_breaches(self, locators_txid_map):
def get_breaches(self, locator_txid_map):
"""
Gets a dictionary of channel breaches given a map of locator:dispute_txid.
Args:
locators_txid_map (:obj:`dict`): the dictionary of locators (locator:txid) derived from a list of
locator_txid_map (:obj:`dict`): the dictionary of locators (locator:txid) derived from a list of
transaction ids.
Returns:
@@ -413,8 +405,8 @@ class Watcher:
"""
# Check is any of the tx_ids in the received block is an actual match
intersection = set(self.locator_uuid_map.keys()).intersection(locators_txid_map.keys())
breaches = {locator: locators_txid_map[locator] for locator in intersection}
intersection = set(self.locator_uuid_map.keys()).intersection(locator_txid_map.keys())
breaches = {locator: locator_txid_map[locator] for locator in intersection}
if len(breaches) > 0:
logger.info("List of breaches", breaches=breaches)
@@ -434,8 +426,7 @@ class Watcher:
dispute_txid (:obj:`str`): the id of the transaction that triggered the breach.
Returns:
:obj:`dic`: The breach data in a dictionary (locator, dispute_txid, penalty_txid, penalty_rawtx), if the
breach is correct.
:obj:`tuple`: A tuple containing the penalty txid and the raw penalty tx.
Raises:
:obj:`EncryptionError`: If the encrypted blob from the provided appointment cannot be decrypted with the
@@ -452,21 +443,14 @@ class Watcher:
raise e
except InvalidTransactionFormat as e:
logger.info("The breach contained an invalid transaction")
logger.info("The breach contained an invalid transaction", uuid=uuid)
raise e
valid_breach = {
"locator": appointment.locator,
"dispute_txid": dispute_txid,
"penalty_txid": penalty_tx.get("txid"),
"penalty_rawtx": penalty_rawtx,
}
logger.info(
"Breach found for locator", locator=appointment.locator, uuid=uuid, penalty_txid=penalty_tx.get("txid")
)
return valid_breach
return penalty_tx.get("txid"), penalty_rawtx
def filter_breaches(self, breaches):
"""
@@ -482,7 +466,7 @@ class Watcher:
:obj:`dict`: A dictionary containing all the breaches flagged either as valid or invalid.
The structure is as follows:
``{locator, dispute_txid, penalty_txid, penalty_rawtx, valid_breach}``
``{locator, dispute_txid, penalty_txid, penalty_rawtx}``
"""
valid_breaches = {}
@@ -496,22 +480,24 @@ class Watcher:
appointment = ExtendedAppointment.from_dict(self.db_manager.load_watcher_appointment(uuid))
if appointment.encrypted_blob in decrypted_blobs:
penalty_tx, penalty_rawtx = decrypted_blobs[appointment.encrypted_blob]
penalty_txid, penalty_rawtx = decrypted_blobs[appointment.encrypted_blob]
valid_breaches[uuid] = {
"locator": appointment.locator,
"dispute_txid": dispute_txid,
"penalty_txid": penalty_tx.get("txid"),
"penalty_txid": penalty_txid,
"penalty_rawtx": penalty_rawtx,
}
else:
try:
valid_breach = self.check_breach(uuid, appointment, dispute_txid)
valid_breaches[uuid] = valid_breach
decrypted_blobs[appointment.encrypted_blob] = (
valid_breach["penalty_txid"],
valid_breach["penalty_rawtx"],
)
penalty_txid, penalty_rawtx = self.check_breach(uuid, appointment, dispute_txid)
valid_breaches[uuid] = {
"locator": appointment.locator,
"dispute_txid": dispute_txid,
"penalty_txid": penalty_txid,
"penalty_rawtx": penalty_rawtx,
}
decrypted_blobs[appointment.encrypted_blob] = (penalty_txid, penalty_rawtx)
except (EncryptionError, InvalidTransactionFormat):
invalid_breaches.append(uuid)

View File

@@ -148,9 +148,9 @@ def test_fix_cache(block_processor):
# Now let's fake a reorg of less than ``cache_size``. We'll go two blocks into the past.
current_tip = block_processor.get_best_block_hash()
current_tip_locators = list(locator_cache.blocks[current_tip].keys())
current_tip_locators = locator_cache.blocks[current_tip]
current_tip_parent = block_processor.get_block(current_tip).get("previousblockhash")
current_tip_parent_locators = list(locator_cache.blocks[current_tip_parent].keys())
current_tip_parent_locators = locator_cache.blocks[current_tip_parent]
fake_tip = block_processor.get_block(current_tip_parent).get("previousblockhash")
locator_cache.fix_cache(fake_tip, block_processor)
@@ -171,9 +171,9 @@ def test_fix_cache(block_processor):
new_cache.fix_cache(block_processor.get_best_block_hash(), block_processor)
# None of the data from the old cache is in the new cache
for block_hash, data in locator_cache.blocks.items():
for block_hash, locators in locator_cache.blocks.items():
assert block_hash not in new_cache.blocks
for locator, txid in data.items():
for locator in locators:
assert locator not in new_cache.cache
# The data in the new cache corresponds to the last ``cache_size`` blocks.
@@ -181,7 +181,7 @@ def test_fix_cache(block_processor):
for i in range(block_count, block_count - locator_cache.cache_size, -1):
block_hash = bitcoin_cli(bitcoind_connect_params).getblockhash(i - 1)
assert block_hash in new_cache.blocks
for locator, _ in new_cache.blocks[block_hash].items():
for locator in new_cache.blocks[block_hash]:
assert locator in new_cache.cache
@@ -533,12 +533,8 @@ def test_check_breach(watcher):
appointment, dispute_tx = generate_dummy_appointment()
dispute_txid = watcher.block_processor.decode_raw_transaction(dispute_tx).get("txid")
valid_breach = watcher.check_breach(uuid, appointment, dispute_txid)
assert (
valid_breach
and valid_breach.get("locator") == appointment.locator
and valid_breach.get("dispute_txid") == dispute_txid
)
penalty_txid, penalty_rawtx = watcher.check_breach(uuid, appointment, dispute_txid)
assert Cryptographer.encrypt(penalty_rawtx, dispute_txid) == appointment.encrypted_blob
def test_check_breach_random_data(watcher):