use metadata from SessionMessage to propagate related_request_id (#591)

This commit is contained in:
ihrpr
2025-05-02 14:35:17 +01:00
committed by GitHub
parent da0cf22355
commit cf8b66b82f
3 changed files with 18 additions and 24 deletions

View File

@@ -24,7 +24,7 @@ from starlette.requests import Request
from starlette.responses import Response from starlette.responses import Response
from starlette.types import Receive, Scope, Send from starlette.types import Receive, Scope, Send
from mcp.shared.message import SessionMessage from mcp.shared.message import ServerMessageMetadata, SessionMessage
from mcp.types import ( from mcp.types import (
INTERNAL_ERROR, INTERNAL_ERROR,
INVALID_PARAMS, INVALID_PARAMS,
@@ -520,7 +520,7 @@ class StreamableHTTPServerTransport:
) )
await response(scope, receive, send) await response(scope, receive, send)
if writer: if writer:
await writer.send(err) await writer.send(Exception(err))
return return
async def _handle_get_request(self, request: Request, send: Send) -> None: async def _handle_get_request(self, request: Request, send: Send) -> None:
@@ -834,12 +834,17 @@ class StreamableHTTPServerTransport:
): ):
# Extract related_request_id from meta if it exists # Extract related_request_id from meta if it exists
if ( if (
(params := getattr(message.root, "params", None)) session_message.metadata is not None
and (meta := params.get("_meta")) and isinstance(
and (related_id := meta.get("related_request_id")) session_message.metadata,
ServerMessageMetadata,
)
and session_message.metadata.related_request_id
is not None is not None
): ):
target_request_id = str(related_id) target_request_id = str(
session_message.metadata.related_request_id
)
else: else:
target_request_id = str(message.root.id) target_request_id = str(message.root.id)

View File

@@ -12,7 +12,7 @@ from pydantic import BaseModel
from typing_extensions import Self from typing_extensions import Self
from mcp.shared.exceptions import McpError from mcp.shared.exceptions import McpError
from mcp.shared.message import SessionMessage from mcp.shared.message import ServerMessageMetadata, SessionMessage
from mcp.types import ( from mcp.types import (
CancelledNotification, CancelledNotification,
ClientNotification, ClientNotification,
@@ -24,7 +24,6 @@ from mcp.types import (
JSONRPCNotification, JSONRPCNotification,
JSONRPCRequest, JSONRPCRequest,
JSONRPCResponse, JSONRPCResponse,
NotificationParams,
RequestParams, RequestParams,
ServerNotification, ServerNotification,
ServerRequest, ServerRequest,
@@ -288,22 +287,16 @@ class BaseSession(
""" """
# Some transport implementations may need to set the related_request_id # Some transport implementations may need to set the related_request_id
# to attribute to the notifications to the request that triggered them. # to attribute to the notifications to the request that triggered them.
if related_request_id is not None and notification.root.params is not None:
# Create meta if it doesn't exist
if notification.root.params.meta is None:
meta_dict = {"related_request_id": related_request_id}
else:
meta_dict = notification.root.params.meta.model_dump(
by_alias=True, mode="json", exclude_none=True
)
meta_dict["related_request_id"] = related_request_id
notification.root.params.meta = NotificationParams.Meta(**meta_dict)
jsonrpc_notification = JSONRPCNotification( jsonrpc_notification = JSONRPCNotification(
jsonrpc="2.0", jsonrpc="2.0",
**notification.model_dump(by_alias=True, mode="json", exclude_none=True), **notification.model_dump(by_alias=True, mode="json", exclude_none=True),
) )
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_notification)) session_message = SessionMessage(
message=JSONRPCMessage(jsonrpc_notification),
metadata=ServerMessageMetadata(related_request_id=related_request_id)
if related_request_id
else None,
)
await self._write_stream.send(session_message) await self._write_stream.send(session_message)
async def _send_response( async def _send_response(

View File

@@ -9,7 +9,6 @@ from mcp.shared.memory import (
from mcp.shared.session import RequestResponder from mcp.shared.session import RequestResponder
from mcp.types import ( from mcp.types import (
LoggingMessageNotificationParams, LoggingMessageNotificationParams,
NotificationParams,
TextContent, TextContent,
) )
@@ -80,10 +79,7 @@ async def test_logging_callback():
assert log_result.isError is False assert log_result.isError is False
assert len(logging_collector.log_messages) == 1 assert len(logging_collector.log_messages) == 1
# Create meta object with related_request_id added dynamically # Create meta object with related_request_id added dynamically
meta = NotificationParams.Meta()
setattr(meta, "related_request_id", "2")
log = logging_collector.log_messages[0] log = logging_collector.log_messages[0]
assert log.level == "info" assert log.level == "info"
assert log.logger == "test_logger" assert log.logger == "test_logger"
assert log.data == "Test log message" assert log.data == "Test log message"
assert log.meta == meta