mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 14:54:24 +01:00
Wrap JSONRPC messages with SessionMessage for metadata support (#590)
This commit is contained in:
@@ -12,6 +12,7 @@ from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
from mcp.shared.exceptions import McpError
|
||||
from mcp.shared.message import SessionMessage
|
||||
from mcp.types import (
|
||||
CancelledNotification,
|
||||
ClientNotification,
|
||||
@@ -172,8 +173,8 @@ class BaseSession(
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
||||
write_stream: MemoryObjectSendStream[JSONRPCMessage],
|
||||
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
|
||||
write_stream: MemoryObjectSendStream[SessionMessage],
|
||||
receive_request_type: type[ReceiveRequestT],
|
||||
receive_notification_type: type[ReceiveNotificationT],
|
||||
# If none, reading will never time out
|
||||
@@ -240,7 +241,9 @@ class BaseSession(
|
||||
|
||||
# TODO: Support progress callbacks
|
||||
|
||||
await self._write_stream.send(JSONRPCMessage(jsonrpc_request))
|
||||
await self._write_stream.send(
|
||||
SessionMessage(message=JSONRPCMessage(jsonrpc_request))
|
||||
)
|
||||
|
||||
# request read timeout takes precedence over session read timeout
|
||||
timeout = None
|
||||
@@ -300,14 +303,16 @@ class BaseSession(
|
||||
jsonrpc="2.0",
|
||||
**notification.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
)
|
||||
await self._write_stream.send(JSONRPCMessage(jsonrpc_notification))
|
||||
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_notification))
|
||||
await self._write_stream.send(session_message)
|
||||
|
||||
async def _send_response(
|
||||
self, request_id: RequestId, response: SendResultT | ErrorData
|
||||
) -> None:
|
||||
if isinstance(response, ErrorData):
|
||||
jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response)
|
||||
await self._write_stream.send(JSONRPCMessage(jsonrpc_error))
|
||||
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error))
|
||||
await self._write_stream.send(session_message)
|
||||
else:
|
||||
jsonrpc_response = JSONRPCResponse(
|
||||
jsonrpc="2.0",
|
||||
@@ -316,7 +321,8 @@ class BaseSession(
|
||||
by_alias=True, mode="json", exclude_none=True
|
||||
),
|
||||
)
|
||||
await self._write_stream.send(JSONRPCMessage(jsonrpc_response))
|
||||
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_response))
|
||||
await self._write_stream.send(session_message)
|
||||
|
||||
async def _receive_loop(self) -> None:
|
||||
async with (
|
||||
@@ -326,15 +332,15 @@ class BaseSession(
|
||||
async for message in self._read_stream:
|
||||
if isinstance(message, Exception):
|
||||
await self._handle_incoming(message)
|
||||
elif isinstance(message.root, JSONRPCRequest):
|
||||
elif isinstance(message.message.root, JSONRPCRequest):
|
||||
validated_request = self._receive_request_type.model_validate(
|
||||
message.root.model_dump(
|
||||
message.message.root.model_dump(
|
||||
by_alias=True, mode="json", exclude_none=True
|
||||
)
|
||||
)
|
||||
|
||||
responder = RequestResponder(
|
||||
request_id=message.root.id,
|
||||
request_id=message.message.root.id,
|
||||
request_meta=validated_request.root.params.meta
|
||||
if validated_request.root.params
|
||||
else None,
|
||||
@@ -349,10 +355,10 @@ class BaseSession(
|
||||
if not responder._completed: # type: ignore[reportPrivateUsage]
|
||||
await self._handle_incoming(responder)
|
||||
|
||||
elif isinstance(message.root, JSONRPCNotification):
|
||||
elif isinstance(message.message.root, JSONRPCNotification):
|
||||
try:
|
||||
notification = self._receive_notification_type.model_validate(
|
||||
message.root.model_dump(
|
||||
message.message.root.model_dump(
|
||||
by_alias=True, mode="json", exclude_none=True
|
||||
)
|
||||
)
|
||||
@@ -368,12 +374,12 @@ class BaseSession(
|
||||
# For other validation errors, log and continue
|
||||
logging.warning(
|
||||
f"Failed to validate notification: {e}. "
|
||||
f"Message was: {message.root}"
|
||||
f"Message was: {message.message.root}"
|
||||
)
|
||||
else: # Response or error
|
||||
stream = self._response_streams.pop(message.root.id, None)
|
||||
stream = self._response_streams.pop(message.message.root.id, None)
|
||||
if stream:
|
||||
await stream.send(message.root)
|
||||
await stream.send(message.message.root)
|
||||
else:
|
||||
await self._handle_incoming(
|
||||
RuntimeError(
|
||||
|
||||
Reference in New Issue
Block a user