From 66ccd1c515814ff4631b69c2bf0d1916aada91e8 Mon Sep 17 00:00:00 2001 From: Nick Merrill Date: Tue, 14 Jan 2025 10:18:11 -0500 Subject: [PATCH] test_sse_connection is passing --- tests/client/test_sse_attempt.py | 173 ---------------------------- tests/shared/test_sse.py | 188 +++++++++++++++++++++++++++++++ 2 files changed, 188 insertions(+), 173 deletions(-) delete mode 100644 tests/client/test_sse_attempt.py create mode 100644 tests/shared/test_sse.py diff --git a/tests/client/test_sse_attempt.py b/tests/client/test_sse_attempt.py deleted file mode 100644 index 7d73291..0000000 --- a/tests/client/test_sse_attempt.py +++ /dev/null @@ -1,173 +0,0 @@ -import anyio -import asyncio -import pytest -from starlette.applications import Starlette -from starlette.routing import Mount, Route -import httpx -from httpx import ReadTimeout, ASGITransport -from starlette.responses import Response -from sse_starlette.sse import EventSourceResponse - -from mcp.client.sse import sse_client -from mcp.server.sse import SseServerTransport -from mcp.types import JSONRPCMessage - - -@pytest.fixture -async def sse_transport(): - """Fixture that creates an SSE transport instance.""" - return SseServerTransport("/messages/") - - -@pytest.fixture -async def sse_app(sse_transport): - """Fixture that creates a Starlette app with SSE endpoints.""" - async def handle_sse(request): - """Handler for SSE connections.""" - async def event_generator(): - try: - async with sse_transport.connect_sse( - request.scope, request.receive, request._send - ) as streams: - client_to_server, server_to_client = streams - # Send initial connection event - yield { - "event": "endpoint", - "data": "/messages", - } - - # Process messages - async with anyio.create_task_group() as tg: - try: - async for message in client_to_server: - if isinstance(message, Exception): - break - yield { - "event": "message", - "data": message.model_dump_json(), - } - except (asyncio.CancelledError, GeneratorExit): - print('cancelled') - return - except Exception as e: - print("unhandled exception:", e) - return - except Exception: - # Log any unexpected errors but allow connection to close gracefully - pass - - return EventSourceResponse(event_generator()) - - routes = [ - Route("/sse", endpoint=handle_sse), - Mount("/messages", app=sse_transport.handle_post_message), - ] - - return Starlette(routes=routes) - - -@pytest.fixture -async def test_client(sse_app): - """Create a test client with ASGI transport.""" - transport = ASGITransport(app=sse_app) - async with httpx.AsyncClient( - transport=transport, - base_url="http://testserver", - timeout=10.0, - ) as client: - yield client - - -@pytest.mark.anyio -async def test_sse_connection(test_client): - """Test basic SSE connection and message exchange.""" - async with anyio.create_task_group() as tg: - try: - async with sse_client( - "http://testserver/sse", - headers={"Host": "testserver"}, - timeout=5, - sse_read_timeout=5, - client=test_client, - ) as (read_stream, write_stream): - # First get the initial endpoint message - async with read_stream: - init_message = await read_stream.__anext__() - assert isinstance(init_message, JSONRPCMessage) - - # Send a test message - test_message = JSONRPCMessage.model_validate({"jsonrpc": "2.0", "method": "test"}) - await write_stream.send(test_message) - - # Receive echoed message - async with read_stream: - message = await read_stream.__anext__() - assert isinstance(message, JSONRPCMessage) - assert message.model_dump() == test_message.model_dump() - - # Explicitly close streams - await write_stream.aclose() - await read_stream.aclose() - except Exception as e: - pytest.fail(f"Test failed with error: {str(e)}") - - -# @pytest.mark.anyio -# async def test_sse_read_timeout(test_client): -# """Test that SSE client properly handles read timeouts.""" -# with pytest.raises(ReadTimeout): -# async with sse_client( -# "http://testserver/sse", -# headers={"Host": "testserver"}, -# timeout=5, -# sse_read_timeout=2, -# client=test_client, -# ) as (read_stream, write_stream): -# async with read_stream: -# # This should timeout since no messages are being sent -# await read_stream.__anext__() - - -# @pytest.mark.anyio -# async def test_sse_connection_error(test_client): -# """Test SSE client behavior with connection errors.""" -# with pytest.raises(httpx.HTTPError): -# async with sse_client( -# "http://testserver/nonexistent", -# headers={"Host": "testserver"}, -# timeout=5, -# client=test_client, -# ): -# pass # Should not reach here - - -# @pytest.mark.anyio -# async def test_sse_multiple_messages(test_client): -# """Test sending and receiving multiple SSE messages.""" -# async with sse_client( -# "http://testserver/sse", -# headers={"Host": "testserver"}, -# timeout=5, -# sse_read_timeout=5, -# client=test_client, -# ) as (read_stream, write_stream): -# # Send multiple test messages -# messages = [ -# JSONRPCMessage.model_validate({"jsonrpc": "2.0", "method": f"test{i}"}) -# for i in range(3) -# ] - -# for msg in messages: -# await write_stream.send(msg) - -# # Receive all echoed messages -# received = [] -# async with read_stream: -# for _ in range(len(messages)): -# message = await read_stream.__anext__() -# assert isinstance(message, JSONRPCMessage) -# received.append(message) - -# # Verify all messages were received in order -# for sent, received in zip(messages, received): -# assert sent.model_dump() == received.model_dump() diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py new file mode 100644 index 0000000..07a859f --- /dev/null +++ b/tests/shared/test_sse.py @@ -0,0 +1,188 @@ +# test_sse.py +import re +import time +import json +import anyio +import pytest +import httpx +from typing import AsyncGenerator +from starlette.applications import Starlette +from starlette.routing import Mount, Route + +from mcp.server import Server +from mcp.server.sse import SseServerTransport +from mcp.types import TextContent, Tool + +# Test server implementation +class TestServer(Server): + def __init__(self): + super().__init__("test_server") + + @self.list_tools() + async def handle_list_tools(): + return [ + Tool( + name="test_tool", + description="A test tool", + inputSchema={"type": "object", "properties": {}} + ) + ] + + @self.call_tool() + async def handle_call_tool(name: str, args: dict): + return [TextContent(type="text", text=f"Called {name}")] + +import threading +import uvicorn +import pytest + + +# Test fixtures +@pytest.fixture +async def server_app()-> Starlette: + """Create test Starlette app with SSE transport""" + sse = SseServerTransport("/messages/") + server = TestServer() + + async def handle_sse(request): + async with sse.connect_sse( + request.scope, request.receive, request._send + ) as streams: + await server.run( + streams[0], + streams[1], + server.create_initialization_options() + ) + + app = Starlette(routes=[ + Route("/sse", endpoint=handle_sse), + Mount("/messages/", app=sse.handle_post_message), + ]) + + return app + +@pytest.fixture() +def server(server_app: Starlette): + server = uvicorn.Server(config=uvicorn.Config(app=server_app, host="127.0.0.1", port=8765, log_level="error")) + server_thread = threading.Thread( target=server.run, daemon=True ) + print('starting server') + server_thread.start() + # Give server time to start + while not server.started: + print('waiting for server to start') + time.sleep(0.5) + yield + print('killing server') + server_thread.join(timeout=0.1) + +@pytest.fixture() +async def client(server) -> AsyncGenerator[httpx.AsyncClient, None]: + """Create test client""" + async with httpx.AsyncClient(base_url="http://127.0.0.1:8765") as client: + yield client + +# Tests +@pytest.mark.anyio +async def test_sse_connection(client: httpx.AsyncClient): + """Test SSE connection establishment""" + async with anyio.create_task_group() as tg: + async def connection_test(): + async with client.stream("GET", "/sse") as response: + assert response.status_code == 200 + assert response.headers["content-type"] == "text/event-stream; charset=utf-8" + + line_number = 0 + async for line in response.aiter_lines(): + if line_number == 0: + assert line == "event: endpoint" + elif line_number == 1: + assert line.startswith("data: /messages/?session_id=") + else: + return + line_number += 1 + + # Add timeout to prevent test from hanging if it fails + with anyio.fail_after(3): + await connection_test() + +@pytest.mark.anyio +async def test_message_exchange(client: httpx.AsyncClient): + """Test full message exchange flow""" + # Connect to SSE endpoint + session_id = None + endpoint_url = None + + async with client.stream("GET", "/sse") as sse_response: + assert sse_response.status_code == 200 + + # Get endpoint URL and session ID + async for line in sse_response.aiter_lines(): + if line.startswith("data: "): + endpoint_url = json.loads(line[6:]) + session_id = endpoint_url.split("session_id=")[1] + break + + assert endpoint_url and session_id + + # Send initialize request + init_request = { + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": { + "name": "test_client", + "version": "1.0" + } + } + } + + response = await client.post( + endpoint_url, + json=init_request + ) + assert response.status_code == 202 + + # Get initialize response from SSE stream + async for line in sse_response.aiter_lines(): + if line.startswith("event: message"): + data_line = next(sse_response.aiter_lines()) + response = json.loads(data_line[6:]) # Strip "data: " prefix + assert response["jsonrpc"] == "2.0" + assert response["id"] == 1 + assert "result" in response + break + +@pytest.mark.anyio +async def test_invalid_session(client: httpx.AsyncClient): + """Test sending message with invalid session ID""" + response = await client.post( + "/messages/?session_id=invalid", + json={"jsonrpc": "2.0", "method": "ping"} + ) + assert response.status_code == 400 + +@pytest.mark.anyio +async def test_connection_cleanup(server_app): + """Test that resources are cleaned up when client disconnects""" + sse = next( + route.app for route in server_app.routes + if isinstance(route, Mount) and route.path == "/messages/" + ).transport + + async with httpx.AsyncClient(app=server_app, base_url="http://test") as client: + # Connect and get session ID + async with client.stream("GET", "/sse") as response: + for line in response.iter_lines(): + if line.startswith("data: "): + endpoint_url = json.loads(line[6:]) + session_id = endpoint_url.split("session_id=")[1] + break + + assert len(sse._read_stream_writers) == 1 + + # After connection closes, writer should be cleaned up + await anyio.sleep(0.1) # Give cleanup a moment + assert len(sse._read_stream_writers) == 0