diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index ee3edb9..96c9758 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -3,6 +3,9 @@ import re import time import json import anyio +import threading +import uvicorn +import pytest from pydantic import AnyUrl from pydantic_core import Url import pytest @@ -11,17 +14,29 @@ from typing import AsyncGenerator from starlette.applications import Starlette from starlette.routing import Mount, Route +from mcp.shared.exceptions import McpError from mcp.client.session import ClientSession from mcp.client.sse import sse_client from mcp.server import Server from mcp.server.sse import SseServerTransport -from mcp.types import EmptyResult, InitializeResult, TextContent, TextResourceContents, Tool - -SERVER_URL = "http://127.0.0.1:8765" -SERVER_SSE_URL = f"{SERVER_URL}/sse" +from mcp.types import EmptyResult, ErrorData, InitializeResult, TextContent, TextResourceContents, Tool SERVER_NAME = "test_server_for_SSE" +@pytest.fixture +def server_port() -> int: + import socket + + s = socket.socket() + s.bind(('', 0)) + port = s.getsockname()[1] + s.close() + return port + +@pytest.fixture +def server_url(server_port: int) -> str: + return f"http://127.0.0.1:{server_port}" + # Test server implementation class TestServer(Server): def __init__(self): @@ -32,7 +47,7 @@ class TestServer(Server): if uri.scheme == "foobar": return f"Read {uri.host}" # TODO: make this an error - return "NOT FOUND" + raise McpError(error=ErrorData(code=404, message="OOPS! no resource with that URI was found")) @self.list_tools() async def handle_list_tools(): @@ -48,9 +63,6 @@ class TestServer(Server): async def handle_call_tool(name: str, args: dict): return [TextContent(type="text", text=f"Called {name}")] -import threading -import uvicorn -import pytest # Test fixtures @@ -78,10 +90,10 @@ async def server_app()-> Starlette: return app @pytest.fixture() -def server(server_app: Starlette): - server = uvicorn.Server(config=uvicorn.Config(app=server_app, host="127.0.0.1", port=8765, log_level="error")) +def server(server_app: Starlette, server_port: int): + server = uvicorn.Server(config=uvicorn.Config(app=server_app, host="127.0.0.1", port=server_port, log_level="error")) server_thread = threading.Thread( target=server.run, daemon=True ) - print('starting server') + print(f'starting server on {server_port}') server_thread.start() # Give server time to start while not server.started: @@ -92,9 +104,9 @@ def server(server_app: Starlette): server_thread.join(timeout=0.1) @pytest.fixture() -async def http_client(server) -> AsyncGenerator[httpx.AsyncClient, None]: +async def http_client(server, server_url) -> AsyncGenerator[httpx.AsyncClient, None]: """Create test client""" - async with httpx.AsyncClient(base_url=SERVER_URL) as client: + async with httpx.AsyncClient(base_url=server_url) as client: yield client # Tests @@ -123,8 +135,8 @@ async def test_raw_sse_connection(http_client: httpx.AsyncClient): @pytest.mark.anyio -async def test_sse_client_basic_connection(server): - async with sse_client(SERVER_SSE_URL) as streams: +async def test_sse_client_basic_connection(server, server_url): + async with sse_client(server_url + "/sse") as streams: async with ClientSession(*streams) as session: # Test initialization result = await session.initialize() @@ -136,18 +148,22 @@ async def test_sse_client_basic_connection(server): assert isinstance(ping_result, EmptyResult) @pytest.fixture -async def initialized_sse_client_session(server) -> AsyncGenerator[ClientSession, None]: - async with sse_client(SERVER_SSE_URL) as streams: +async def initialized_sse_client_session(server, server_url: str) -> AsyncGenerator[ClientSession, None]: + async with sse_client(server_url + "/sse") as streams: async with ClientSession(*streams) as session: await session.initialize() yield session @pytest.mark.anyio -async def test_sse_client_request_and_response(initialized_sse_client_session: ClientSession): +async def test_sse_client_happy_request_and_response(initialized_sse_client_session: ClientSession): session = initialized_sse_client_session - # TODO: expect raise - await session.read_resource(uri=AnyUrl("xxx://will-not-work")) response = await session.read_resource(uri=AnyUrl("foobar://should-work")) assert len(response.contents) == 1 assert isinstance(response.contents[0], TextResourceContents) assert response.contents[0].text == "Read should-work" + +@pytest.mark.anyio +async def test_sse_client_exception_handling(initialized_sse_client_session: ClientSession): + session = initialized_sse_client_session + with pytest.raises(McpError, match="OOPS! no resource with that URI was found"): + await session.read_resource(uri=AnyUrl("xxx://will-not-work"))