mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 23:04:25 +01:00
Allow to pass timeout as float (#941)
This commit is contained in:
committed by
GitHub
parent
d69b290b65
commit
20dc0fbabb
@@ -11,7 +11,6 @@ from collections.abc import AsyncGenerator, Awaitable, Callable
|
|||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import anyio
|
import anyio
|
||||||
import httpx
|
import httpx
|
||||||
@@ -52,14 +51,10 @@ SSE = "text/event-stream"
|
|||||||
class StreamableHTTPError(Exception):
|
class StreamableHTTPError(Exception):
|
||||||
"""Base exception for StreamableHTTP transport errors."""
|
"""Base exception for StreamableHTTP transport errors."""
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ResumptionError(StreamableHTTPError):
|
class ResumptionError(StreamableHTTPError):
|
||||||
"""Raised when resumption request is invalid."""
|
"""Raised when resumption request is invalid."""
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RequestContext:
|
class RequestContext:
|
||||||
@@ -71,7 +66,7 @@ class RequestContext:
|
|||||||
session_message: SessionMessage
|
session_message: SessionMessage
|
||||||
metadata: ClientMessageMetadata | None
|
metadata: ClientMessageMetadata | None
|
||||||
read_stream_writer: StreamWriter
|
read_stream_writer: StreamWriter
|
||||||
sse_read_timeout: timedelta
|
sse_read_timeout: float
|
||||||
|
|
||||||
|
|
||||||
class StreamableHTTPTransport:
|
class StreamableHTTPTransport:
|
||||||
@@ -80,9 +75,9 @@ class StreamableHTTPTransport:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
url: str,
|
url: str,
|
||||||
headers: dict[str, Any] | None = None,
|
headers: dict[str, str] | None = None,
|
||||||
timeout: timedelta = timedelta(seconds=30),
|
timeout: float | timedelta = 30,
|
||||||
sse_read_timeout: timedelta = timedelta(seconds=60 * 5),
|
sse_read_timeout: float | timedelta = 60 * 5,
|
||||||
auth: httpx.Auth | None = None,
|
auth: httpx.Auth | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the StreamableHTTP transport.
|
"""Initialize the StreamableHTTP transport.
|
||||||
@@ -96,10 +91,12 @@ class StreamableHTTPTransport:
|
|||||||
"""
|
"""
|
||||||
self.url = url
|
self.url = url
|
||||||
self.headers = headers or {}
|
self.headers = headers or {}
|
||||||
self.timeout = timeout
|
self.timeout = timeout.total_seconds() if isinstance(timeout, timedelta) else timeout
|
||||||
self.sse_read_timeout = sse_read_timeout
|
self.sse_read_timeout = (
|
||||||
|
sse_read_timeout.total_seconds() if isinstance(sse_read_timeout, timedelta) else sse_read_timeout
|
||||||
|
)
|
||||||
self.auth = auth
|
self.auth = auth
|
||||||
self.session_id: str | None = None
|
self.session_id = None
|
||||||
self.request_headers = {
|
self.request_headers = {
|
||||||
ACCEPT: f"{JSON}, {SSE}",
|
ACCEPT: f"{JSON}, {SSE}",
|
||||||
CONTENT_TYPE: JSON,
|
CONTENT_TYPE: JSON,
|
||||||
@@ -160,7 +157,7 @@ class StreamableHTTPTransport:
|
|||||||
return isinstance(message.root, JSONRPCResponse | JSONRPCError)
|
return isinstance(message.root, JSONRPCResponse | JSONRPCError)
|
||||||
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.error(f"Error parsing SSE message: {exc}")
|
logger.exception("Error parsing SSE message")
|
||||||
await read_stream_writer.send(exc)
|
await read_stream_writer.send(exc)
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
@@ -184,10 +181,7 @@ class StreamableHTTPTransport:
|
|||||||
"GET",
|
"GET",
|
||||||
self.url,
|
self.url,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
timeout=httpx.Timeout(
|
timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout),
|
||||||
self.timeout.total_seconds(),
|
|
||||||
read=self.sse_read_timeout.total_seconds(),
|
|
||||||
),
|
|
||||||
) as event_source:
|
) as event_source:
|
||||||
event_source.response.raise_for_status()
|
event_source.response.raise_for_status()
|
||||||
logger.debug("GET SSE connection established")
|
logger.debug("GET SSE connection established")
|
||||||
@@ -216,10 +210,7 @@ class StreamableHTTPTransport:
|
|||||||
"GET",
|
"GET",
|
||||||
self.url,
|
self.url,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
timeout=httpx.Timeout(
|
timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout),
|
||||||
self.timeout.total_seconds(),
|
|
||||||
read=ctx.sse_read_timeout.total_seconds(),
|
|
||||||
),
|
|
||||||
) as event_source:
|
) as event_source:
|
||||||
event_source.response.raise_for_status()
|
event_source.response.raise_for_status()
|
||||||
logger.debug("Resumption GET SSE connection established")
|
logger.debug("Resumption GET SSE connection established")
|
||||||
@@ -412,9 +403,9 @@ class StreamableHTTPTransport:
|
|||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def streamablehttp_client(
|
async def streamablehttp_client(
|
||||||
url: str,
|
url: str,
|
||||||
headers: dict[str, Any] | None = None,
|
headers: dict[str, str] | None = None,
|
||||||
timeout: timedelta = timedelta(seconds=30),
|
timeout: float | timedelta = 30,
|
||||||
sse_read_timeout: timedelta = timedelta(seconds=60 * 5),
|
sse_read_timeout: float | timedelta = 60 * 5,
|
||||||
terminate_on_close: bool = True,
|
terminate_on_close: bool = True,
|
||||||
httpx_client_factory: McpHttpClientFactory = create_mcp_http_client,
|
httpx_client_factory: McpHttpClientFactory = create_mcp_http_client,
|
||||||
auth: httpx.Auth | None = None,
|
auth: httpx.Auth | None = None,
|
||||||
@@ -449,10 +440,7 @@ async def streamablehttp_client(
|
|||||||
|
|
||||||
async with httpx_client_factory(
|
async with httpx_client_factory(
|
||||||
headers=transport.request_headers,
|
headers=transport.request_headers,
|
||||||
timeout=httpx.Timeout(
|
timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout),
|
||||||
transport.timeout.total_seconds(),
|
|
||||||
read=transport.sse_read_timeout.total_seconds(),
|
|
||||||
),
|
|
||||||
auth=transport.auth,
|
auth=transport.auth,
|
||||||
) as client:
|
) as client:
|
||||||
# Define callbacks that need access to tg
|
# Define callbacks that need access to tg
|
||||||
|
|||||||
Reference in New Issue
Block a user