mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 14:54:24 +01:00
Fix streamable http sampling (#693)
This commit is contained in:
@@ -31,6 +31,7 @@ def get_claude_config_path() -> Path | None:
|
||||
return path
|
||||
return None
|
||||
|
||||
|
||||
def get_uv_path() -> str:
|
||||
"""Get the full path to the uv executable."""
|
||||
uv_path = shutil.which("uv")
|
||||
@@ -42,6 +43,7 @@ def get_uv_path() -> str:
|
||||
return "uv" # Fall back to just "uv" if not found
|
||||
return uv_path
|
||||
|
||||
|
||||
def update_claude_config(
|
||||
file_spec: str,
|
||||
server_name: str,
|
||||
|
||||
@@ -15,6 +15,7 @@ from typing import Any
|
||||
|
||||
import anyio
|
||||
import httpx
|
||||
from anyio.abc import TaskGroup
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from httpx_sse import EventSource, ServerSentEvent, aconnect_sse
|
||||
|
||||
@@ -239,7 +240,7 @@ class StreamableHTTPTransport:
|
||||
break
|
||||
|
||||
async def _handle_post_request(self, ctx: RequestContext) -> None:
|
||||
"""Handle a POST request with response processing."""
|
||||
"""Handle a POST request with response processing."""
|
||||
headers = self._update_headers_with_session(ctx.headers)
|
||||
message = ctx.session_message.message
|
||||
is_initialization = self._is_initialization_request(message)
|
||||
@@ -300,7 +301,7 @@ class StreamableHTTPTransport:
|
||||
try:
|
||||
event_source = EventSource(response)
|
||||
async for sse in event_source.aiter_sse():
|
||||
await self._handle_sse_event(
|
||||
is_complete = await self._handle_sse_event(
|
||||
sse,
|
||||
ctx.read_stream_writer,
|
||||
resumption_callback=(
|
||||
@@ -309,6 +310,10 @@ class StreamableHTTPTransport:
|
||||
else None
|
||||
),
|
||||
)
|
||||
# If the SSE event indicates completion, like returning respose/error
|
||||
# break the loop
|
||||
if is_complete:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.exception("Error reading SSE stream:")
|
||||
await ctx.read_stream_writer.send(e)
|
||||
@@ -344,6 +349,7 @@ class StreamableHTTPTransport:
|
||||
read_stream_writer: StreamWriter,
|
||||
write_stream: MemoryObjectSendStream[SessionMessage],
|
||||
start_get_stream: Callable[[], None],
|
||||
tg: TaskGroup,
|
||||
) -> None:
|
||||
"""Handle writing requests to the server."""
|
||||
try:
|
||||
@@ -375,10 +381,17 @@ class StreamableHTTPTransport:
|
||||
sse_read_timeout=self.sse_read_timeout,
|
||||
)
|
||||
|
||||
if is_resumption:
|
||||
await self._handle_resumption_request(ctx)
|
||||
async def handle_request_async():
|
||||
if is_resumption:
|
||||
await self._handle_resumption_request(ctx)
|
||||
else:
|
||||
await self._handle_post_request(ctx)
|
||||
|
||||
# If this is a request, start a new task to handle it
|
||||
if isinstance(message.root, JSONRPCRequest):
|
||||
tg.start_soon(handle_request_async)
|
||||
else:
|
||||
await self._handle_post_request(ctx)
|
||||
await handle_request_async()
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(f"Error in post_writer: {exc}")
|
||||
@@ -466,6 +479,7 @@ async def streamablehttp_client(
|
||||
read_stream_writer,
|
||||
write_stream,
|
||||
start_get_stream,
|
||||
tg,
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@@ -47,7 +47,7 @@ from pydantic import AnyUrl
|
||||
|
||||
import mcp.types as types
|
||||
from mcp.server.models import InitializationOptions
|
||||
from mcp.shared.message import SessionMessage
|
||||
from mcp.shared.message import ServerMessageMetadata, SessionMessage
|
||||
from mcp.shared.session import (
|
||||
BaseSession,
|
||||
RequestResponder,
|
||||
@@ -230,10 +230,11 @@ class ServerSession(
|
||||
stop_sequences: list[str] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
model_preferences: types.ModelPreferences | None = None,
|
||||
related_request_id: types.RequestId | None = None,
|
||||
) -> types.CreateMessageResult:
|
||||
"""Send a sampling/create_message request."""
|
||||
return await self.send_request(
|
||||
types.ServerRequest(
|
||||
request=types.ServerRequest(
|
||||
types.CreateMessageRequest(
|
||||
method="sampling/createMessage",
|
||||
params=types.CreateMessageRequestParams(
|
||||
@@ -248,7 +249,10 @@ class ServerSession(
|
||||
),
|
||||
)
|
||||
),
|
||||
types.CreateMessageResult,
|
||||
result_type=types.CreateMessageResult,
|
||||
metadata=ServerMessageMetadata(
|
||||
related_request_id=related_request_id,
|
||||
),
|
||||
)
|
||||
|
||||
async def list_roots(self) -> types.ListRootsResult:
|
||||
|
||||
@@ -33,7 +33,6 @@ from mcp.types import (
|
||||
ErrorData,
|
||||
JSONRPCError,
|
||||
JSONRPCMessage,
|
||||
JSONRPCNotification,
|
||||
JSONRPCRequest,
|
||||
JSONRPCResponse,
|
||||
RequestId,
|
||||
@@ -849,9 +848,15 @@ class StreamableHTTPServerTransport:
|
||||
# Determine which request stream(s) should receive this message
|
||||
message = session_message.message
|
||||
target_request_id = None
|
||||
if isinstance(
|
||||
message.root, JSONRPCNotification | JSONRPCRequest
|
||||
):
|
||||
# Check if this is a response
|
||||
if isinstance(message.root, JSONRPCResponse | JSONRPCError):
|
||||
response_id = str(message.root.id)
|
||||
# If this response is for an existing request stream,
|
||||
# send it there
|
||||
if response_id in self._request_streams:
|
||||
target_request_id = response_id
|
||||
|
||||
else:
|
||||
# Extract related_request_id from meta if it exists
|
||||
if (
|
||||
session_message.metadata is not None
|
||||
@@ -865,10 +870,12 @@ class StreamableHTTPServerTransport:
|
||||
target_request_id = str(
|
||||
session_message.metadata.related_request_id
|
||||
)
|
||||
else:
|
||||
target_request_id = str(message.root.id)
|
||||
|
||||
request_stream_id = target_request_id or GET_STREAM_KEY
|
||||
request_stream_id = (
|
||||
target_request_id
|
||||
if target_request_id is not None
|
||||
else GET_STREAM_KEY
|
||||
)
|
||||
|
||||
# Store the event if we have an event store,
|
||||
# regardless of whether a client is connected
|
||||
|
||||
@@ -223,7 +223,6 @@ class BaseSession(
|
||||
Do not use this method to emit notifications! Use send_notification()
|
||||
instead.
|
||||
"""
|
||||
|
||||
request_id = self._request_id
|
||||
self._request_id = request_id + 1
|
||||
|
||||
|
||||
Reference in New Issue
Block a user