Fix streamable http sampling (#693)

This commit is contained in:
ihrpr
2025-05-12 18:31:35 +01:00
committed by GitHub
parent ed25167fa5
commit c6fb822c86
7 changed files with 152 additions and 23 deletions

View File

@@ -15,6 +15,7 @@ from typing import Any
import anyio
import httpx
from anyio.abc import TaskGroup
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from httpx_sse import EventSource, ServerSentEvent, aconnect_sse
@@ -239,7 +240,7 @@ class StreamableHTTPTransport:
break
async def _handle_post_request(self, ctx: RequestContext) -> None:
"""Handle a POST request with response processing."""
"""Handle a POST request with response processing."""
headers = self._update_headers_with_session(ctx.headers)
message = ctx.session_message.message
is_initialization = self._is_initialization_request(message)
@@ -300,7 +301,7 @@ class StreamableHTTPTransport:
try:
event_source = EventSource(response)
async for sse in event_source.aiter_sse():
await self._handle_sse_event(
is_complete = await self._handle_sse_event(
sse,
ctx.read_stream_writer,
resumption_callback=(
@@ -309,6 +310,10 @@ class StreamableHTTPTransport:
else None
),
)
# If the SSE event indicates completion, like returning respose/error
# break the loop
if is_complete:
break
except Exception as e:
logger.exception("Error reading SSE stream:")
await ctx.read_stream_writer.send(e)
@@ -344,6 +349,7 @@ class StreamableHTTPTransport:
read_stream_writer: StreamWriter,
write_stream: MemoryObjectSendStream[SessionMessage],
start_get_stream: Callable[[], None],
tg: TaskGroup,
) -> None:
"""Handle writing requests to the server."""
try:
@@ -375,10 +381,17 @@ class StreamableHTTPTransport:
sse_read_timeout=self.sse_read_timeout,
)
if is_resumption:
await self._handle_resumption_request(ctx)
async def handle_request_async():
if is_resumption:
await self._handle_resumption_request(ctx)
else:
await self._handle_post_request(ctx)
# If this is a request, start a new task to handle it
if isinstance(message.root, JSONRPCRequest):
tg.start_soon(handle_request_async)
else:
await self._handle_post_request(ctx)
await handle_request_async()
except Exception as exc:
logger.error(f"Error in post_writer: {exc}")
@@ -466,6 +479,7 @@ async def streamablehttp_client(
read_stream_writer,
write_stream,
start_get_stream,
tg,
)
try: