mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 06:54:18 +01:00
Fix streamable http sampling (#693)
This commit is contained in:
@@ -47,7 +47,7 @@ from pydantic import AnyUrl
|
||||
|
||||
import mcp.types as types
|
||||
from mcp.server.models import InitializationOptions
|
||||
from mcp.shared.message import SessionMessage
|
||||
from mcp.shared.message import ServerMessageMetadata, SessionMessage
|
||||
from mcp.shared.session import (
|
||||
BaseSession,
|
||||
RequestResponder,
|
||||
@@ -230,10 +230,11 @@ class ServerSession(
|
||||
stop_sequences: list[str] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
model_preferences: types.ModelPreferences | None = None,
|
||||
related_request_id: types.RequestId | None = None,
|
||||
) -> types.CreateMessageResult:
|
||||
"""Send a sampling/create_message request."""
|
||||
return await self.send_request(
|
||||
types.ServerRequest(
|
||||
request=types.ServerRequest(
|
||||
types.CreateMessageRequest(
|
||||
method="sampling/createMessage",
|
||||
params=types.CreateMessageRequestParams(
|
||||
@@ -248,7 +249,10 @@ class ServerSession(
|
||||
),
|
||||
)
|
||||
),
|
||||
types.CreateMessageResult,
|
||||
result_type=types.CreateMessageResult,
|
||||
metadata=ServerMessageMetadata(
|
||||
related_request_id=related_request_id,
|
||||
),
|
||||
)
|
||||
|
||||
async def list_roots(self) -> types.ListRootsResult:
|
||||
|
||||
@@ -33,7 +33,6 @@ from mcp.types import (
|
||||
ErrorData,
|
||||
JSONRPCError,
|
||||
JSONRPCMessage,
|
||||
JSONRPCNotification,
|
||||
JSONRPCRequest,
|
||||
JSONRPCResponse,
|
||||
RequestId,
|
||||
@@ -849,9 +848,15 @@ class StreamableHTTPServerTransport:
|
||||
# Determine which request stream(s) should receive this message
|
||||
message = session_message.message
|
||||
target_request_id = None
|
||||
if isinstance(
|
||||
message.root, JSONRPCNotification | JSONRPCRequest
|
||||
):
|
||||
# Check if this is a response
|
||||
if isinstance(message.root, JSONRPCResponse | JSONRPCError):
|
||||
response_id = str(message.root.id)
|
||||
# If this response is for an existing request stream,
|
||||
# send it there
|
||||
if response_id in self._request_streams:
|
||||
target_request_id = response_id
|
||||
|
||||
else:
|
||||
# Extract related_request_id from meta if it exists
|
||||
if (
|
||||
session_message.metadata is not None
|
||||
@@ -865,10 +870,12 @@ class StreamableHTTPServerTransport:
|
||||
target_request_id = str(
|
||||
session_message.metadata.related_request_id
|
||||
)
|
||||
else:
|
||||
target_request_id = str(message.root.id)
|
||||
|
||||
request_stream_id = target_request_id or GET_STREAM_KEY
|
||||
request_stream_id = (
|
||||
target_request_id
|
||||
if target_request_id is not None
|
||||
else GET_STREAM_KEY
|
||||
)
|
||||
|
||||
# Store the event if we have an event store,
|
||||
# regardless of whether a client is connected
|
||||
|
||||
Reference in New Issue
Block a user