diff --git a/CTFd/api/v1/notifications.py b/CTFd/api/v1/notifications.py index 70143367..072430e7 100644 --- a/CTFd/api/v1/notifications.py +++ b/CTFd/api/v1/notifications.py @@ -1,6 +1,6 @@ from typing import List -from flask import current_app, request +from flask import current_app, make_response, request from flask_restx import Namespace, Resource from CTFd.api.v1.helpers.request import validate_args @@ -60,6 +60,7 @@ class NotificantionList(Resource): RawEnum("NotificationFields", {"title": "title", "content": "content"}), None, ), + "since_id": (int, None), }, location="query", ) @@ -68,6 +69,10 @@ class NotificantionList(Resource): field = str(query_args.pop("field", None)) filters = build_model_filters(model=Notifications, query=q, field=field) + since_id = query_args.pop("since_id", None) + if since_id: + filters.append((Notifications.id > since_id)) + notifications = ( Notifications.query.filter_by(**query_args).filter(*filters).all() ) @@ -77,6 +82,41 @@ class NotificantionList(Resource): return {"success": False, "errors": result.errors}, 400 return {"success": True, "data": result.data} + @notifications_namespace.doc( + description="Endpoint to get statistics for notification objects in bulk", + responses={200: ("Success", "APISimpleSuccessResponse")}, + ) + @validate_args( + { + "title": (str, None), + "content": (str, None), + "user_id": (int, None), + "team_id": (int, None), + "q": (str, None), + "field": ( + RawEnum("NotificationFields", {"title": "title", "content": "content"}), + None, + ), + "since_id": (int, None), + }, + location="query", + ) + def head(self, query_args): + q = query_args.pop("q", None) + field = str(query_args.pop("field", None)) + filters = build_model_filters(model=Notifications, query=q, field=field) + + since_id = query_args.pop("since_id", None) + if since_id: + filters.append((Notifications.id > since_id)) + + notification_count = ( + Notifications.query.filter_by(**query_args).filter(*filters).count() + ) + response = make_response() + response.headers["Result-Count"] = notification_count + return response + @admins_only @notifications_namespace.doc( description="Endpoint to create a notification object", diff --git a/CTFd/utils/events/__init__.py b/CTFd/utils/events/__init__.py index d9697c57..d0f30a82 100644 --- a/CTFd/utils/events/__init__.py +++ b/CTFd/utils/events/__init__.py @@ -40,8 +40,8 @@ class EventManager(object): def __init__(self): self.clients = {} - def publish(self, data, type=None, channel="ctf"): - event = ServerSentEvent(data, type=type) + def publish(self, data, type=None, id=None, channel="ctf"): + event = ServerSentEvent(data, type=type, id=id) message = event.to_dict() for client in list(self.clients.values()): client[channel].put(message) @@ -56,14 +56,14 @@ class EventManager(object): try: # Immediately yield a ping event to force Response headers to be set # or else some reverse proxies will incorrectly buffer SSE - yield ServerSentEvent(data="", type="ping") + yield ServerSentEvent(data="ping", type="ping") while True: try: with Timeout(5): message = q[channel].get() yield ServerSentEvent(**message) except Timeout: - yield ServerSentEvent(data="", type="ping") + yield ServerSentEvent(data="ping", type="ping") finally: del self.clients[id(q)] del q @@ -75,8 +75,8 @@ class RedisEventManager(EventManager): self.client = cache.cache._write_client self.clients = {} - def publish(self, data, type=None, channel="ctf"): - event = ServerSentEvent(data, type=type) + def publish(self, data, type=None, id=None, channel="ctf"): + event = ServerSentEvent(data, type=type, id=id) message = json.dumps(event.to_dict()) return self.client.publish(message=message, channel=channel) @@ -107,14 +107,14 @@ class RedisEventManager(EventManager): try: # Immediately yield a ping event to force Response headers to be set # or else some reverse proxies will incorrectly buffer SSE - yield ServerSentEvent(data="", type="ping") + yield ServerSentEvent(data="ping", type="ping") while True: try: with Timeout(5): message = q[channel].get() yield ServerSentEvent(**message) except Timeout: - yield ServerSentEvent(data="", type="ping") + yield ServerSentEvent(data="ping", type="ping") finally: del self.clients[id(q)] del q diff --git a/tests/utils/test_events.py b/tests/utils/test_events.py index 74d87e0b..30eddf06 100644 --- a/tests/utils/test_events.py +++ b/tests/utils/test_events.py @@ -36,7 +36,7 @@ def test_event_manager_subscription(): events = event_manager.subscribe() message = next(events) assert isinstance(message, ServerSentEvent) - assert message.to_dict() == {"data": "", "type": "ping"} + assert message.to_dict() == {"data": "ping", "type": "ping"} assert message.__str__().startswith("event:ping") assert len(event_manager.clients) == 1 @@ -146,7 +146,7 @@ def test_redis_event_manager_subscription(): events = event_manager.subscribe() message = next(events) assert isinstance(message, ServerSentEvent) - assert message.to_dict() == {"data": "", "type": "ping"} + assert message.to_dict() == {"data": "ping", "type": "ping"} assert message.__str__().startswith("event:ping") message = next(events)