mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 06:54:18 +01:00
use metadata from SessionMessage to propagate related_request_id (#591)
This commit is contained in:
@@ -24,7 +24,7 @@ from starlette.requests import Request
|
|||||||
from starlette.responses import Response
|
from starlette.responses import Response
|
||||||
from starlette.types import Receive, Scope, Send
|
from starlette.types import Receive, Scope, Send
|
||||||
|
|
||||||
from mcp.shared.message import SessionMessage
|
from mcp.shared.message import ServerMessageMetadata, SessionMessage
|
||||||
from mcp.types import (
|
from mcp.types import (
|
||||||
INTERNAL_ERROR,
|
INTERNAL_ERROR,
|
||||||
INVALID_PARAMS,
|
INVALID_PARAMS,
|
||||||
@@ -520,7 +520,7 @@ class StreamableHTTPServerTransport:
|
|||||||
)
|
)
|
||||||
await response(scope, receive, send)
|
await response(scope, receive, send)
|
||||||
if writer:
|
if writer:
|
||||||
await writer.send(err)
|
await writer.send(Exception(err))
|
||||||
return
|
return
|
||||||
|
|
||||||
async def _handle_get_request(self, request: Request, send: Send) -> None:
|
async def _handle_get_request(self, request: Request, send: Send) -> None:
|
||||||
@@ -834,12 +834,17 @@ class StreamableHTTPServerTransport:
|
|||||||
):
|
):
|
||||||
# Extract related_request_id from meta if it exists
|
# Extract related_request_id from meta if it exists
|
||||||
if (
|
if (
|
||||||
(params := getattr(message.root, "params", None))
|
session_message.metadata is not None
|
||||||
and (meta := params.get("_meta"))
|
and isinstance(
|
||||||
and (related_id := meta.get("related_request_id"))
|
session_message.metadata,
|
||||||
|
ServerMessageMetadata,
|
||||||
|
)
|
||||||
|
and session_message.metadata.related_request_id
|
||||||
is not None
|
is not None
|
||||||
):
|
):
|
||||||
target_request_id = str(related_id)
|
target_request_id = str(
|
||||||
|
session_message.metadata.related_request_id
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
target_request_id = str(message.root.id)
|
target_request_id = str(message.root.id)
|
||||||
|
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from pydantic import BaseModel
|
|||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
from mcp.shared.exceptions import McpError
|
from mcp.shared.exceptions import McpError
|
||||||
from mcp.shared.message import SessionMessage
|
from mcp.shared.message import ServerMessageMetadata, SessionMessage
|
||||||
from mcp.types import (
|
from mcp.types import (
|
||||||
CancelledNotification,
|
CancelledNotification,
|
||||||
ClientNotification,
|
ClientNotification,
|
||||||
@@ -24,7 +24,6 @@ from mcp.types import (
|
|||||||
JSONRPCNotification,
|
JSONRPCNotification,
|
||||||
JSONRPCRequest,
|
JSONRPCRequest,
|
||||||
JSONRPCResponse,
|
JSONRPCResponse,
|
||||||
NotificationParams,
|
|
||||||
RequestParams,
|
RequestParams,
|
||||||
ServerNotification,
|
ServerNotification,
|
||||||
ServerRequest,
|
ServerRequest,
|
||||||
@@ -288,22 +287,16 @@ class BaseSession(
|
|||||||
"""
|
"""
|
||||||
# Some transport implementations may need to set the related_request_id
|
# Some transport implementations may need to set the related_request_id
|
||||||
# to attribute to the notifications to the request that triggered them.
|
# to attribute to the notifications to the request that triggered them.
|
||||||
if related_request_id is not None and notification.root.params is not None:
|
|
||||||
# Create meta if it doesn't exist
|
|
||||||
if notification.root.params.meta is None:
|
|
||||||
meta_dict = {"related_request_id": related_request_id}
|
|
||||||
|
|
||||||
else:
|
|
||||||
meta_dict = notification.root.params.meta.model_dump(
|
|
||||||
by_alias=True, mode="json", exclude_none=True
|
|
||||||
)
|
|
||||||
meta_dict["related_request_id"] = related_request_id
|
|
||||||
notification.root.params.meta = NotificationParams.Meta(**meta_dict)
|
|
||||||
jsonrpc_notification = JSONRPCNotification(
|
jsonrpc_notification = JSONRPCNotification(
|
||||||
jsonrpc="2.0",
|
jsonrpc="2.0",
|
||||||
**notification.model_dump(by_alias=True, mode="json", exclude_none=True),
|
**notification.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||||
)
|
)
|
||||||
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_notification))
|
session_message = SessionMessage(
|
||||||
|
message=JSONRPCMessage(jsonrpc_notification),
|
||||||
|
metadata=ServerMessageMetadata(related_request_id=related_request_id)
|
||||||
|
if related_request_id
|
||||||
|
else None,
|
||||||
|
)
|
||||||
await self._write_stream.send(session_message)
|
await self._write_stream.send(session_message)
|
||||||
|
|
||||||
async def _send_response(
|
async def _send_response(
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ from mcp.shared.memory import (
|
|||||||
from mcp.shared.session import RequestResponder
|
from mcp.shared.session import RequestResponder
|
||||||
from mcp.types import (
|
from mcp.types import (
|
||||||
LoggingMessageNotificationParams,
|
LoggingMessageNotificationParams,
|
||||||
NotificationParams,
|
|
||||||
TextContent,
|
TextContent,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -80,10 +79,7 @@ async def test_logging_callback():
|
|||||||
assert log_result.isError is False
|
assert log_result.isError is False
|
||||||
assert len(logging_collector.log_messages) == 1
|
assert len(logging_collector.log_messages) == 1
|
||||||
# Create meta object with related_request_id added dynamically
|
# Create meta object with related_request_id added dynamically
|
||||||
meta = NotificationParams.Meta()
|
|
||||||
setattr(meta, "related_request_id", "2")
|
|
||||||
log = logging_collector.log_messages[0]
|
log = logging_collector.log_messages[0]
|
||||||
assert log.level == "info"
|
assert log.level == "info"
|
||||||
assert log.logger == "test_logger"
|
assert log.logger == "test_logger"
|
||||||
assert log.data == "Test log message"
|
assert log.data == "Test log message"
|
||||||
assert log.meta == meta
|
|
||||||
|
|||||||
Reference in New Issue
Block a user