Revert "Add message queue for SSE messages POST endpoint (#459)" (#649)

This commit is contained in:
ihrpr
2025-05-07 16:35:20 +01:00
committed by GitHub
parent c8a14c9dba
commit 9d99aee014
26 changed files with 51 additions and 1247 deletions

View File

@@ -52,11 +52,9 @@ from starlette.responses import Response
from starlette.types import Receive, Scope, Send
import mcp.types as types
from mcp.server.message_queue import InMemoryMessageDispatch, MessageDispatch
from mcp.shared.message import SessionMessage
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.DEBUG)
class SseServerTransport:
@@ -72,24 +70,17 @@ class SseServerTransport:
"""
_endpoint: str
_message_dispatch: MessageDispatch
_read_stream_writers: dict[UUID, MemoryObjectSendStream[SessionMessage | Exception]]
def __init__(
self, endpoint: str, message_dispatch: MessageDispatch | None = None
) -> None:
def __init__(self, endpoint: str) -> None:
"""
Creates a new SSE server transport, which will direct the client to POST
messages to the relative or absolute URL given.
Args:
endpoint: The endpoint URL for SSE connections
message_dispatch: Optional message dispatch to use
"""
super().__init__()
self._endpoint = endpoint
self._message_dispatch = message_dispatch or InMemoryMessageDispatch()
self._read_stream_writers = {}
logger.debug(f"SseServerTransport initialized with endpoint: {endpoint}")
@asynccontextmanager
@@ -110,12 +101,7 @@ class SseServerTransport:
session_id = uuid4()
session_uri = f"{quote(self._endpoint)}?session_id={session_id.hex}"
async def message_callback(message: SessionMessage | Exception) -> None:
"""Callback that receives messages from the message queue"""
logger.debug(f"Got message from queue for session {session_id}")
await read_stream_writer.send(message)
self._read_stream_writers[session_id] = read_stream_writer
logger.debug(f"Created new session with ID: {session_id}")
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[
@@ -152,16 +138,13 @@ class SseServerTransport:
)(scope, receive, send)
await read_stream_writer.aclose()
await write_stream_reader.aclose()
await sse_stream_writer.aclose()
await sse_stream_reader.aclose()
logging.debug(f"Client session disconnected {session_id}")
logger.debug("Starting SSE response task")
tg.start_soon(response_wrapper, scope, receive, send)
async with self._message_dispatch.subscribe(session_id, message_callback):
logger.debug("Yielding read and write streams")
yield (read_stream, write_stream)
logger.debug("Yielding read and write streams")
yield (read_stream, write_stream)
async def handle_post_message(
self, scope: Scope, receive: Receive, send: Send
@@ -183,7 +166,8 @@ class SseServerTransport:
response = Response("Invalid session ID", status_code=400)
return await response(scope, receive, send)
if not await self._message_dispatch.session_exists(session_id):
writer = self._read_stream_writers.get(session_id)
if not writer:
logger.warning(f"Could not find session for ID: {session_id}")
response = Response("Could not find session", status_code=404)
return await response(scope, receive, send)
@@ -198,15 +182,11 @@ class SseServerTransport:
logger.error(f"Failed to parse message: {err}")
response = Response("Could not parse message", status_code=400)
await response(scope, receive, send)
# Pass raw JSON string; receiver will recreate identical ValidationError
# when parsing the same invalid JSON
await self._message_dispatch.publish_message(session_id, body.decode())
await writer.send(err)
return
logger.debug(f"Publishing message for session {session_id}: {message}")
session_message = SessionMessage(message)
logger.debug(f"Sending session message to writer: {session_message}")
response = Response("Accepted", status_code=202)
await response(scope, receive, send)
await self._message_dispatch.publish_message(
session_id, SessionMessage(message=message)
)
logger.debug(f"Sending session message to writer: {message}")
await writer.send(session_message)