diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index abafacb..e09f6c5 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -24,12 +24,20 @@ async def sse_client( headers: dict[str, Any] | None = None, timeout: float = 5, sse_read_timeout: float = 60 * 5, + client: httpx.AsyncClient | None = None, ): """ Client transport for SSE. `sse_read_timeout` determines how long (in seconds) the client will wait for a new event before disconnecting. All other HTTP operations are controlled by `timeout`. + + Args: + url: The URL to connect to + headers: Optional headers to send with the request + timeout: Connection timeout in seconds + sse_read_timeout: Read timeout in seconds + client: Optional httpx.AsyncClient instance to use for requests """ read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception] @@ -43,7 +51,13 @@ async def sse_client( async with anyio.create_task_group() as tg: try: logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}") - async with httpx.AsyncClient(headers=headers) as client: + if client is None: + client = httpx.AsyncClient(headers=headers) + should_close_client = True + else: + should_close_client = False + + try: async with aconnect_sse( client, "GET", @@ -137,6 +151,9 @@ async def sse_client( yield read_stream, write_stream finally: tg.cancel_scope.cancel() + finally: + if should_close_client: + await client.aclose() finally: await read_stream_writer.aclose() await write_stream.aclose() diff --git a/tests/client/test_sse_attempt.py b/tests/client/test_sse_attempt.py index 8df152b..4005467 100644 --- a/tests/client/test_sse_attempt.py +++ b/tests/client/test_sse_attempt.py @@ -1,82 +1,127 @@ -import pytest import anyio +import pytest from starlette.applications import Starlette from starlette.routing import Mount, Route -import uvicorn -from mcp.client.sse import sse_client -from exceptiongroup import ExceptionGroup -import asyncio import httpx -from httpx import ReadTimeout +from httpx import ReadTimeout, ASGITransport +from mcp.client.sse import sse_client from mcp.server.sse import SseServerTransport +from mcp.types import JSONRPCMessage + @pytest.fixture -async def sse_server(): +async def sse_transport(): + """Fixture that creates an SSE transport instance.""" + return SseServerTransport("/messages/") - # Create an SSE transport at an endpoint - sse = SseServerTransport("/messages/") - # Create Starlette routes for SSE and message handling - routes = [ - Route("/sse", endpoint=handle_sse), - Mount("/messages/", app=sse.handle_post_message), - ] - # - # Create and run Starlette app - app = Starlette(routes=routes) - - # Define handler functions +@pytest.fixture +async def sse_app(sse_transport): + """Fixture that creates a Starlette app with SSE endpoints.""" async def handle_sse(request): - async with sse.connect_sse( + """Handler for SSE connections.""" + async with sse_transport.connect_sse( request.scope, request.receive, request._send ) as streams: - await app.run( - streams[0], streams[1], app.create_initialization_options() - ) + client_to_server, server_to_client = streams + async for message in client_to_server: + # Echo messages back for testing + await server_to_client.send(message) - uvicorn.run(app, host="127.0.0.1", port=34891) + routes = [ + Route("/sse", endpoint=handle_sse), + Mount("/messages", app=sse_transport.handle_post_message), + ] - async def sse_handler(request): - response = httpx.Response(200, content_type="text/event-stream") - response.send_headers() - response.write("data: test\n\n") - await response.aclose() - - async with httpx.AsyncServer(sse_handler) as server: - yield server.url + return Starlette(routes=routes) @pytest.fixture -async def sse_client(): - async with sse_client("http://test/sse") as (read_stream, write_stream): - async with read_stream: - async for message in read_stream: - if isinstance(message, Exception): - raise message +async def test_client(sse_app): + """Create a test client with ASGI transport.""" + async with httpx.AsyncClient( + transport=ASGITransport(app=sse_app), + base_url="http://testserver", + ) as client: + yield client - return read_stream, write_stream @pytest.mark.anyio -async def test_sse_happy_path(monkeypatch): - # Mock httpx.AsyncClient to return our mock response - monkeypatch.setattr(httpx, "AsyncClient", MockClient) +async def test_sse_connection(test_client): + """Test basic SSE connection and message exchange.""" + async with sse_client( + "http://testserver/sse", + headers={"Host": "testserver"}, + timeout=5, + client=test_client, + ) as (read_stream, write_stream): + # Send a test message + test_message = JSONRPCMessage.model_validate({"jsonrpc": "2.0", "method": "test"}) + await write_stream.send(test_message) - with pytest.raises(ReadTimeout) as exc_info: + # Receive echoed message + async with read_stream: + message = await read_stream.__anext__() + assert isinstance(message, JSONRPCMessage) + assert message.model_dump() == test_message.model_dump() + + +@pytest.mark.anyio +async def test_sse_read_timeout(test_client): + """Test that SSE client properly handles read timeouts.""" + with pytest.raises(ReadTimeout): async with sse_client( - "http://test/sse", - timeout=5, # Connection timeout - make this longer - sse_read_timeout=1 # Read timeout - this should trigger + "http://testserver/sse", + headers={"Host": "testserver"}, + timeout=5, + sse_read_timeout=1, + client=test_client, ) as (read_stream, write_stream): async with read_stream: - async for message in read_stream: - if isinstance(message, Exception): - raise message + # This should timeout since no messages are being sent + await read_stream.__anext__() - error = exc_info.value - assert isinstance(error, ReadTimeout) - assert str(error) == "Read timeout" @pytest.mark.anyio -async def test_sse_read_timeouts(monkeypatch): - """Test that the SSE client properly handles read timeouts between SSE messages.""" +async def test_sse_connection_error(test_client): + """Test SSE client behavior with connection errors.""" + with pytest.raises(httpx.HTTPError): + async with sse_client( + "http://testserver/nonexistent", + headers={"Host": "testserver"}, + timeout=5, + client=test_client, + ): + pass # Should not reach here + + +@pytest.mark.anyio +async def test_sse_multiple_messages(test_client): + """Test sending and receiving multiple SSE messages.""" + async with sse_client( + "http://testserver/sse", + headers={"Host": "testserver"}, + timeout=5, + client=test_client, + ) as (read_stream, write_stream): + # Send multiple test messages + messages = [ + JSONRPCMessage.model_validate({"jsonrpc": "2.0", "method": f"test{i}"}) + for i in range(3) + ] + + for msg in messages: + await write_stream.send(msg) + + # Receive all echoed messages + received = [] + async with read_stream: + for _ in range(len(messages)): + message = await read_stream.__anext__() + assert isinstance(message, JSONRPCMessage) + received.append(message) + + # Verify all messages were received in order + for sent, received in zip(messages, received): + assert sent.model_dump() == received.model_dump() \ No newline at end of file