Fix streamable http sampling (#693)

This commit is contained in:
ihrpr
2025-05-12 18:31:35 +01:00
committed by GitHub
parent ed25167fa5
commit c6fb822c86
7 changed files with 152 additions and 23 deletions

View File

@@ -47,7 +47,7 @@ from pydantic import AnyUrl
import mcp.types as types
from mcp.server.models import InitializationOptions
from mcp.shared.message import SessionMessage
from mcp.shared.message import ServerMessageMetadata, SessionMessage
from mcp.shared.session import (
BaseSession,
RequestResponder,
@@ -230,10 +230,11 @@ class ServerSession(
stop_sequences: list[str] | None = None,
metadata: dict[str, Any] | None = None,
model_preferences: types.ModelPreferences | None = None,
related_request_id: types.RequestId | None = None,
) -> types.CreateMessageResult:
"""Send a sampling/create_message request."""
return await self.send_request(
types.ServerRequest(
request=types.ServerRequest(
types.CreateMessageRequest(
method="sampling/createMessage",
params=types.CreateMessageRequestParams(
@@ -248,7 +249,10 @@ class ServerSession(
),
)
),
types.CreateMessageResult,
result_type=types.CreateMessageResult,
metadata=ServerMessageMetadata(
related_request_id=related_request_id,
),
)
async def list_roots(self) -> types.ListRootsResult:

View File

@@ -33,7 +33,6 @@ from mcp.types import (
ErrorData,
JSONRPCError,
JSONRPCMessage,
JSONRPCNotification,
JSONRPCRequest,
JSONRPCResponse,
RequestId,
@@ -849,9 +848,15 @@ class StreamableHTTPServerTransport:
# Determine which request stream(s) should receive this message
message = session_message.message
target_request_id = None
if isinstance(
message.root, JSONRPCNotification | JSONRPCRequest
):
# Check if this is a response
if isinstance(message.root, JSONRPCResponse | JSONRPCError):
response_id = str(message.root.id)
# If this response is for an existing request stream,
# send it there
if response_id in self._request_streams:
target_request_id = response_id
else:
# Extract related_request_id from meta if it exists
if (
session_message.metadata is not None
@@ -865,10 +870,12 @@ class StreamableHTTPServerTransport:
target_request_id = str(
session_message.metadata.related_request_id
)
else:
target_request_id = str(message.root.id)
request_stream_id = target_request_id or GET_STREAM_KEY
request_stream_id = (
target_request_id
if target_request_id is not None
else GET_STREAM_KEY
)
# Store the event if we have an event store,
# regardless of whether a client is connected