mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-18 22:44:20 +01:00
use metadata from SessionMessage to propagate related_request_id (#591)
This commit is contained in:
@@ -24,7 +24,7 @@ from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
from starlette.types import Receive, Scope, Send
|
||||
|
||||
from mcp.shared.message import SessionMessage
|
||||
from mcp.shared.message import ServerMessageMetadata, SessionMessage
|
||||
from mcp.types import (
|
||||
INTERNAL_ERROR,
|
||||
INVALID_PARAMS,
|
||||
@@ -520,7 +520,7 @@ class StreamableHTTPServerTransport:
|
||||
)
|
||||
await response(scope, receive, send)
|
||||
if writer:
|
||||
await writer.send(err)
|
||||
await writer.send(Exception(err))
|
||||
return
|
||||
|
||||
async def _handle_get_request(self, request: Request, send: Send) -> None:
|
||||
@@ -834,12 +834,17 @@ class StreamableHTTPServerTransport:
|
||||
):
|
||||
# Extract related_request_id from meta if it exists
|
||||
if (
|
||||
(params := getattr(message.root, "params", None))
|
||||
and (meta := params.get("_meta"))
|
||||
and (related_id := meta.get("related_request_id"))
|
||||
session_message.metadata is not None
|
||||
and isinstance(
|
||||
session_message.metadata,
|
||||
ServerMessageMetadata,
|
||||
)
|
||||
and session_message.metadata.related_request_id
|
||||
is not None
|
||||
):
|
||||
target_request_id = str(related_id)
|
||||
target_request_id = str(
|
||||
session_message.metadata.related_request_id
|
||||
)
|
||||
else:
|
||||
target_request_id = str(message.root.id)
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
from mcp.shared.exceptions import McpError
|
||||
from mcp.shared.message import SessionMessage
|
||||
from mcp.shared.message import ServerMessageMetadata, SessionMessage
|
||||
from mcp.types import (
|
||||
CancelledNotification,
|
||||
ClientNotification,
|
||||
@@ -24,7 +24,6 @@ from mcp.types import (
|
||||
JSONRPCNotification,
|
||||
JSONRPCRequest,
|
||||
JSONRPCResponse,
|
||||
NotificationParams,
|
||||
RequestParams,
|
||||
ServerNotification,
|
||||
ServerRequest,
|
||||
@@ -288,22 +287,16 @@ class BaseSession(
|
||||
"""
|
||||
# Some transport implementations may need to set the related_request_id
|
||||
# to attribute to the notifications to the request that triggered them.
|
||||
if related_request_id is not None and notification.root.params is not None:
|
||||
# Create meta if it doesn't exist
|
||||
if notification.root.params.meta is None:
|
||||
meta_dict = {"related_request_id": related_request_id}
|
||||
|
||||
else:
|
||||
meta_dict = notification.root.params.meta.model_dump(
|
||||
by_alias=True, mode="json", exclude_none=True
|
||||
)
|
||||
meta_dict["related_request_id"] = related_request_id
|
||||
notification.root.params.meta = NotificationParams.Meta(**meta_dict)
|
||||
jsonrpc_notification = JSONRPCNotification(
|
||||
jsonrpc="2.0",
|
||||
**notification.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
)
|
||||
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_notification))
|
||||
session_message = SessionMessage(
|
||||
message=JSONRPCMessage(jsonrpc_notification),
|
||||
metadata=ServerMessageMetadata(related_request_id=related_request_id)
|
||||
if related_request_id
|
||||
else None,
|
||||
)
|
||||
await self._write_stream.send(session_message)
|
||||
|
||||
async def _send_response(
|
||||
|
||||
@@ -9,7 +9,6 @@ from mcp.shared.memory import (
|
||||
from mcp.shared.session import RequestResponder
|
||||
from mcp.types import (
|
||||
LoggingMessageNotificationParams,
|
||||
NotificationParams,
|
||||
TextContent,
|
||||
)
|
||||
|
||||
@@ -80,10 +79,7 @@ async def test_logging_callback():
|
||||
assert log_result.isError is False
|
||||
assert len(logging_collector.log_messages) == 1
|
||||
# Create meta object with related_request_id added dynamically
|
||||
meta = NotificationParams.Meta()
|
||||
setattr(meta, "related_request_id", "2")
|
||||
log = logging_collector.log_messages[0]
|
||||
assert log.level == "info"
|
||||
assert log.logger == "test_logger"
|
||||
assert log.data == "Test log message"
|
||||
assert log.meta == meta
|
||||
|
||||
Reference in New Issue
Block a user