Wrap JSONRPC messages with SessionMessage for metadata support (#590)

This commit is contained in:
ihrpr
2025-05-02 14:29:00 +01:00
committed by GitHub
parent 3978c6e1b9
commit da0cf22355
22 changed files with 286 additions and 173 deletions

View File

@@ -46,6 +46,7 @@ from starlette.responses import Response
from starlette.types import Receive, Scope, Send
import mcp.types as types
from mcp.shared.message import SessionMessage
logger = logging.getLogger(__name__)
@@ -63,9 +64,7 @@ class SseServerTransport:
"""
_endpoint: str
_read_stream_writers: dict[
UUID, MemoryObjectSendStream[types.JSONRPCMessage | Exception]
]
_read_stream_writers: dict[UUID, MemoryObjectSendStream[SessionMessage | Exception]]
def __init__(self, endpoint: str) -> None:
"""
@@ -85,11 +84,11 @@ class SseServerTransport:
raise ValueError("connect_sse can only handle HTTP requests")
logger.debug("Setting up SSE connection")
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
write_stream: MemoryObjectSendStream[SessionMessage]
write_stream_reader: MemoryObjectReceiveStream[SessionMessage]
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
@@ -109,12 +108,12 @@ class SseServerTransport:
await sse_stream_writer.send({"event": "endpoint", "data": session_uri})
logger.debug(f"Sent endpoint event: {session_uri}")
async for message in write_stream_reader:
logger.debug(f"Sending message via SSE: {message}")
async for session_message in write_stream_reader:
logger.debug(f"Sending message via SSE: {session_message}")
await sse_stream_writer.send(
{
"event": "message",
"data": message.model_dump_json(
"data": session_message.message.model_dump_json(
by_alias=True, exclude_none=True
),
}
@@ -169,7 +168,8 @@ class SseServerTransport:
await writer.send(err)
return
logger.debug(f"Sending message to writer: {message}")
session_message = SessionMessage(message)
logger.debug(f"Sending session message to writer: {session_message}")
response = Response("Accepted", status_code=202)
await response(scope, receive, send)
await writer.send(message)
await writer.send(session_message)