From c18abefe4b7865aa608484826ec69550cca74dbb Mon Sep 17 00:00:00 2001 From: Sergi Delgado Segura Date: Wed, 6 May 2020 12:50:44 +0200 Subject: [PATCH] plugin - adds getappointment tests and tests with multiple towers --- watchtower-plugin/test_watchtower.py | 235 +++++++++++++++++++-------- 1 file changed, 171 insertions(+), 64 deletions(-) diff --git a/watchtower-plugin/test_watchtower.py b/watchtower-plugin/test_watchtower.py index e35db02..aa15664 100644 --- a/watchtower-plugin/test_watchtower.py +++ b/watchtower-plugin/test_watchtower.py @@ -12,6 +12,7 @@ from common.appointment import Appointment from common.cryptographer import Cryptographer plugin_path = os.path.join(os.path.dirname(__file__), "watchtower.py") + tower_netaddr = "localhost" tower_port = "1234" tower_sk = PrivateKey() @@ -20,14 +21,103 @@ tower_id = Cryptographer.get_compressed_pk(tower_sk.public_key) mocked_return = None -def add_appointment_success(appointment, available_slots, subscription_expiry): +class TowerMock: + def __init__(self, tower_sk): + self.sk = tower_sk + self.users = {} + self.app = Flask(__name__) + + # Adds all the routes to the functions listed above. + routes = { + "/register": (self.register, ["POST"]), + "/add_appointment": (self.add_appointment, ["POST"]), + "/get_appointment": (self.get_appointment, ["POST"]), + } + + for url, params in routes.items(): + self.app.add_url_rule(url, view_func=params[0], methods=params[1]) + + # Setting Flask log to ERROR only so it does not mess with our logging. Also disabling flask initial messages + logging.getLogger("werkzeug").setLevel(logging.ERROR) + os.environ["WERKZEUG_RUN_MAIN"] = "true" + + # Thread(target=app.run, kwargs={"host": tower_netaddr, "port": tower_port}, daemon=True).start() + + def register(self): + user_id = request.get_json().get("public_key") + + if user_id not in self.users: + self.users[user_id] = {"available_slots": 100, "subscription_expiry": 4320} + else: + self.users[user_id]["available_slots"] += 100 + self.users[user_id]["subscription_expiry"] = 4320 + + rcode = constants.HTTP_OK + response = { + "public_key": user_id, + "available_slots": self.users[user_id].get("available_slots"), + "subscription_expiry": self.users[user_id].get("subscription_expiry"), + } + + return response, rcode + + def add_appointment(self): + appointment = Appointment.from_dict(request.get_json().get("appointment")) + user_id = Cryptographer.get_compressed_pk( + Cryptographer.recover_pk(appointment.serialize(), request.get_json().get("signature")) + ) + + if mocked_return == "success": + response, rtype = add_appointment_success(appointment, self.users[user_id], self.sk) + elif mocked_return == "reject_no_slots": + response, rtype = add_appointment_reject_no_slots() + elif mocked_return == "reject_invalid": + response, rtype = add_appointment_reject_invalid() + elif mocked_return == "misbehaving_tower": + response, rtype = add_appointment_misbehaving_tower(appointment, self.users[user_id], self.sk) + else: + response, rtype = add_appointment_service_unavailable() + + return jsonify(response), rtype + + def get_appointment(self): + locator = request.get_json().get("locator") + message = f"get appointment {locator}" + user_id = Cryptographer.get_compressed_pk( + Cryptographer.recover_pk(message.encode(), request.get_json().get("signature")) + ) + + if ( + user_id in self.users + and "appointments" in self.users[user_id] + and locator in self.users[user_id]["appointments"] + ): + rcode = constants.HTTP_OK + response = self.users[user_id]["appointments"][locator] + response["status"] = "being_watched" + + else: + rcode = constants.HTTP_NOT_FOUND + response = {"locator": locator, "status": "not_found"} + + return jsonify(response), rcode + + +def add_appointment_success(appointment, user, tower_sk): rcode = constants.HTTP_OK response = { "locator": appointment.locator, "signature": Cryptographer.sign(appointment.serialize(), tower_sk), - "available_slots": available_slots - 1, - "subscription_expiry": subscription_expiry, + "available_slots": user.get("available_slots") - 1, + "subscription_expiry": user.get("subscription_expiry"), } + + user["available_slots"] = response.get("available_slots") + if user.get("appointments"): + user["appointments"][appointment.locator] = appointment.to_dict() + else: + user["appointments"] = {appointment.locator: appointment.to_dict()} + return response, rcode @@ -62,12 +152,14 @@ def add_appointment_service_unavailable(): return response, rcode -def add_appointment_misbehaving_tower(appointment, **kwargs): +def add_appointment_misbehaving_tower(appointment, user, tower_sk): # This covers a tower signing with invalid keys wrong_sk = PrivateKey.from_hex(get_random_value_hex(32)) - response, rcode = add_appointment_success(appointment, **kwargs) - response["signature"] = Cryptographer.sign(appointment.serialize(), wrong_sk) + wrong_sig = Cryptographer.sign(appointment.serialize(), wrong_sk) + response, rcode = add_appointment_success(appointment, user, tower_sk) + user["appointments"][appointment.locator]["signature"] = wrong_sig + response["signature"] = wrong_sig return response, rcode @@ -100,56 +192,9 @@ def prng_seed(): @pytest.fixture(scope="session", autouse=True) -def tower_mock(): - app = Flask(__name__) - - users = {} - - @app.route("/register", methods=["POST"]) - def register(): - - user_id = request.get_json().get("public_key") - - if user_id not in users: - users[user_id] = {"available_slots": 100, "subscription_expiry": 4320} - else: - users[user_id]["available_slots"] = 100 - users[user_id]["subscription_expiry"] = 4320 - - rcode = constants.HTTP_OK - response = {"public_key": user_id, **users[user_id]} - - return response, rcode - - @app.route("/add_appointment", methods=["POST"]) - def add_appointment(): - appointment = Appointment.from_dict(request.get_json().get("appointment")) - user_id = Cryptographer.get_compressed_pk( - Cryptographer.recover_pk(appointment.serialize(), request.get_json().get("signature")) - ) - - if mocked_return == "success": - data, rtype = add_appointment_success(appointment, **users[user_id]) - elif mocked_return == "reject_no_slots": - data, rtype = add_appointment_reject_no_slots() - elif mocked_return == "reject_invalid": - data, rtype = add_appointment_reject_invalid() - elif mocked_return == "misbehaving_tower": - data, rtype = add_appointment_misbehaving_tower(appointment, **users[user_id]) - else: - data, rtype = add_appointment_service_unavailable() - - return jsonify(data), rtype - - @app.route("/get_appointment", methods=["POST"]) - def get_appointment(): - pass - - # Setting Flask log to ERROR only so it does not mess with our logging. Also disabling flask initial messages - logging.getLogger("werkzeug").setLevel(logging.ERROR) - os.environ["WERKZEUG_RUN_MAIN"] = "true" - - Thread(target=app.run, kwargs={"host": tower_netaddr, "port": tower_port}, daemon=True).start() +def run_tower(): + tower = TowerMock(tower_sk) + Thread(target=tower.app.run, kwargs={"host": tower_netaddr, "port": tower_port}, daemon=True).start() def test_helpme_starts(node_factory): @@ -165,6 +210,8 @@ def test_helpme_starts(node_factory): def test_watchtower(node_factory): + """ Tests sending data to a single tower with short connection issue""" + global mocked_return # FIXME: node_factory is a function scope fixture, so I cannot reuse it while splitting the tests logically l1, l2 = node_factory.line_graph(2, opts=[{"may_fail": True, "allow_broken_log": True}, {"plugin": plugin_path}]) @@ -172,7 +219,7 @@ def test_watchtower(node_factory): # Register a new tower l2.rpc.registertower("{}@{}:{}".format(tower_id, tower_netaddr, tower_port)) - # Make sure we the tower in our list of towers + # Make sure the tower in our list of towers tower_ids = [tower.get("id") for tower in l2.rpc.listtowers().get("towers")] assert tower_id in tower_ids @@ -206,15 +253,17 @@ def test_watchtower(node_factory): while l2.rpc.gettowerinfo(tower_id).get("pending_appointments"): sleep(0.1) - # The previously pending appointment are now part of the sent appointments + # The previously pending appointments are now part of the sent appointments assert set(pending_appointments).issubset(l2.rpc.gettowerinfo(tower_id).get("appointments").keys()) def test_watchtower_retry_offline(node_factory): + """Tests sending data to a tower that gets offline for a while. Forces retry using ``retrytower``""" + global mocked_return l1, l2 = node_factory.line_graph(2, opts=[{"may_fail": True, "allow_broken_log": True}, {"plugin": plugin_path}]) - # Send some more with to tower "offline" + # Send some appointments with to tower "offline" mocked_return = "service_unavailable" # There are no pending appointment atm @@ -249,6 +298,8 @@ def test_watchtower_retry_offline(node_factory): def test_watchtower_no_slots(node_factory): + """Tests sending data to tower for a user that has no available slots""" + global mocked_return l1, l2 = node_factory.line_graph(2, opts=[{"may_fail": True, "allow_broken_log": True}, {"plugin": plugin_path}]) @@ -258,7 +309,7 @@ def test_watchtower_no_slots(node_factory): # There are no pending appointments atm assert not l2.rpc.gettowerinfo(tower_id).get("pending_appointments") - # Make a payment and the appointment should be as pending + # Make a payment and the appointment should be left as pending l1.rpc.pay(l2.rpc.invoice(25000000, "lbl3", "desc")["bolt11"]) pending_appointments = [ data.get("appointment").get("locator") for data in l2.rpc.gettowerinfo(tower_id).get("pending_appointments") @@ -271,7 +322,7 @@ def test_watchtower_no_slots(node_factory): data.get("appointment").get("locator") for data in l2.rpc.gettowerinfo(tower_id).get("pending_appointments") ] - # Adding appointments + retrying should work + # Adding slots + retrying should work mocked_return = "success" l2.rpc.retrytower(tower_id) while l2.rpc.gettowerinfo(tower_id).get("pending_appointments"): @@ -281,6 +332,8 @@ def test_watchtower_no_slots(node_factory): def test_watchtower_invalid_appointment(node_factory): + """Tests sending an invalid appointment to a tower""" + global mocked_return l1, l2 = node_factory.line_graph(2, opts=[{"may_fail": True, "allow_broken_log": True}, {"plugin": plugin_path}]) @@ -290,21 +343,58 @@ def test_watchtower_invalid_appointment(node_factory): # There are no invalid appointment atm assert not l2.rpc.gettowerinfo(tower_id).get("invalid_appointments") - # Make a payment and the appointment should be dropped + # Make a payment and the appointment should be flagged as invalid l1.rpc.pay(l2.rpc.invoice(25000000, "lbl4", "desc")["bolt11"]) - # The appointments have been saves as invalid + # The appointments have been saved as invalid assert l2.rpc.gettowerinfo(tower_id).get("invalid_appointments") +def test_watchtower_multiple_towers(node_factory): + """ Test sending data to multiple towers at the same time""" + global mocked_return + + # Create the new tower + another_tower_netaddr = "localhost" + another_tower_port = "5678" + another_tower_sk = PrivateKey() + another_tower_id = Cryptographer.get_compressed_pk(another_tower_sk.public_key) + + another_tower = TowerMock(another_tower_sk) + Thread( + target=another_tower.app.run, kwargs={"host": another_tower_netaddr, "port": another_tower_port}, daemon=True + ).start() + + l1, l2 = node_factory.line_graph(2, opts=[{"may_fail": True, "allow_broken_log": True}, {"plugin": plugin_path}]) + + # Register a new tower + l2.rpc.registertower("{}@{}:{}".format(another_tower_id, another_tower_netaddr, another_tower_port)) + + # Make sure the tower in our list of towers + tower_ids = [tower.get("id") for tower in l2.rpc.listtowers().get("towers")] + assert another_tower_id in tower_ids + + # Force a new commitment + mocked_return = "success" + l1.rpc.pay(l2.rpc.invoice(25000000, "lbl6", "desc")["bolt11"]) + + # Check that both towers got it + another_tower_appointments = l2.rpc.gettowerinfo(another_tower_id).get("appointments") + assert another_tower_appointments + assert not l2.rpc.gettowerinfo(another_tower_id).get("pending_appointments") + assert set(another_tower_appointments).issubset(l2.rpc.gettowerinfo(tower_id).get("appointments")) + + def test_watchtower_misbehaving(node_factory): + """Tests sending an appointment to a misbehaving tower""" + global mocked_return l1, l2 = node_factory.line_graph(2, opts=[{"may_fail": True, "allow_broken_log": True}, {"plugin": plugin_path}]) # Simulates a tower that replies with an invalid signature mocked_return = "misbehaving_tower" - # There is no proof of misbehaviour + # There is no proof of misbehaviour atm assert not l2.rpc.gettowerinfo(tower_id).get("misbehaving_proof") # Make a payment and the appointment make it to the tower, but the response will contain an invalid signature @@ -314,3 +404,20 @@ def test_watchtower_misbehaving(node_factory): tower_info = l2.rpc.gettowerinfo(tower_id) assert tower_info.get("status") == "misbehaving" assert tower_info.get("misbehaving_proof") + + +def test_get_appointment(node_factory): + l1, l2 = node_factory.line_graph(2, opts=[{"may_fail": True, "allow_broken_log": True}, {"plugin": plugin_path}]) + + local_appointments = l2.rpc.gettowerinfo(tower_id).get("appointments") + # Get should get a reply for every local appointment + for locator in local_appointments: + response = l2.rpc.getappointment(tower_id, locator) + assert response.get("locator") == locator + assert response.get("status") == "being_watched" + + # Made up appointments should return a 404 + rand_locator = get_random_value_hex(16) + response = l2.rpc.getappointment(tower_id, rand_locator) + assert response.get("locator") == rand_locator + assert response.get("status") == "not_found"