mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2026-01-08 16:34:19 +01:00
Fix streamable http sampling (#693)
This commit is contained in:
@@ -15,6 +15,7 @@ from typing import Any
|
||||
|
||||
import anyio
|
||||
import httpx
|
||||
from anyio.abc import TaskGroup
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from httpx_sse import EventSource, ServerSentEvent, aconnect_sse
|
||||
|
||||
@@ -239,7 +240,7 @@ class StreamableHTTPTransport:
|
||||
break
|
||||
|
||||
async def _handle_post_request(self, ctx: RequestContext) -> None:
|
||||
"""Handle a POST request with response processing."""
|
||||
"""Handle a POST request with response processing."""
|
||||
headers = self._update_headers_with_session(ctx.headers)
|
||||
message = ctx.session_message.message
|
||||
is_initialization = self._is_initialization_request(message)
|
||||
@@ -300,7 +301,7 @@ class StreamableHTTPTransport:
|
||||
try:
|
||||
event_source = EventSource(response)
|
||||
async for sse in event_source.aiter_sse():
|
||||
await self._handle_sse_event(
|
||||
is_complete = await self._handle_sse_event(
|
||||
sse,
|
||||
ctx.read_stream_writer,
|
||||
resumption_callback=(
|
||||
@@ -309,6 +310,10 @@ class StreamableHTTPTransport:
|
||||
else None
|
||||
),
|
||||
)
|
||||
# If the SSE event indicates completion, like returning respose/error
|
||||
# break the loop
|
||||
if is_complete:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.exception("Error reading SSE stream:")
|
||||
await ctx.read_stream_writer.send(e)
|
||||
@@ -344,6 +349,7 @@ class StreamableHTTPTransport:
|
||||
read_stream_writer: StreamWriter,
|
||||
write_stream: MemoryObjectSendStream[SessionMessage],
|
||||
start_get_stream: Callable[[], None],
|
||||
tg: TaskGroup,
|
||||
) -> None:
|
||||
"""Handle writing requests to the server."""
|
||||
try:
|
||||
@@ -375,10 +381,17 @@ class StreamableHTTPTransport:
|
||||
sse_read_timeout=self.sse_read_timeout,
|
||||
)
|
||||
|
||||
if is_resumption:
|
||||
await self._handle_resumption_request(ctx)
|
||||
async def handle_request_async():
|
||||
if is_resumption:
|
||||
await self._handle_resumption_request(ctx)
|
||||
else:
|
||||
await self._handle_post_request(ctx)
|
||||
|
||||
# If this is a request, start a new task to handle it
|
||||
if isinstance(message.root, JSONRPCRequest):
|
||||
tg.start_soon(handle_request_async)
|
||||
else:
|
||||
await self._handle_post_request(ctx)
|
||||
await handle_request_async()
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(f"Error in post_writer: {exc}")
|
||||
@@ -466,6 +479,7 @@ async def streamablehttp_client(
|
||||
read_stream_writer,
|
||||
write_stream,
|
||||
start_get_stream,
|
||||
tg,
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user