Allow to pass timeout as float (#941)

This commit is contained in:
Marcelo Trylesinski
2025-06-12 09:31:31 +02:00
committed by GitHub
parent d69b290b65
commit 20dc0fbabb

View File

@@ -11,7 +11,6 @@ from collections.abc import AsyncGenerator, Awaitable, Callable
from contextlib import asynccontextmanager
from dataclasses import dataclass
from datetime import timedelta
from typing import Any
import anyio
import httpx
@@ -52,14 +51,10 @@ SSE = "text/event-stream"
class StreamableHTTPError(Exception):
"""Base exception for StreamableHTTP transport errors."""
pass
class ResumptionError(StreamableHTTPError):
"""Raised when resumption request is invalid."""
pass
@dataclass
class RequestContext:
@@ -71,7 +66,7 @@ class RequestContext:
session_message: SessionMessage
metadata: ClientMessageMetadata | None
read_stream_writer: StreamWriter
sse_read_timeout: timedelta
sse_read_timeout: float
class StreamableHTTPTransport:
@@ -80,9 +75,9 @@ class StreamableHTTPTransport:
def __init__(
self,
url: str,
headers: dict[str, Any] | None = None,
timeout: timedelta = timedelta(seconds=30),
sse_read_timeout: timedelta = timedelta(seconds=60 * 5),
headers: dict[str, str] | None = None,
timeout: float | timedelta = 30,
sse_read_timeout: float | timedelta = 60 * 5,
auth: httpx.Auth | None = None,
) -> None:
"""Initialize the StreamableHTTP transport.
@@ -96,10 +91,12 @@ class StreamableHTTPTransport:
"""
self.url = url
self.headers = headers or {}
self.timeout = timeout
self.sse_read_timeout = sse_read_timeout
self.timeout = timeout.total_seconds() if isinstance(timeout, timedelta) else timeout
self.sse_read_timeout = (
sse_read_timeout.total_seconds() if isinstance(sse_read_timeout, timedelta) else sse_read_timeout
)
self.auth = auth
self.session_id: str | None = None
self.session_id = None
self.request_headers = {
ACCEPT: f"{JSON}, {SSE}",
CONTENT_TYPE: JSON,
@@ -160,7 +157,7 @@ class StreamableHTTPTransport:
return isinstance(message.root, JSONRPCResponse | JSONRPCError)
except Exception as exc:
logger.error(f"Error parsing SSE message: {exc}")
logger.exception("Error parsing SSE message")
await read_stream_writer.send(exc)
return False
else:
@@ -184,10 +181,7 @@ class StreamableHTTPTransport:
"GET",
self.url,
headers=headers,
timeout=httpx.Timeout(
self.timeout.total_seconds(),
read=self.sse_read_timeout.total_seconds(),
),
timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout),
) as event_source:
event_source.response.raise_for_status()
logger.debug("GET SSE connection established")
@@ -216,10 +210,7 @@ class StreamableHTTPTransport:
"GET",
self.url,
headers=headers,
timeout=httpx.Timeout(
self.timeout.total_seconds(),
read=ctx.sse_read_timeout.total_seconds(),
),
timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout),
) as event_source:
event_source.response.raise_for_status()
logger.debug("Resumption GET SSE connection established")
@@ -412,9 +403,9 @@ class StreamableHTTPTransport:
@asynccontextmanager
async def streamablehttp_client(
url: str,
headers: dict[str, Any] | None = None,
timeout: timedelta = timedelta(seconds=30),
sse_read_timeout: timedelta = timedelta(seconds=60 * 5),
headers: dict[str, str] | None = None,
timeout: float | timedelta = 30,
sse_read_timeout: float | timedelta = 60 * 5,
terminate_on_close: bool = True,
httpx_client_factory: McpHttpClientFactory = create_mcp_http_client,
auth: httpx.Auth | None = None,
@@ -449,10 +440,7 @@ async def streamablehttp_client(
async with httpx_client_factory(
headers=transport.request_headers,
timeout=httpx.Timeout(
transport.timeout.total_seconds(),
read=transport.sse_read_timeout.total_seconds(),
),
timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout),
auth=transport.auth,
) as client:
# Define callbacks that need access to tg