Fix streamable http sampling (#693)

This commit is contained in:
ihrpr
2025-05-12 18:31:35 +01:00
committed by GitHub
parent ed25167fa5
commit c6fb822c86
7 changed files with 152 additions and 23 deletions

View File

@@ -31,6 +31,7 @@ def get_claude_config_path() -> Path | None:
return path
return None
def get_uv_path() -> str:
"""Get the full path to the uv executable."""
uv_path = shutil.which("uv")
@@ -42,6 +43,7 @@ def get_uv_path() -> str:
return "uv" # Fall back to just "uv" if not found
return uv_path
def update_claude_config(
file_spec: str,
server_name: str,

View File

@@ -15,6 +15,7 @@ from typing import Any
import anyio
import httpx
from anyio.abc import TaskGroup
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from httpx_sse import EventSource, ServerSentEvent, aconnect_sse
@@ -239,7 +240,7 @@ class StreamableHTTPTransport:
break
async def _handle_post_request(self, ctx: RequestContext) -> None:
"""Handle a POST request with response processing."""
"""Handle a POST request with response processing."""
headers = self._update_headers_with_session(ctx.headers)
message = ctx.session_message.message
is_initialization = self._is_initialization_request(message)
@@ -300,7 +301,7 @@ class StreamableHTTPTransport:
try:
event_source = EventSource(response)
async for sse in event_source.aiter_sse():
await self._handle_sse_event(
is_complete = await self._handle_sse_event(
sse,
ctx.read_stream_writer,
resumption_callback=(
@@ -309,6 +310,10 @@ class StreamableHTTPTransport:
else None
),
)
# If the SSE event indicates completion, like returning respose/error
# break the loop
if is_complete:
break
except Exception as e:
logger.exception("Error reading SSE stream:")
await ctx.read_stream_writer.send(e)
@@ -344,6 +349,7 @@ class StreamableHTTPTransport:
read_stream_writer: StreamWriter,
write_stream: MemoryObjectSendStream[SessionMessage],
start_get_stream: Callable[[], None],
tg: TaskGroup,
) -> None:
"""Handle writing requests to the server."""
try:
@@ -375,10 +381,17 @@ class StreamableHTTPTransport:
sse_read_timeout=self.sse_read_timeout,
)
if is_resumption:
await self._handle_resumption_request(ctx)
async def handle_request_async():
if is_resumption:
await self._handle_resumption_request(ctx)
else:
await self._handle_post_request(ctx)
# If this is a request, start a new task to handle it
if isinstance(message.root, JSONRPCRequest):
tg.start_soon(handle_request_async)
else:
await self._handle_post_request(ctx)
await handle_request_async()
except Exception as exc:
logger.error(f"Error in post_writer: {exc}")
@@ -466,6 +479,7 @@ async def streamablehttp_client(
read_stream_writer,
write_stream,
start_get_stream,
tg,
)
try:

View File

@@ -47,7 +47,7 @@ from pydantic import AnyUrl
import mcp.types as types
from mcp.server.models import InitializationOptions
from mcp.shared.message import SessionMessage
from mcp.shared.message import ServerMessageMetadata, SessionMessage
from mcp.shared.session import (
BaseSession,
RequestResponder,
@@ -230,10 +230,11 @@ class ServerSession(
stop_sequences: list[str] | None = None,
metadata: dict[str, Any] | None = None,
model_preferences: types.ModelPreferences | None = None,
related_request_id: types.RequestId | None = None,
) -> types.CreateMessageResult:
"""Send a sampling/create_message request."""
return await self.send_request(
types.ServerRequest(
request=types.ServerRequest(
types.CreateMessageRequest(
method="sampling/createMessage",
params=types.CreateMessageRequestParams(
@@ -248,7 +249,10 @@ class ServerSession(
),
)
),
types.CreateMessageResult,
result_type=types.CreateMessageResult,
metadata=ServerMessageMetadata(
related_request_id=related_request_id,
),
)
async def list_roots(self) -> types.ListRootsResult:

View File

@@ -33,7 +33,6 @@ from mcp.types import (
ErrorData,
JSONRPCError,
JSONRPCMessage,
JSONRPCNotification,
JSONRPCRequest,
JSONRPCResponse,
RequestId,
@@ -849,9 +848,15 @@ class StreamableHTTPServerTransport:
# Determine which request stream(s) should receive this message
message = session_message.message
target_request_id = None
if isinstance(
message.root, JSONRPCNotification | JSONRPCRequest
):
# Check if this is a response
if isinstance(message.root, JSONRPCResponse | JSONRPCError):
response_id = str(message.root.id)
# If this response is for an existing request stream,
# send it there
if response_id in self._request_streams:
target_request_id = response_id
else:
# Extract related_request_id from meta if it exists
if (
session_message.metadata is not None
@@ -865,10 +870,12 @@ class StreamableHTTPServerTransport:
target_request_id = str(
session_message.metadata.related_request_id
)
else:
target_request_id = str(message.root.id)
request_stream_id = target_request_id or GET_STREAM_KEY
request_stream_id = (
target_request_id
if target_request_id is not None
else GET_STREAM_KEY
)
# Store the event if we have an event store,
# regardless of whether a client is connected

View File

@@ -223,7 +223,6 @@ class BaseSession(
Do not use this method to emit notifications! Use send_notification()
instead.
"""
request_id = self._request_id
self._request_id = request_id + 1