Allow to pass timeout as float (#941)

This commit is contained in:
Marcelo Trylesinski
2025-06-12 09:31:31 +02:00
committed by GitHub
parent d69b290b65
commit 20dc0fbabb

View File

@@ -11,7 +11,6 @@ from collections.abc import AsyncGenerator, Awaitable, Callable
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from dataclasses import dataclass from dataclasses import dataclass
from datetime import timedelta from datetime import timedelta
from typing import Any
import anyio import anyio
import httpx import httpx
@@ -52,14 +51,10 @@ SSE = "text/event-stream"
class StreamableHTTPError(Exception): class StreamableHTTPError(Exception):
"""Base exception for StreamableHTTP transport errors.""" """Base exception for StreamableHTTP transport errors."""
pass
class ResumptionError(StreamableHTTPError): class ResumptionError(StreamableHTTPError):
"""Raised when resumption request is invalid.""" """Raised when resumption request is invalid."""
pass
@dataclass @dataclass
class RequestContext: class RequestContext:
@@ -71,7 +66,7 @@ class RequestContext:
session_message: SessionMessage session_message: SessionMessage
metadata: ClientMessageMetadata | None metadata: ClientMessageMetadata | None
read_stream_writer: StreamWriter read_stream_writer: StreamWriter
sse_read_timeout: timedelta sse_read_timeout: float
class StreamableHTTPTransport: class StreamableHTTPTransport:
@@ -80,9 +75,9 @@ class StreamableHTTPTransport:
def __init__( def __init__(
self, self,
url: str, url: str,
headers: dict[str, Any] | None = None, headers: dict[str, str] | None = None,
timeout: timedelta = timedelta(seconds=30), timeout: float | timedelta = 30,
sse_read_timeout: timedelta = timedelta(seconds=60 * 5), sse_read_timeout: float | timedelta = 60 * 5,
auth: httpx.Auth | None = None, auth: httpx.Auth | None = None,
) -> None: ) -> None:
"""Initialize the StreamableHTTP transport. """Initialize the StreamableHTTP transport.
@@ -96,10 +91,12 @@ class StreamableHTTPTransport:
""" """
self.url = url self.url = url
self.headers = headers or {} self.headers = headers or {}
self.timeout = timeout self.timeout = timeout.total_seconds() if isinstance(timeout, timedelta) else timeout
self.sse_read_timeout = sse_read_timeout self.sse_read_timeout = (
sse_read_timeout.total_seconds() if isinstance(sse_read_timeout, timedelta) else sse_read_timeout
)
self.auth = auth self.auth = auth
self.session_id: str | None = None self.session_id = None
self.request_headers = { self.request_headers = {
ACCEPT: f"{JSON}, {SSE}", ACCEPT: f"{JSON}, {SSE}",
CONTENT_TYPE: JSON, CONTENT_TYPE: JSON,
@@ -160,7 +157,7 @@ class StreamableHTTPTransport:
return isinstance(message.root, JSONRPCResponse | JSONRPCError) return isinstance(message.root, JSONRPCResponse | JSONRPCError)
except Exception as exc: except Exception as exc:
logger.error(f"Error parsing SSE message: {exc}") logger.exception("Error parsing SSE message")
await read_stream_writer.send(exc) await read_stream_writer.send(exc)
return False return False
else: else:
@@ -184,10 +181,7 @@ class StreamableHTTPTransport:
"GET", "GET",
self.url, self.url,
headers=headers, headers=headers,
timeout=httpx.Timeout( timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout),
self.timeout.total_seconds(),
read=self.sse_read_timeout.total_seconds(),
),
) as event_source: ) as event_source:
event_source.response.raise_for_status() event_source.response.raise_for_status()
logger.debug("GET SSE connection established") logger.debug("GET SSE connection established")
@@ -216,10 +210,7 @@ class StreamableHTTPTransport:
"GET", "GET",
self.url, self.url,
headers=headers, headers=headers,
timeout=httpx.Timeout( timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout),
self.timeout.total_seconds(),
read=ctx.sse_read_timeout.total_seconds(),
),
) as event_source: ) as event_source:
event_source.response.raise_for_status() event_source.response.raise_for_status()
logger.debug("Resumption GET SSE connection established") logger.debug("Resumption GET SSE connection established")
@@ -412,9 +403,9 @@ class StreamableHTTPTransport:
@asynccontextmanager @asynccontextmanager
async def streamablehttp_client( async def streamablehttp_client(
url: str, url: str,
headers: dict[str, Any] | None = None, headers: dict[str, str] | None = None,
timeout: timedelta = timedelta(seconds=30), timeout: float | timedelta = 30,
sse_read_timeout: timedelta = timedelta(seconds=60 * 5), sse_read_timeout: float | timedelta = 60 * 5,
terminate_on_close: bool = True, terminate_on_close: bool = True,
httpx_client_factory: McpHttpClientFactory = create_mcp_http_client, httpx_client_factory: McpHttpClientFactory = create_mcp_http_client,
auth: httpx.Auth | None = None, auth: httpx.Auth | None = None,
@@ -449,10 +440,7 @@ async def streamablehttp_client(
async with httpx_client_factory( async with httpx_client_factory(
headers=transport.request_headers, headers=transport.request_headers,
timeout=httpx.Timeout( timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout),
transport.timeout.total_seconds(),
read=transport.sse_read_timeout.total_seconds(),
),
auth=transport.auth, auth=transport.auth,
) as client: ) as client:
# Define callbacks that need access to tg # Define callbacks that need access to tg