Support custom httpx client creation (#752)

This commit is contained in:
Nick Cooper
2025-05-23 05:55:01 -07:00
committed by GitHub
parent 7c8ad510b7
commit 10cf0f78a8
3 changed files with 16 additions and 5 deletions

View File

@@ -10,7 +10,7 @@ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStre
from httpx_sse import aconnect_sse from httpx_sse import aconnect_sse
import mcp.types as types import mcp.types as types
from mcp.shared._httpx_utils import create_mcp_http_client from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client
from mcp.shared.message import SessionMessage from mcp.shared.message import SessionMessage
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -26,6 +26,7 @@ async def sse_client(
headers: dict[str, Any] | None = None, headers: dict[str, Any] | None = None,
timeout: float = 5, timeout: float = 5,
sse_read_timeout: float = 60 * 5, sse_read_timeout: float = 60 * 5,
httpx_client_factory: McpHttpClientFactory = create_mcp_http_client,
auth: httpx.Auth | None = None, auth: httpx.Auth | None = None,
): ):
""" """
@@ -53,7 +54,7 @@ async def sse_client(
async with anyio.create_task_group() as tg: async with anyio.create_task_group() as tg:
try: try:
logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}") logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}")
async with create_mcp_http_client(headers=headers, auth=auth) as client: async with httpx_client_factory(headers=headers, auth=auth) as client:
async with aconnect_sse( async with aconnect_sse(
client, client,
"GET", "GET",

View File

@@ -19,7 +19,7 @@ from anyio.abc import TaskGroup
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from httpx_sse import EventSource, ServerSentEvent, aconnect_sse from httpx_sse import EventSource, ServerSentEvent, aconnect_sse
from mcp.shared._httpx_utils import create_mcp_http_client from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client
from mcp.shared.message import ClientMessageMetadata, SessionMessage from mcp.shared.message import ClientMessageMetadata, SessionMessage
from mcp.types import ( from mcp.types import (
ErrorData, ErrorData,
@@ -430,6 +430,7 @@ async def streamablehttp_client(
timeout: timedelta = timedelta(seconds=30), timeout: timedelta = timedelta(seconds=30),
sse_read_timeout: timedelta = timedelta(seconds=60 * 5), sse_read_timeout: timedelta = timedelta(seconds=60 * 5),
terminate_on_close: bool = True, terminate_on_close: bool = True,
httpx_client_factory: McpHttpClientFactory = create_mcp_http_client,
auth: httpx.Auth | None = None, auth: httpx.Auth | None = None,
) -> AsyncGenerator[ ) -> AsyncGenerator[
tuple[ tuple[
@@ -464,7 +465,7 @@ async def streamablehttp_client(
try: try:
logger.info(f"Connecting to StreamableHTTP endpoint: {url}") logger.info(f"Connecting to StreamableHTTP endpoint: {url}")
async with create_mcp_http_client( async with httpx_client_factory(
headers=transport.request_headers, headers=transport.request_headers,
timeout=httpx.Timeout( timeout=httpx.Timeout(
transport.timeout.seconds, read=transport.sse_read_timeout.seconds transport.timeout.seconds, read=transport.sse_read_timeout.seconds

View File

@@ -1,12 +1,21 @@
"""Utilities for creating standardized httpx AsyncClient instances.""" """Utilities for creating standardized httpx AsyncClient instances."""
from typing import Any from typing import Any, Protocol
import httpx import httpx
__all__ = ["create_mcp_http_client"] __all__ = ["create_mcp_http_client"]
class McpHttpClientFactory(Protocol):
def __call__(
self,
headers: dict[str, str] | None = None,
timeout: httpx.Timeout | None = None,
auth: httpx.Auth | None = None,
) -> httpx.AsyncClient: ...
def create_mcp_http_client( def create_mcp_http_client(
headers: dict[str, str] | None = None, headers: dict[str, str] | None = None,
timeout: httpx.Timeout | None = None, timeout: httpx.Timeout | None = None,