diff --git a/tests/client/test_sse_attempt.py b/tests/client/test_sse_attempt.py index 4005467..1f856bb 100644 --- a/tests/client/test_sse_attempt.py +++ b/tests/client/test_sse_attempt.py @@ -4,6 +4,8 @@ from starlette.applications import Starlette from starlette.routing import Mount, Route import httpx from httpx import ReadTimeout, ASGITransport +from starlette.responses import Response +from sse_starlette.sse import EventSourceResponse from mcp.client.sse import sse_client from mcp.server.sse import SseServerTransport @@ -21,17 +23,33 @@ async def sse_app(sse_transport): """Fixture that creates a Starlette app with SSE endpoints.""" async def handle_sse(request): """Handler for SSE connections.""" - async with sse_transport.connect_sse( - request.scope, request.receive, request._send - ) as streams: - client_to_server, server_to_client = streams - async for message in client_to_server: - # Echo messages back for testing - await server_to_client.send(message) + async def event_generator(): + # Send initial connection event + yield { + "event": "endpoint", + "data": "/messages", + } + + # Keep connection alive + async with sse_transport.connect_sse( + request.scope, request.receive, request._send + ) as streams: + client_to_server, server_to_client = streams + async for message in client_to_server: + yield { + "event": "message", + "data": message.model_dump_json(), + } + + return EventSourceResponse(event_generator()) + + async def handle_post(request): + """Handler for POST messages.""" + return Response(status_code=200) routes = [ Route("/sse", endpoint=handle_sse), - Mount("/messages", app=sse_transport.handle_post_message), + Route("/messages", endpoint=handle_post, methods=["POST"]), ] return Starlette(routes=routes) @@ -40,9 +58,11 @@ async def sse_app(sse_transport): @pytest.fixture async def test_client(sse_app): """Create a test client with ASGI transport.""" + transport = ASGITransport(app=sse_app) async with httpx.AsyncClient( - transport=ASGITransport(app=sse_app), + transport=transport, base_url="http://testserver", + timeout=5.0, ) as client: yield client @@ -53,7 +73,8 @@ async def test_sse_connection(test_client): async with sse_client( "http://testserver/sse", headers={"Host": "testserver"}, - timeout=5, + timeout=2, + sse_read_timeout=1, client=test_client, ) as (read_stream, write_stream): # Send a test message @@ -74,7 +95,7 @@ async def test_sse_read_timeout(test_client): async with sse_client( "http://testserver/sse", headers={"Host": "testserver"}, - timeout=5, + timeout=2, sse_read_timeout=1, client=test_client, ) as (read_stream, write_stream): @@ -90,7 +111,7 @@ async def test_sse_connection_error(test_client): async with sse_client( "http://testserver/nonexistent", headers={"Host": "testserver"}, - timeout=5, + timeout=2, client=test_client, ): pass # Should not reach here @@ -102,7 +123,8 @@ async def test_sse_multiple_messages(test_client): async with sse_client( "http://testserver/sse", headers={"Host": "testserver"}, - timeout=5, + timeout=2, + sse_read_timeout=1, client=test_client, ) as (read_stream, write_stream): # Send multiple test messages