diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 30c15ac..8f9221a 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -1,4 +1,3 @@ -# test_sse.py import re import multiprocessing import socket @@ -21,20 +20,30 @@ from mcp.client.session import ClientSession from mcp.client.sse import sse_client from mcp.server import Server from mcp.server.sse import SseServerTransport -from mcp.types import EmptyResult, ErrorData, InitializeResult, TextContent, TextResourceContents, Tool +from mcp.types import ( + EmptyResult, + ErrorData, + InitializeResult, + TextContent, + TextResourceContents, + Tool, +) SERVER_NAME = "test_server_for_SSE" + @pytest.fixture def server_port() -> int: with socket.socket() as s: - s.bind(('127.0.0.1', 0)) + s.bind(("127.0.0.1", 0)) return s.getsockname()[1] + @pytest.fixture def server_url(server_port: int) -> str: return f"http://127.0.0.1:{server_port}" + # Test server implementation class TestServer(Server): def __init__(self): @@ -45,7 +54,11 @@ class TestServer(Server): if uri.scheme == "foobar": return f"Read {uri.host}" # TODO: make this an error - raise McpError(error=ErrorData(code=404, message="OOPS! no resource with that URI was found")) + raise McpError( + error=ErrorData( + code=404, message="OOPS! no resource with that URI was found" + ) + ) @self.list_tools() async def handle_list_tools(): @@ -53,7 +66,7 @@ class TestServer(Server): Tool( name="test_tool", description="A test tool", - inputSchema={"type": "object", "properties": {}} + inputSchema={"type": "object", "properties": {}}, ) ] @@ -62,9 +75,8 @@ class TestServer(Server): return [TextContent(type="text", text=f"Called {name}")] - # Test fixtures -def make_server_app()-> Starlette: +def make_server_app() -> Starlette: """Create test Starlette app with SSE transport""" sse = SseServerTransport("/messages/") server = TestServer() @@ -74,80 +86,97 @@ def make_server_app()-> Starlette: request.scope, request.receive, request._send ) as streams: await server.run( - streams[0], - streams[1], - server.create_initialization_options() + streams[0], streams[1], server.create_initialization_options() ) - app = Starlette(routes=[ - Route("/sse", endpoint=handle_sse), - Mount("/messages/", app=sse.handle_post_message), - ]) + app = Starlette( + routes=[ + Route("/sse", endpoint=handle_sse), + Mount("/messages/", app=sse.handle_post_message), + ] + ) return app + @pytest.fixture(autouse=True) def space_around_test(): time.sleep(0.1) yield time.sleep(0.1) + def run_server(server_port: int): app = make_server_app() - server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error")) - print(f'starting server on {server_port}') + server = uvicorn.Server( + config=uvicorn.Config( + app=app, host="127.0.0.1", port=server_port, log_level="error" + ) + ) + print(f"starting server on {server_port}") server.run() # Give server time to start while not server.started: - print('waiting for server to start') + print("waiting for server to start") time.sleep(0.5) + @pytest.fixture() def server(server_port: int): - proc = multiprocessing.Process(target=run_server, kwargs={"server_port": server_port}, daemon=True) - print('starting process') + proc = multiprocessing.Process( + target=run_server, kwargs={"server_port": server_port}, daemon=True + ) + print("starting process") proc.start() # Wait for server to be running max_attempts = 20 attempt = 0 - print('waiting for server to start') + print("waiting for server to start") while attempt < max_attempts: try: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.connect(('127.0.0.1', server_port)) + s.connect(("127.0.0.1", server_port)) break except ConnectionRefusedError: time.sleep(0.1) attempt += 1 else: - raise RuntimeError("Server failed to start after {} attempts".format(max_attempts)) + raise RuntimeError( + "Server failed to start after {} attempts".format(max_attempts) + ) yield - print('killing server') + print("killing server") # Signal the server to stop proc.kill() proc.join(timeout=2) if proc.is_alive(): print("server process failed to terminate") + @pytest.fixture() async def http_client(server, server_url) -> AsyncGenerator[httpx.AsyncClient, None]: """Create test client""" async with httpx.AsyncClient(base_url=server_url) as client: yield client + # Tests @pytest.mark.anyio async def test_raw_sse_connection(http_client: httpx.AsyncClient): """Test the SSE connection establishment simply with an HTTP client.""" async with anyio.create_task_group() as tg: + async def connection_test(): async with http_client.stream("GET", "/sse") as response: assert response.status_code == 200 - assert response.headers["content-type"] == "text/event-stream; charset=utf-8" + assert ( + response.headers["content-type"] + == "text/event-stream; charset=utf-8" + ) line_number = 0 async for line in response.aiter_lines(): @@ -177,23 +206,32 @@ async def test_sse_client_basic_connection(server, server_url): ping_result = await session.send_ping() assert isinstance(ping_result, EmptyResult) + @pytest.fixture -async def initialized_sse_client_session(server, server_url: str) -> AsyncGenerator[ClientSession, None]: +async def initialized_sse_client_session( + server, server_url: str +) -> AsyncGenerator[ClientSession, None]: async with sse_client(server_url + "/sse") as streams: async with ClientSession(*streams) as session: await session.initialize() yield session + @pytest.mark.anyio -async def test_sse_client_happy_request_and_response(initialized_sse_client_session: ClientSession): +async def test_sse_client_happy_request_and_response( + initialized_sse_client_session: ClientSession, +): session = initialized_sse_client_session response = await session.read_resource(uri=AnyUrl("foobar://should-work")) assert len(response.contents) == 1 assert isinstance(response.contents[0], TextResourceContents) assert response.contents[0].text == "Read should-work" + @pytest.mark.anyio -async def test_sse_client_exception_handling(initialized_sse_client_session: ClientSession): +async def test_sse_client_exception_handling( + initialized_sse_client_session: ClientSession, +): session = initialized_sse_client_session with pytest.raises(McpError, match="OOPS! no resource with that URI was found"): await session.read_resource(uri=AnyUrl("xxx://will-not-work"))