mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-23 08:44:22 +01:00
Add message queue for SSE messages POST endpoint (#459)
This commit is contained in:
@@ -52,9 +52,11 @@ from starlette.responses import Response
|
||||
from starlette.types import Receive, Scope, Send
|
||||
|
||||
import mcp.types as types
|
||||
from mcp.server.message_queue import InMemoryMessageDispatch, MessageDispatch
|
||||
from mcp.shared.message import SessionMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
|
||||
class SseServerTransport:
|
||||
@@ -70,17 +72,24 @@ class SseServerTransport:
|
||||
"""
|
||||
|
||||
_endpoint: str
|
||||
_message_dispatch: MessageDispatch
|
||||
_read_stream_writers: dict[UUID, MemoryObjectSendStream[SessionMessage | Exception]]
|
||||
|
||||
def __init__(self, endpoint: str) -> None:
|
||||
def __init__(
|
||||
self, endpoint: str, message_dispatch: MessageDispatch | None = None
|
||||
) -> None:
|
||||
"""
|
||||
Creates a new SSE server transport, which will direct the client to POST
|
||||
messages to the relative or absolute URL given.
|
||||
|
||||
Args:
|
||||
endpoint: The endpoint URL for SSE connections
|
||||
message_dispatch: Optional message dispatch to use
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
self._endpoint = endpoint
|
||||
self._read_stream_writers = {}
|
||||
self._message_dispatch = message_dispatch or InMemoryMessageDispatch()
|
||||
logger.debug(f"SseServerTransport initialized with endpoint: {endpoint}")
|
||||
|
||||
@asynccontextmanager
|
||||
@@ -101,7 +110,12 @@ class SseServerTransport:
|
||||
|
||||
session_id = uuid4()
|
||||
session_uri = f"{quote(self._endpoint)}?session_id={session_id.hex}"
|
||||
self._read_stream_writers[session_id] = read_stream_writer
|
||||
|
||||
async def message_callback(message: SessionMessage | Exception) -> None:
|
||||
"""Callback that receives messages from the message queue"""
|
||||
logger.debug(f"Got message from queue for session {session_id}")
|
||||
await read_stream_writer.send(message)
|
||||
|
||||
logger.debug(f"Created new session with ID: {session_id}")
|
||||
|
||||
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[
|
||||
@@ -138,13 +152,16 @@ class SseServerTransport:
|
||||
)(scope, receive, send)
|
||||
await read_stream_writer.aclose()
|
||||
await write_stream_reader.aclose()
|
||||
await sse_stream_writer.aclose()
|
||||
await sse_stream_reader.aclose()
|
||||
logging.debug(f"Client session disconnected {session_id}")
|
||||
|
||||
logger.debug("Starting SSE response task")
|
||||
tg.start_soon(response_wrapper, scope, receive, send)
|
||||
|
||||
logger.debug("Yielding read and write streams")
|
||||
yield (read_stream, write_stream)
|
||||
async with self._message_dispatch.subscribe(session_id, message_callback):
|
||||
logger.debug("Yielding read and write streams")
|
||||
yield (read_stream, write_stream)
|
||||
|
||||
async def handle_post_message(
|
||||
self, scope: Scope, receive: Receive, send: Send
|
||||
@@ -166,8 +183,7 @@ class SseServerTransport:
|
||||
response = Response("Invalid session ID", status_code=400)
|
||||
return await response(scope, receive, send)
|
||||
|
||||
writer = self._read_stream_writers.get(session_id)
|
||||
if not writer:
|
||||
if not await self._message_dispatch.session_exists(session_id):
|
||||
logger.warning(f"Could not find session for ID: {session_id}")
|
||||
response = Response("Could not find session", status_code=404)
|
||||
return await response(scope, receive, send)
|
||||
@@ -182,11 +198,15 @@ class SseServerTransport:
|
||||
logger.error(f"Failed to parse message: {err}")
|
||||
response = Response("Could not parse message", status_code=400)
|
||||
await response(scope, receive, send)
|
||||
await writer.send(err)
|
||||
# Pass raw JSON string; receiver will recreate identical ValidationError
|
||||
# when parsing the same invalid JSON
|
||||
await self._message_dispatch.publish_message(session_id, body.decode())
|
||||
return
|
||||
|
||||
session_message = SessionMessage(message)
|
||||
logger.debug(f"Sending session message to writer: {session_message}")
|
||||
logger.debug(f"Publishing message for session {session_id}: {message}")
|
||||
response = Response("Accepted", status_code=202)
|
||||
await response(scope, receive, send)
|
||||
await writer.send(session_message)
|
||||
await self._message_dispatch.publish_message(
|
||||
session_id, SessionMessage(message=message)
|
||||
)
|
||||
logger.debug(f"Sending session message to writer: {message}")
|
||||
|
||||
Reference in New Issue
Block a user