Support custom httpx client creation (#752)

This commit is contained in:
Nick Cooper
2025-05-23 05:55:01 -07:00
committed by GitHub
parent 7c8ad510b7
commit 10cf0f78a8
3 changed files with 16 additions and 5 deletions

View File

@@ -10,7 +10,7 @@ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStre
from httpx_sse import aconnect_sse
import mcp.types as types
from mcp.shared._httpx_utils import create_mcp_http_client
from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client
from mcp.shared.message import SessionMessage
logger = logging.getLogger(__name__)
@@ -26,6 +26,7 @@ async def sse_client(
headers: dict[str, Any] | None = None,
timeout: float = 5,
sse_read_timeout: float = 60 * 5,
httpx_client_factory: McpHttpClientFactory = create_mcp_http_client,
auth: httpx.Auth | None = None,
):
"""
@@ -53,7 +54,7 @@ async def sse_client(
async with anyio.create_task_group() as tg:
try:
logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}")
async with create_mcp_http_client(headers=headers, auth=auth) as client:
async with httpx_client_factory(headers=headers, auth=auth) as client:
async with aconnect_sse(
client,
"GET",

View File

@@ -19,7 +19,7 @@ from anyio.abc import TaskGroup
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from httpx_sse import EventSource, ServerSentEvent, aconnect_sse
from mcp.shared._httpx_utils import create_mcp_http_client
from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client
from mcp.shared.message import ClientMessageMetadata, SessionMessage
from mcp.types import (
ErrorData,
@@ -430,6 +430,7 @@ async def streamablehttp_client(
timeout: timedelta = timedelta(seconds=30),
sse_read_timeout: timedelta = timedelta(seconds=60 * 5),
terminate_on_close: bool = True,
httpx_client_factory: McpHttpClientFactory = create_mcp_http_client,
auth: httpx.Auth | None = None,
) -> AsyncGenerator[
tuple[
@@ -464,7 +465,7 @@ async def streamablehttp_client(
try:
logger.info(f"Connecting to StreamableHTTP endpoint: {url}")
async with create_mcp_http_client(
async with httpx_client_factory(
headers=transport.request_headers,
timeout=httpx.Timeout(
transport.timeout.seconds, read=transport.sse_read_timeout.seconds

View File

@@ -1,12 +1,21 @@
"""Utilities for creating standardized httpx AsyncClient instances."""
from typing import Any
from typing import Any, Protocol
import httpx
__all__ = ["create_mcp_http_client"]
class McpHttpClientFactory(Protocol):
def __call__(
self,
headers: dict[str, str] | None = None,
timeout: httpx.Timeout | None = None,
auth: httpx.Auth | None = None,
) -> httpx.AsyncClient: ...
def create_mcp_http_client(
headers: dict[str, str] | None = None,
timeout: httpx.Timeout | None = None,