diff --git a/CTFd/utils/events/__init__.py b/CTFd/utils/events/__init__.py index 7e95f6c9..c6b4b601 100644 --- a/CTFd/utils/events/__init__.py +++ b/CTFd/utils/events/__init__.py @@ -2,7 +2,8 @@ import json from collections import defaultdict from queue import Queue -from gevent import Timeout +from gevent import Timeout, spawn +from tenacity import retry, wait_exponential from CTFd.cache import cache from CTFd.utils import string_types @@ -37,60 +38,85 @@ class ServerSentEvent(object): class EventManager(object): def __init__(self): - self.clients = [] + self.clients = {} def publish(self, data, type=None, channel="ctf"): event = ServerSentEvent(data, type=type) message = event.to_dict() - for client in self.clients: + for client in list(self.clients.values()): client[channel].put(message) return len(self.clients) + def listen(self): + pass + def subscribe(self, channel="ctf"): q = defaultdict(Queue) - self.clients.append(q) - while True: - try: - # Immediately yield a ping event to force Response headers to be set - # or else some reverse proxies will incorrectly buffer SSE - yield ServerSentEvent(data="", type="ping") + self.clients[id(q)] = q + try: + while True: + try: + # Immediately yield a ping event to force Response headers to be set + # or else some reverse proxies will incorrectly buffer SSE + yield ServerSentEvent(data="", type="ping") - with Timeout(10): - message = q[channel].get() - yield ServerSentEvent(**message) - except Timeout: - yield ServerSentEvent(data="", type="ping") - except Exception: - raise + with Timeout(5): + message = q[channel].get() + yield ServerSentEvent(**message) + except Timeout: + yield ServerSentEvent(data="", type="ping") + finally: + del self.clients[id(q)] + del q class RedisEventManager(EventManager): def __init__(self): super(EventManager, self).__init__() self.client = cache.cache._write_client + self.clients = {} def publish(self, data, type=None, channel="ctf"): event = ServerSentEvent(data, type=type) message = json.dumps(event.to_dict()) return self.client.publish(message=message, channel=channel) - def subscribe(self, channel="ctf"): - while True: - pubsub = self.client.pubsub() - pubsub.subscribe(channel) - try: - # Immediately yield a ping event to force Response headers to be set - # or else some reverse proxies will incorrectly buffer SSE - yield ServerSentEvent(data="", type="ping") + def listen(self, channel="ctf"): + @retry(wait=wait_exponential(min=1, max=30)) + def _listen(): + while True: + pubsub = self.client.pubsub() + pubsub.subscribe(channel) + try: + while True: + message = pubsub.get_message( + ignore_subscribe_messages=True, timeout=5 + ) + if message: + if message["type"] == "message": + event = json.loads(message["data"]) + for client in list(self.clients.values()): + client[channel].put(event) + finally: + pubsub.close() - with Timeout(10) as timeout: - for message in pubsub.listen(): - if message["type"] == "message": - event = json.loads(message["data"]) - yield ServerSentEvent(**event) - timeout.cancel() - timeout.start() - except Timeout: - yield ServerSentEvent(data="", type="ping") - except Exception: - raise + spawn(_listen) + + def subscribe(self, channel="ctf"): + q = defaultdict(Queue) + self.clients[id(q)] = q + try: + while True: + try: + # Immediately yield a ping event to force Response headers to be set + # or else some reverse proxies will incorrectly buffer SSE + yield ServerSentEvent(data="", type="ping") + + with Timeout(5): + message = q[channel].get() + yield ServerSentEvent(**message) + except Timeout: + yield ServerSentEvent(data="", type="ping") + finally: + del self.clients[id(q)] + del q diff --git a/CTFd/utils/initialization/__init__.py b/CTFd/utils/initialization/__init__.py index cc948e8b..ca530c8d 100644 --- a/CTFd/utils/initialization/__init__.py +++ b/CTFd/utils/initialization/__init__.py @@ -176,6 +176,7 @@ def init_events(app): app.events_manager = EventManager() else: app.events_manager = EventManager() + app.events_manager.listen() def init_request_processors(app): diff --git a/README.md b/README.md index 4dbd3a98..e9000611 100644 --- a/README.md +++ b/README.md @@ -44,7 +44,7 @@ CTFd is a Capture The Flag framework focusing on ease of use and customizability 1. Install dependencies: `pip install -r requirements.txt` 1. You can also use the `prepare.sh` script to install system dependencies using apt. 2. Modify [CTFd/config.ini](https://github.com/CTFd/CTFd/blob/master/CTFd/config.ini) to your liking. -3. Use `flask run` in a terminal to drop into debug mode. +3. Use `python serve.py` or `flask run` in a terminal to drop into debug mode. You can use the auto-generated Docker images with the following command: diff --git a/development.txt b/development.txt index dec3b4ec..8f4d6848 100644 --- a/development.txt +++ b/development.txt @@ -19,3 +19,4 @@ flake8-isort==3.0.0 Faker==4.1.0 pipdeptree==0.13.2 black==19.10b0 +pytest-sugar==0.9.4 diff --git a/requirements.txt b/requirements.txt index 7b151f94..638df443 100644 --- a/requirements.txt +++ b/requirements.txt @@ -28,3 +28,4 @@ html5lib==1.0.1 WTForms==2.3.1 python-geoacumen==0.0.1 maxminddb==1.5.4 +tenacity==6.2.0 \ No newline at end of file diff --git a/serve.py b/serve.py index 962e5772..f209a240 100644 --- a/serve.py +++ b/serve.py @@ -1,4 +1,3 @@ -from CTFd import create_app import argparse parser = argparse.ArgumentParser() @@ -6,7 +5,20 @@ parser.add_argument("--port", help="Port for debug server to listen on", default parser.add_argument( "--profile", help="Enable flask_profiler profiling", action="store_true" ) +parser.add_argument( + "--disable-gevent", + help="Disable importing gevent and monkey patching", + action="store_false", +) args = parser.parse_args() +if args.disable_gevent: + print(" * Importing gevent and monkey patching. Use --disable-gevent to disable.") + from gevent import monkey + + monkey.patch_all() + +# Import not at top of file to allow gevent to monkey patch uninterrupted +from CTFd import create_app app = create_app() diff --git a/tests/utils/test_events.py b/tests/utils/test_events.py index a3c97151..74d87e0b 100644 --- a/tests/utils/test_events.py +++ b/tests/utils/test_events.py @@ -1,9 +1,7 @@ -import json from collections import defaultdict from queue import Queue from unittest.mock import patch -import redis from redis.exceptions import ConnectionError from CTFd.config import TestingConfig @@ -63,10 +61,11 @@ def test_event_manager_publish(): } event_manager = EventManager() - event_manager.clients.append(defaultdict(Queue)) + q = defaultdict(Queue) + event_manager.clients[id(q)] = q event_manager.publish(data=saved_data, type="notification", channel="ctf") - event = event_manager.clients[0]["ctf"].get() + event = event_manager.clients[id(q)]["ctf"].get() event = ServerSentEvent(**event) assert event.data == saved_data @@ -129,27 +128,19 @@ def test_redis_event_manager_subscription(): else: with app.app_context(): saved_data = { - u"data": { - u"content": u"asdf", - u"date": u"2019-01-28T05:02:19.830906+00:00", - u"id": 13, - u"team": None, - u"team_id": None, - u"title": u"asdf", - u"user": None, - u"user_id": None, - }, - u"type": u"notification", + "user_id": None, + "title": "asdf", + "content": "asdf", + "team_id": None, + "user": None, + "team": None, + "date": "2019-01-28T01:20:46.017649+00:00", + "id": 10, } + saved_event = {"type": "notification", "data": saved_data} - saved_event = { - "pattern": None, - "type": "message", - "channel": "ctf", - "data": json.dumps(saved_data), - } - with patch.object(redis.client.PubSub, "listen") as fake_pubsub_listen: - fake_pubsub_listen.return_value = [saved_event] + with patch.object(Queue, "get") as fake_queue: + fake_queue.return_value = saved_event event_manager = RedisEventManager() events = event_manager.subscribe() @@ -160,7 +151,7 @@ def test_redis_event_manager_subscription(): message = next(events) assert isinstance(message, ServerSentEvent) - assert message.to_dict() == saved_data + assert message.to_dict() == saved_event assert message.__str__().startswith("event:notification\ndata:") destroy_ctfd(app) @@ -193,3 +184,67 @@ def test_redis_event_manager_publish(): event_manager = RedisEventManager() event_manager.publish(data=saved_data, type="notification", channel="ctf") destroy_ctfd(app) + + +def test_redis_event_manager_listen(): + """Test that RedisEventManager listening pubsub works.""" + # This test is nob currently working properly + # This test is sort of incomplete b/c we aren't also subscribing + # I wasnt able to get listening and subscribing to work at the same time + # But the code does work under gunicorn and serve.py + try: + # import importlib + # from gevent.monkey import patch_time, patch_socket + # from gevent import Timeout + + # patch_time() + # patch_socket() + + class RedisConfig(TestingConfig): + REDIS_URL = "redis://localhost:6379/4" + CACHE_REDIS_URL = "redis://localhost:6379/4" + CACHE_TYPE = "redis" + + try: + app = create_ctfd(config=RedisConfig) + except ConnectionError: + print("Failed to connect to redis. Skipping test.") + else: + with app.app_context(): + # saved_event = { + # "data": { + # "team_id": None, + # "user_id": None, + # "content": "asdf", + # "title": "asdf", + # "id": 1, + # "team": None, + # "user": None, + # "date": "2020-08-31T23:57:27.193081+00:00", + # "type": "toast", + # "sound": None, + # }, + # "type": "notification", + # } + + event_manager = RedisEventManager() + + # def disable_retry(f, *args, **kwargs): + # return f() + + # with patch("tenacity.retry", side_effect=disable_retry): + # with Timeout(10): + # event_manager.listen() + event_manager.listen() + + # event_manager.publish( + # data=saved_event["data"], type="notification", channel="ctf" + # ) + destroy_ctfd(app) + finally: + pass + # import socket + # import time + + # importlib.reload(socket) + # importlib.reload(time)