From e79a56435a9f6bfb84ffb5317501aceeda8ca48a Mon Sep 17 00:00:00 2001 From: Nick Merrill Date: Tue, 14 Jan 2025 10:34:32 -0500 Subject: [PATCH] passing SSE client test --- tests/shared/test_sse.py | 133 +++++++++++++++------------------------ 1 file changed, 49 insertions(+), 84 deletions(-) diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 07a859f..ee3edb9 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -3,20 +3,36 @@ import re import time import json import anyio +from pydantic import AnyUrl +from pydantic_core import Url import pytest import httpx from typing import AsyncGenerator from starlette.applications import Starlette from starlette.routing import Mount, Route +from mcp.client.session import ClientSession +from mcp.client.sse import sse_client from mcp.server import Server from mcp.server.sse import SseServerTransport -from mcp.types import TextContent, Tool +from mcp.types import EmptyResult, InitializeResult, TextContent, TextResourceContents, Tool + +SERVER_URL = "http://127.0.0.1:8765" +SERVER_SSE_URL = f"{SERVER_URL}/sse" + +SERVER_NAME = "test_server_for_SSE" # Test server implementation class TestServer(Server): def __init__(self): - super().__init__("test_server") + super().__init__(SERVER_NAME) + + @self.read_resource() + async def handle_read_resource(uri: AnyUrl) -> str | bytes: + if uri.scheme == "foobar": + return f"Read {uri.host}" + # TODO: make this an error + return "NOT FOUND" @self.list_tools() async def handle_list_tools(): @@ -76,18 +92,18 @@ def server(server_app: Starlette): server_thread.join(timeout=0.1) @pytest.fixture() -async def client(server) -> AsyncGenerator[httpx.AsyncClient, None]: +async def http_client(server) -> AsyncGenerator[httpx.AsyncClient, None]: """Create test client""" - async with httpx.AsyncClient(base_url="http://127.0.0.1:8765") as client: + async with httpx.AsyncClient(base_url=SERVER_URL) as client: yield client # Tests @pytest.mark.anyio -async def test_sse_connection(client: httpx.AsyncClient): - """Test SSE connection establishment""" +async def test_raw_sse_connection(http_client: httpx.AsyncClient): + """Test the SSE connection establishment simply with an HTTP client.""" async with anyio.create_task_group() as tg: async def connection_test(): - async with client.stream("GET", "/sse") as response: + async with http_client.stream("GET", "/sse") as response: assert response.status_code == 200 assert response.headers["content-type"] == "text/event-stream; charset=utf-8" @@ -105,84 +121,33 @@ async def test_sse_connection(client: httpx.AsyncClient): with anyio.fail_after(3): await connection_test() -@pytest.mark.anyio -async def test_message_exchange(client: httpx.AsyncClient): - """Test full message exchange flow""" - # Connect to SSE endpoint - session_id = None - endpoint_url = None - - async with client.stream("GET", "/sse") as sse_response: - assert sse_response.status_code == 200 - - # Get endpoint URL and session ID - async for line in sse_response.aiter_lines(): - if line.startswith("data: "): - endpoint_url = json.loads(line[6:]) - session_id = endpoint_url.split("session_id=")[1] - break - - assert endpoint_url and session_id - - # Send initialize request - init_request = { - "jsonrpc": "2.0", - "id": 1, - "method": "initialize", - "params": { - "protocolVersion": "2024-11-05", - "capabilities": {}, - "clientInfo": { - "name": "test_client", - "version": "1.0" - } - } - } - - response = await client.post( - endpoint_url, - json=init_request - ) - assert response.status_code == 202 - - # Get initialize response from SSE stream - async for line in sse_response.aiter_lines(): - if line.startswith("event: message"): - data_line = next(sse_response.aiter_lines()) - response = json.loads(data_line[6:]) # Strip "data: " prefix - assert response["jsonrpc"] == "2.0" - assert response["id"] == 1 - assert "result" in response - break @pytest.mark.anyio -async def test_invalid_session(client: httpx.AsyncClient): - """Test sending message with invalid session ID""" - response = await client.post( - "/messages/?session_id=invalid", - json={"jsonrpc": "2.0", "method": "ping"} - ) - assert response.status_code == 400 +async def test_sse_client_basic_connection(server): + async with sse_client(SERVER_SSE_URL) as streams: + async with ClientSession(*streams) as session: + # Test initialization + result = await session.initialize() + assert isinstance(result, InitializeResult) + assert result.serverInfo.name == SERVER_NAME + + # Test ping + ping_result = await session.send_ping() + assert isinstance(ping_result, EmptyResult) + +@pytest.fixture +async def initialized_sse_client_session(server) -> AsyncGenerator[ClientSession, None]: + async with sse_client(SERVER_SSE_URL) as streams: + async with ClientSession(*streams) as session: + await session.initialize() + yield session @pytest.mark.anyio -async def test_connection_cleanup(server_app): - """Test that resources are cleaned up when client disconnects""" - sse = next( - route.app for route in server_app.routes - if isinstance(route, Mount) and route.path == "/messages/" - ).transport - - async with httpx.AsyncClient(app=server_app, base_url="http://test") as client: - # Connect and get session ID - async with client.stream("GET", "/sse") as response: - for line in response.iter_lines(): - if line.startswith("data: "): - endpoint_url = json.loads(line[6:]) - session_id = endpoint_url.split("session_id=")[1] - break - - assert len(sse._read_stream_writers) == 1 - - # After connection closes, writer should be cleaned up - await anyio.sleep(0.1) # Give cleanup a moment - assert len(sse._read_stream_writers) == 0 +async def test_sse_client_request_and_response(initialized_sse_client_session: ClientSession): + session = initialized_sse_client_session + # TODO: expect raise + await session.read_resource(uri=AnyUrl("xxx://will-not-work")) + response = await session.read_resource(uri=AnyUrl("foobar://should-work")) + assert len(response.contents) == 1 + assert isinstance(response.contents[0], TextResourceContents) + assert response.contents[0].text == "Read should-work"