mirror of
https://github.com/aljazceru/CTFd.git
synced 2025-12-17 05:54:19 +01:00
Pool pubsub connections for notifications (#1626)
* Add a `listen()` method to `CTFd.utils.events.EventManager` and `CTFd.utils.events.RedisEventManager`. * This method should implement subscription for a CTFd worker to whatever underlying notification system there is. This should be implemented with gevent or a background thread. * The `subscribe()` method (which used to also implement the functionality of the new `listen()` function) now only handles passing notifications from CTFd to the browser. This should also be implemented with gevent or a background thread. * Pool PubSub connections to Redis behind gevent. This improves the notification system by not having a pubsub connection per browser but instead per CTFd worker. This should reduce the difficulty in deploying the Notification system. * Closes #1622 * Make gevent default in serve.py and add a `--disable-gevent` switch in serve.py * Revert to recommending `serve.py` first in README. `flask run` works but we don't get a lot of control. * Add `tenacity` library for retrying logic * Add `pytest-sugar` for slightly prettier pytest output
This commit is contained in:
@@ -2,7 +2,8 @@ import json
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
|
|
||||||
from gevent import Timeout
|
from gevent import Timeout, spawn
|
||||||
|
from tenacity import retry, wait_exponential
|
||||||
|
|
||||||
from CTFd.cache import cache
|
from CTFd.cache import cache
|
||||||
from CTFd.utils import string_types
|
from CTFd.utils import string_types
|
||||||
@@ -37,60 +38,85 @@ class ServerSentEvent(object):
|
|||||||
|
|
||||||
class EventManager(object):
|
class EventManager(object):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.clients = []
|
self.clients = {}
|
||||||
|
|
||||||
def publish(self, data, type=None, channel="ctf"):
|
def publish(self, data, type=None, channel="ctf"):
|
||||||
event = ServerSentEvent(data, type=type)
|
event = ServerSentEvent(data, type=type)
|
||||||
message = event.to_dict()
|
message = event.to_dict()
|
||||||
for client in self.clients:
|
for client in list(self.clients.values()):
|
||||||
client[channel].put(message)
|
client[channel].put(message)
|
||||||
return len(self.clients)
|
return len(self.clients)
|
||||||
|
|
||||||
|
def listen(self):
|
||||||
|
pass
|
||||||
|
|
||||||
def subscribe(self, channel="ctf"):
|
def subscribe(self, channel="ctf"):
|
||||||
q = defaultdict(Queue)
|
q = defaultdict(Queue)
|
||||||
self.clients.append(q)
|
self.clients[id(q)] = q
|
||||||
while True:
|
try:
|
||||||
try:
|
while True:
|
||||||
# Immediately yield a ping event to force Response headers to be set
|
try:
|
||||||
# or else some reverse proxies will incorrectly buffer SSE
|
# Immediately yield a ping event to force Response headers to be set
|
||||||
yield ServerSentEvent(data="", type="ping")
|
# or else some reverse proxies will incorrectly buffer SSE
|
||||||
|
yield ServerSentEvent(data="", type="ping")
|
||||||
|
|
||||||
with Timeout(10):
|
with Timeout(5):
|
||||||
message = q[channel].get()
|
message = q[channel].get()
|
||||||
yield ServerSentEvent(**message)
|
yield ServerSentEvent(**message)
|
||||||
except Timeout:
|
except Timeout:
|
||||||
yield ServerSentEvent(data="", type="ping")
|
yield ServerSentEvent(data="", type="ping")
|
||||||
except Exception:
|
finally:
|
||||||
raise
|
del self.clients[id(q)]
|
||||||
|
del q
|
||||||
|
|
||||||
|
|
||||||
class RedisEventManager(EventManager):
|
class RedisEventManager(EventManager):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(EventManager, self).__init__()
|
super(EventManager, self).__init__()
|
||||||
self.client = cache.cache._write_client
|
self.client = cache.cache._write_client
|
||||||
|
self.clients = {}
|
||||||
|
|
||||||
def publish(self, data, type=None, channel="ctf"):
|
def publish(self, data, type=None, channel="ctf"):
|
||||||
event = ServerSentEvent(data, type=type)
|
event = ServerSentEvent(data, type=type)
|
||||||
message = json.dumps(event.to_dict())
|
message = json.dumps(event.to_dict())
|
||||||
return self.client.publish(message=message, channel=channel)
|
return self.client.publish(message=message, channel=channel)
|
||||||
|
|
||||||
def subscribe(self, channel="ctf"):
|
def listen(self, channel="ctf"):
|
||||||
while True:
|
@retry(wait=wait_exponential(min=1, max=30))
|
||||||
pubsub = self.client.pubsub()
|
def _listen():
|
||||||
pubsub.subscribe(channel)
|
while True:
|
||||||
try:
|
pubsub = self.client.pubsub()
|
||||||
# Immediately yield a ping event to force Response headers to be set
|
pubsub.subscribe(channel)
|
||||||
# or else some reverse proxies will incorrectly buffer SSE
|
try:
|
||||||
yield ServerSentEvent(data="", type="ping")
|
while True:
|
||||||
|
message = pubsub.get_message(
|
||||||
|
ignore_subscribe_messages=True, timeout=5
|
||||||
|
)
|
||||||
|
if message:
|
||||||
|
if message["type"] == "message":
|
||||||
|
event = json.loads(message["data"])
|
||||||
|
for client in list(self.clients.values()):
|
||||||
|
client[channel].put(event)
|
||||||
|
finally:
|
||||||
|
pubsub.close()
|
||||||
|
|
||||||
with Timeout(10) as timeout:
|
spawn(_listen)
|
||||||
for message in pubsub.listen():
|
|
||||||
if message["type"] == "message":
|
def subscribe(self, channel="ctf"):
|
||||||
event = json.loads(message["data"])
|
q = defaultdict(Queue)
|
||||||
yield ServerSentEvent(**event)
|
self.clients[id(q)] = q
|
||||||
timeout.cancel()
|
try:
|
||||||
timeout.start()
|
while True:
|
||||||
except Timeout:
|
try:
|
||||||
yield ServerSentEvent(data="", type="ping")
|
# Immediately yield a ping event to force Response headers to be set
|
||||||
except Exception:
|
# or else some reverse proxies will incorrectly buffer SSE
|
||||||
raise
|
yield ServerSentEvent(data="", type="ping")
|
||||||
|
|
||||||
|
with Timeout(5):
|
||||||
|
message = q[channel].get()
|
||||||
|
yield ServerSentEvent(**message)
|
||||||
|
except Timeout:
|
||||||
|
yield ServerSentEvent(data="", type="ping")
|
||||||
|
finally:
|
||||||
|
del self.clients[id(q)]
|
||||||
|
del q
|
||||||
|
|||||||
@@ -176,6 +176,7 @@ def init_events(app):
|
|||||||
app.events_manager = EventManager()
|
app.events_manager = EventManager()
|
||||||
else:
|
else:
|
||||||
app.events_manager = EventManager()
|
app.events_manager = EventManager()
|
||||||
|
app.events_manager.listen()
|
||||||
|
|
||||||
|
|
||||||
def init_request_processors(app):
|
def init_request_processors(app):
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ CTFd is a Capture The Flag framework focusing on ease of use and customizability
|
|||||||
1. Install dependencies: `pip install -r requirements.txt`
|
1. Install dependencies: `pip install -r requirements.txt`
|
||||||
1. You can also use the `prepare.sh` script to install system dependencies using apt.
|
1. You can also use the `prepare.sh` script to install system dependencies using apt.
|
||||||
2. Modify [CTFd/config.ini](https://github.com/CTFd/CTFd/blob/master/CTFd/config.ini) to your liking.
|
2. Modify [CTFd/config.ini](https://github.com/CTFd/CTFd/blob/master/CTFd/config.ini) to your liking.
|
||||||
3. Use `flask run` in a terminal to drop into debug mode.
|
3. Use `python serve.py` or `flask run` in a terminal to drop into debug mode.
|
||||||
|
|
||||||
You can use the auto-generated Docker images with the following command:
|
You can use the auto-generated Docker images with the following command:
|
||||||
|
|
||||||
|
|||||||
@@ -19,3 +19,4 @@ flake8-isort==3.0.0
|
|||||||
Faker==4.1.0
|
Faker==4.1.0
|
||||||
pipdeptree==0.13.2
|
pipdeptree==0.13.2
|
||||||
black==19.10b0
|
black==19.10b0
|
||||||
|
pytest-sugar==0.9.4
|
||||||
|
|||||||
@@ -28,3 +28,4 @@ html5lib==1.0.1
|
|||||||
WTForms==2.3.1
|
WTForms==2.3.1
|
||||||
python-geoacumen==0.0.1
|
python-geoacumen==0.0.1
|
||||||
maxminddb==1.5.4
|
maxminddb==1.5.4
|
||||||
|
tenacity==6.2.0
|
||||||
14
serve.py
14
serve.py
@@ -1,4 +1,3 @@
|
|||||||
from CTFd import create_app
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
@@ -6,7 +5,20 @@ parser.add_argument("--port", help="Port for debug server to listen on", default
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--profile", help="Enable flask_profiler profiling", action="store_true"
|
"--profile", help="Enable flask_profiler profiling", action="store_true"
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--disable-gevent",
|
||||||
|
help="Disable importing gevent and monkey patching",
|
||||||
|
action="store_false",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
if args.disable_gevent:
|
||||||
|
print(" * Importing gevent and monkey patching. Use --disable-gevent to disable.")
|
||||||
|
from gevent import monkey
|
||||||
|
|
||||||
|
monkey.patch_all()
|
||||||
|
|
||||||
|
# Import not at top of file to allow gevent to monkey patch uninterrupted
|
||||||
|
from CTFd import create_app
|
||||||
|
|
||||||
app = create_app()
|
app = create_app()
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,7 @@
|
|||||||
import json
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import redis
|
|
||||||
from redis.exceptions import ConnectionError
|
from redis.exceptions import ConnectionError
|
||||||
|
|
||||||
from CTFd.config import TestingConfig
|
from CTFd.config import TestingConfig
|
||||||
@@ -63,10 +61,11 @@ def test_event_manager_publish():
|
|||||||
}
|
}
|
||||||
|
|
||||||
event_manager = EventManager()
|
event_manager = EventManager()
|
||||||
event_manager.clients.append(defaultdict(Queue))
|
q = defaultdict(Queue)
|
||||||
|
event_manager.clients[id(q)] = q
|
||||||
event_manager.publish(data=saved_data, type="notification", channel="ctf")
|
event_manager.publish(data=saved_data, type="notification", channel="ctf")
|
||||||
|
|
||||||
event = event_manager.clients[0]["ctf"].get()
|
event = event_manager.clients[id(q)]["ctf"].get()
|
||||||
event = ServerSentEvent(**event)
|
event = ServerSentEvent(**event)
|
||||||
assert event.data == saved_data
|
assert event.data == saved_data
|
||||||
|
|
||||||
@@ -129,27 +128,19 @@ def test_redis_event_manager_subscription():
|
|||||||
else:
|
else:
|
||||||
with app.app_context():
|
with app.app_context():
|
||||||
saved_data = {
|
saved_data = {
|
||||||
u"data": {
|
"user_id": None,
|
||||||
u"content": u"asdf",
|
"title": "asdf",
|
||||||
u"date": u"2019-01-28T05:02:19.830906+00:00",
|
"content": "asdf",
|
||||||
u"id": 13,
|
"team_id": None,
|
||||||
u"team": None,
|
"user": None,
|
||||||
u"team_id": None,
|
"team": None,
|
||||||
u"title": u"asdf",
|
"date": "2019-01-28T01:20:46.017649+00:00",
|
||||||
u"user": None,
|
"id": 10,
|
||||||
u"user_id": None,
|
|
||||||
},
|
|
||||||
u"type": u"notification",
|
|
||||||
}
|
}
|
||||||
|
saved_event = {"type": "notification", "data": saved_data}
|
||||||
|
|
||||||
saved_event = {
|
with patch.object(Queue, "get") as fake_queue:
|
||||||
"pattern": None,
|
fake_queue.return_value = saved_event
|
||||||
"type": "message",
|
|
||||||
"channel": "ctf",
|
|
||||||
"data": json.dumps(saved_data),
|
|
||||||
}
|
|
||||||
with patch.object(redis.client.PubSub, "listen") as fake_pubsub_listen:
|
|
||||||
fake_pubsub_listen.return_value = [saved_event]
|
|
||||||
event_manager = RedisEventManager()
|
event_manager = RedisEventManager()
|
||||||
|
|
||||||
events = event_manager.subscribe()
|
events = event_manager.subscribe()
|
||||||
@@ -160,7 +151,7 @@ def test_redis_event_manager_subscription():
|
|||||||
|
|
||||||
message = next(events)
|
message = next(events)
|
||||||
assert isinstance(message, ServerSentEvent)
|
assert isinstance(message, ServerSentEvent)
|
||||||
assert message.to_dict() == saved_data
|
assert message.to_dict() == saved_event
|
||||||
assert message.__str__().startswith("event:notification\ndata:")
|
assert message.__str__().startswith("event:notification\ndata:")
|
||||||
destroy_ctfd(app)
|
destroy_ctfd(app)
|
||||||
|
|
||||||
@@ -193,3 +184,67 @@ def test_redis_event_manager_publish():
|
|||||||
event_manager = RedisEventManager()
|
event_manager = RedisEventManager()
|
||||||
event_manager.publish(data=saved_data, type="notification", channel="ctf")
|
event_manager.publish(data=saved_data, type="notification", channel="ctf")
|
||||||
destroy_ctfd(app)
|
destroy_ctfd(app)
|
||||||
|
|
||||||
|
|
||||||
|
def test_redis_event_manager_listen():
|
||||||
|
"""Test that RedisEventManager listening pubsub works."""
|
||||||
|
# This test is nob currently working properly
|
||||||
|
# This test is sort of incomplete b/c we aren't also subscribing
|
||||||
|
# I wasnt able to get listening and subscribing to work at the same time
|
||||||
|
# But the code does work under gunicorn and serve.py
|
||||||
|
try:
|
||||||
|
# import importlib
|
||||||
|
# from gevent.monkey import patch_time, patch_socket
|
||||||
|
# from gevent import Timeout
|
||||||
|
|
||||||
|
# patch_time()
|
||||||
|
# patch_socket()
|
||||||
|
|
||||||
|
class RedisConfig(TestingConfig):
|
||||||
|
REDIS_URL = "redis://localhost:6379/4"
|
||||||
|
CACHE_REDIS_URL = "redis://localhost:6379/4"
|
||||||
|
CACHE_TYPE = "redis"
|
||||||
|
|
||||||
|
try:
|
||||||
|
app = create_ctfd(config=RedisConfig)
|
||||||
|
except ConnectionError:
|
||||||
|
print("Failed to connect to redis. Skipping test.")
|
||||||
|
else:
|
||||||
|
with app.app_context():
|
||||||
|
# saved_event = {
|
||||||
|
# "data": {
|
||||||
|
# "team_id": None,
|
||||||
|
# "user_id": None,
|
||||||
|
# "content": "asdf",
|
||||||
|
# "title": "asdf",
|
||||||
|
# "id": 1,
|
||||||
|
# "team": None,
|
||||||
|
# "user": None,
|
||||||
|
# "date": "2020-08-31T23:57:27.193081+00:00",
|
||||||
|
# "type": "toast",
|
||||||
|
# "sound": None,
|
||||||
|
# },
|
||||||
|
# "type": "notification",
|
||||||
|
# }
|
||||||
|
|
||||||
|
event_manager = RedisEventManager()
|
||||||
|
|
||||||
|
# def disable_retry(f, *args, **kwargs):
|
||||||
|
# return f()
|
||||||
|
|
||||||
|
# with patch("tenacity.retry", side_effect=disable_retry):
|
||||||
|
# with Timeout(10):
|
||||||
|
# event_manager.listen()
|
||||||
|
event_manager.listen()
|
||||||
|
|
||||||
|
# event_manager.publish(
|
||||||
|
# data=saved_event["data"], type="notification", channel="ctf"
|
||||||
|
# )
|
||||||
|
destroy_ctfd(app)
|
||||||
|
finally:
|
||||||
|
pass
|
||||||
|
# import socket
|
||||||
|
# import time
|
||||||
|
|
||||||
|
# importlib.reload(socket)
|
||||||
|
# importlib.reload(time)
|
||||||
|
|||||||
Reference in New Issue
Block a user