This commit is contained in:
Nick Merrill
2025-01-14 09:24:44 -05:00
parent f164291483
commit b0a6aafaf6
2 changed files with 117 additions and 55 deletions

View File

@@ -24,12 +24,20 @@ async def sse_client(
headers: dict[str, Any] | None = None, headers: dict[str, Any] | None = None,
timeout: float = 5, timeout: float = 5,
sse_read_timeout: float = 60 * 5, sse_read_timeout: float = 60 * 5,
client: httpx.AsyncClient | None = None,
): ):
""" """
Client transport for SSE. Client transport for SSE.
`sse_read_timeout` determines how long (in seconds) the client will wait for a new `sse_read_timeout` determines how long (in seconds) the client will wait for a new
event before disconnecting. All other HTTP operations are controlled by `timeout`. event before disconnecting. All other HTTP operations are controlled by `timeout`.
Args:
url: The URL to connect to
headers: Optional headers to send with the request
timeout: Connection timeout in seconds
sse_read_timeout: Read timeout in seconds
client: Optional httpx.AsyncClient instance to use for requests
""" """
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception] read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
@@ -43,7 +51,13 @@ async def sse_client(
async with anyio.create_task_group() as tg: async with anyio.create_task_group() as tg:
try: try:
logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}") logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}")
async with httpx.AsyncClient(headers=headers) as client: if client is None:
client = httpx.AsyncClient(headers=headers)
should_close_client = True
else:
should_close_client = False
try:
async with aconnect_sse( async with aconnect_sse(
client, client,
"GET", "GET",
@@ -137,6 +151,9 @@ async def sse_client(
yield read_stream, write_stream yield read_stream, write_stream
finally: finally:
tg.cancel_scope.cancel() tg.cancel_scope.cancel()
finally:
if should_close_client:
await client.aclose()
finally: finally:
await read_stream_writer.aclose() await read_stream_writer.aclose()
await write_stream.aclose() await write_stream.aclose()

View File

@@ -1,82 +1,127 @@
import pytest
import anyio import anyio
import pytest
from starlette.applications import Starlette from starlette.applications import Starlette
from starlette.routing import Mount, Route from starlette.routing import Mount, Route
import uvicorn
from mcp.client.sse import sse_client
from exceptiongroup import ExceptionGroup
import asyncio
import httpx import httpx
from httpx import ReadTimeout from httpx import ReadTimeout, ASGITransport
from mcp.client.sse import sse_client
from mcp.server.sse import SseServerTransport from mcp.server.sse import SseServerTransport
from mcp.types import JSONRPCMessage
@pytest.fixture @pytest.fixture
async def sse_server(): async def sse_transport():
"""Fixture that creates an SSE transport instance."""
return SseServerTransport("/messages/")
# Create an SSE transport at an endpoint
sse = SseServerTransport("/messages/")
# Create Starlette routes for SSE and message handling @pytest.fixture
routes = [ async def sse_app(sse_transport):
Route("/sse", endpoint=handle_sse), """Fixture that creates a Starlette app with SSE endpoints."""
Mount("/messages/", app=sse.handle_post_message),
]
#
# Create and run Starlette app
app = Starlette(routes=routes)
# Define handler functions
async def handle_sse(request): async def handle_sse(request):
async with sse.connect_sse( """Handler for SSE connections."""
async with sse_transport.connect_sse(
request.scope, request.receive, request._send request.scope, request.receive, request._send
) as streams: ) as streams:
await app.run( client_to_server, server_to_client = streams
streams[0], streams[1], app.create_initialization_options() async for message in client_to_server:
) # Echo messages back for testing
await server_to_client.send(message)
uvicorn.run(app, host="127.0.0.1", port=34891) routes = [
Route("/sse", endpoint=handle_sse),
Mount("/messages", app=sse_transport.handle_post_message),
]
async def sse_handler(request): return Starlette(routes=routes)
response = httpx.Response(200, content_type="text/event-stream")
response.send_headers()
response.write("data: test\n\n")
await response.aclose()
async with httpx.AsyncServer(sse_handler) as server:
yield server.url
@pytest.fixture @pytest.fixture
async def sse_client(): async def test_client(sse_app):
async with sse_client("http://test/sse") as (read_stream, write_stream): """Create a test client with ASGI transport."""
async with read_stream: async with httpx.AsyncClient(
async for message in read_stream: transport=ASGITransport(app=sse_app),
if isinstance(message, Exception): base_url="http://testserver",
raise message ) as client:
yield client
return read_stream, write_stream
@pytest.mark.anyio @pytest.mark.anyio
async def test_sse_happy_path(monkeypatch): async def test_sse_connection(test_client):
# Mock httpx.AsyncClient to return our mock response """Test basic SSE connection and message exchange."""
monkeypatch.setattr(httpx, "AsyncClient", MockClient)
with pytest.raises(ReadTimeout) as exc_info:
async with sse_client( async with sse_client(
"http://test/sse", "http://testserver/sse",
timeout=5, # Connection timeout - make this longer headers={"Host": "testserver"},
sse_read_timeout=1 # Read timeout - this should trigger timeout=5,
client=test_client,
) as (read_stream, write_stream):
# Send a test message
test_message = JSONRPCMessage.model_validate({"jsonrpc": "2.0", "method": "test"})
await write_stream.send(test_message)
# Receive echoed message
async with read_stream:
message = await read_stream.__anext__()
assert isinstance(message, JSONRPCMessage)
assert message.model_dump() == test_message.model_dump()
@pytest.mark.anyio
async def test_sse_read_timeout(test_client):
"""Test that SSE client properly handles read timeouts."""
with pytest.raises(ReadTimeout):
async with sse_client(
"http://testserver/sse",
headers={"Host": "testserver"},
timeout=5,
sse_read_timeout=1,
client=test_client,
) as (read_stream, write_stream): ) as (read_stream, write_stream):
async with read_stream: async with read_stream:
async for message in read_stream: # This should timeout since no messages are being sent
if isinstance(message, Exception): await read_stream.__anext__()
raise message
error = exc_info.value
assert isinstance(error, ReadTimeout)
assert str(error) == "Read timeout"
@pytest.mark.anyio @pytest.mark.anyio
async def test_sse_read_timeouts(monkeypatch): async def test_sse_connection_error(test_client):
"""Test that the SSE client properly handles read timeouts between SSE messages.""" """Test SSE client behavior with connection errors."""
with pytest.raises(httpx.HTTPError):
async with sse_client(
"http://testserver/nonexistent",
headers={"Host": "testserver"},
timeout=5,
client=test_client,
):
pass # Should not reach here
@pytest.mark.anyio
async def test_sse_multiple_messages(test_client):
"""Test sending and receiving multiple SSE messages."""
async with sse_client(
"http://testserver/sse",
headers={"Host": "testserver"},
timeout=5,
client=test_client,
) as (read_stream, write_stream):
# Send multiple test messages
messages = [
JSONRPCMessage.model_validate({"jsonrpc": "2.0", "method": f"test{i}"})
for i in range(3)
]
for msg in messages:
await write_stream.send(msg)
# Receive all echoed messages
received = []
async with read_stream:
for _ in range(len(messages)):
message = await read_stream.__anext__()
assert isinstance(message, JSONRPCMessage)
received.append(message)
# Verify all messages were received in order
for sent, received in zip(messages, received):
assert sent.model_dump() == received.model_dump()