Wrap JSONRPC messages with SessionMessage for metadata support (#590)

This commit is contained in:
ihrpr
2025-05-02 14:29:00 +01:00
committed by GitHub
parent 3978c6e1b9
commit da0cf22355
22 changed files with 286 additions and 173 deletions

View File

@@ -12,6 +12,7 @@ from pydantic import BaseModel
from typing_extensions import Self
from mcp.shared.exceptions import McpError
from mcp.shared.message import SessionMessage
from mcp.types import (
CancelledNotification,
ClientNotification,
@@ -172,8 +173,8 @@ class BaseSession(
def __init__(
self,
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
write_stream: MemoryObjectSendStream[JSONRPCMessage],
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
write_stream: MemoryObjectSendStream[SessionMessage],
receive_request_type: type[ReceiveRequestT],
receive_notification_type: type[ReceiveNotificationT],
# If none, reading will never time out
@@ -240,7 +241,9 @@ class BaseSession(
# TODO: Support progress callbacks
await self._write_stream.send(JSONRPCMessage(jsonrpc_request))
await self._write_stream.send(
SessionMessage(message=JSONRPCMessage(jsonrpc_request))
)
# request read timeout takes precedence over session read timeout
timeout = None
@@ -300,14 +303,16 @@ class BaseSession(
jsonrpc="2.0",
**notification.model_dump(by_alias=True, mode="json", exclude_none=True),
)
await self._write_stream.send(JSONRPCMessage(jsonrpc_notification))
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_notification))
await self._write_stream.send(session_message)
async def _send_response(
self, request_id: RequestId, response: SendResultT | ErrorData
) -> None:
if isinstance(response, ErrorData):
jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response)
await self._write_stream.send(JSONRPCMessage(jsonrpc_error))
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error))
await self._write_stream.send(session_message)
else:
jsonrpc_response = JSONRPCResponse(
jsonrpc="2.0",
@@ -316,7 +321,8 @@ class BaseSession(
by_alias=True, mode="json", exclude_none=True
),
)
await self._write_stream.send(JSONRPCMessage(jsonrpc_response))
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_response))
await self._write_stream.send(session_message)
async def _receive_loop(self) -> None:
async with (
@@ -326,15 +332,15 @@ class BaseSession(
async for message in self._read_stream:
if isinstance(message, Exception):
await self._handle_incoming(message)
elif isinstance(message.root, JSONRPCRequest):
elif isinstance(message.message.root, JSONRPCRequest):
validated_request = self._receive_request_type.model_validate(
message.root.model_dump(
message.message.root.model_dump(
by_alias=True, mode="json", exclude_none=True
)
)
responder = RequestResponder(
request_id=message.root.id,
request_id=message.message.root.id,
request_meta=validated_request.root.params.meta
if validated_request.root.params
else None,
@@ -349,10 +355,10 @@ class BaseSession(
if not responder._completed: # type: ignore[reportPrivateUsage]
await self._handle_incoming(responder)
elif isinstance(message.root, JSONRPCNotification):
elif isinstance(message.message.root, JSONRPCNotification):
try:
notification = self._receive_notification_type.model_validate(
message.root.model_dump(
message.message.root.model_dump(
by_alias=True, mode="json", exclude_none=True
)
)
@@ -368,12 +374,12 @@ class BaseSession(
# For other validation errors, log and continue
logging.warning(
f"Failed to validate notification: {e}. "
f"Message was: {message.root}"
f"Message was: {message.message.root}"
)
else: # Response or error
stream = self._response_streams.pop(message.root.id, None)
stream = self._response_streams.pop(message.message.root.id, None)
if stream:
await stream.send(message.root)
await stream.send(message.message.root)
else:
await self._handle_incoming(
RuntimeError(