passing SSE client test

This commit is contained in:
Nick Merrill
2025-01-14 10:34:32 -05:00
parent 66ccd1c515
commit e79a56435a

View File

@@ -3,20 +3,36 @@ import re
import time import time
import json import json
import anyio import anyio
from pydantic import AnyUrl
from pydantic_core import Url
import pytest import pytest
import httpx import httpx
from typing import AsyncGenerator from typing import AsyncGenerator
from starlette.applications import Starlette from starlette.applications import Starlette
from starlette.routing import Mount, Route from starlette.routing import Mount, Route
from mcp.client.session import ClientSession
from mcp.client.sse import sse_client
from mcp.server import Server from mcp.server import Server
from mcp.server.sse import SseServerTransport from mcp.server.sse import SseServerTransport
from mcp.types import TextContent, Tool from mcp.types import EmptyResult, InitializeResult, TextContent, TextResourceContents, Tool
SERVER_URL = "http://127.0.0.1:8765"
SERVER_SSE_URL = f"{SERVER_URL}/sse"
SERVER_NAME = "test_server_for_SSE"
# Test server implementation # Test server implementation
class TestServer(Server): class TestServer(Server):
def __init__(self): def __init__(self):
super().__init__("test_server") super().__init__(SERVER_NAME)
@self.read_resource()
async def handle_read_resource(uri: AnyUrl) -> str | bytes:
if uri.scheme == "foobar":
return f"Read {uri.host}"
# TODO: make this an error
return "NOT FOUND"
@self.list_tools() @self.list_tools()
async def handle_list_tools(): async def handle_list_tools():
@@ -76,18 +92,18 @@ def server(server_app: Starlette):
server_thread.join(timeout=0.1) server_thread.join(timeout=0.1)
@pytest.fixture() @pytest.fixture()
async def client(server) -> AsyncGenerator[httpx.AsyncClient, None]: async def http_client(server) -> AsyncGenerator[httpx.AsyncClient, None]:
"""Create test client""" """Create test client"""
async with httpx.AsyncClient(base_url="http://127.0.0.1:8765") as client: async with httpx.AsyncClient(base_url=SERVER_URL) as client:
yield client yield client
# Tests # Tests
@pytest.mark.anyio @pytest.mark.anyio
async def test_sse_connection(client: httpx.AsyncClient): async def test_raw_sse_connection(http_client: httpx.AsyncClient):
"""Test SSE connection establishment""" """Test the SSE connection establishment simply with an HTTP client."""
async with anyio.create_task_group() as tg: async with anyio.create_task_group() as tg:
async def connection_test(): async def connection_test():
async with client.stream("GET", "/sse") as response: async with http_client.stream("GET", "/sse") as response:
assert response.status_code == 200 assert response.status_code == 200
assert response.headers["content-type"] == "text/event-stream; charset=utf-8" assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
@@ -105,84 +121,33 @@ async def test_sse_connection(client: httpx.AsyncClient):
with anyio.fail_after(3): with anyio.fail_after(3):
await connection_test() await connection_test()
@pytest.mark.anyio
async def test_message_exchange(client: httpx.AsyncClient):
"""Test full message exchange flow"""
# Connect to SSE endpoint
session_id = None
endpoint_url = None
async with client.stream("GET", "/sse") as sse_response:
assert sse_response.status_code == 200
# Get endpoint URL and session ID
async for line in sse_response.aiter_lines():
if line.startswith("data: "):
endpoint_url = json.loads(line[6:])
session_id = endpoint_url.split("session_id=")[1]
break
assert endpoint_url and session_id
# Send initialize request
init_request = {
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {},
"clientInfo": {
"name": "test_client",
"version": "1.0"
}
}
}
response = await client.post(
endpoint_url,
json=init_request
)
assert response.status_code == 202
# Get initialize response from SSE stream
async for line in sse_response.aiter_lines():
if line.startswith("event: message"):
data_line = next(sse_response.aiter_lines())
response = json.loads(data_line[6:]) # Strip "data: " prefix
assert response["jsonrpc"] == "2.0"
assert response["id"] == 1
assert "result" in response
break
@pytest.mark.anyio @pytest.mark.anyio
async def test_invalid_session(client: httpx.AsyncClient): async def test_sse_client_basic_connection(server):
"""Test sending message with invalid session ID""" async with sse_client(SERVER_SSE_URL) as streams:
response = await client.post( async with ClientSession(*streams) as session:
"/messages/?session_id=invalid", # Test initialization
json={"jsonrpc": "2.0", "method": "ping"} result = await session.initialize()
) assert isinstance(result, InitializeResult)
assert response.status_code == 400 assert result.serverInfo.name == SERVER_NAME
# Test ping
ping_result = await session.send_ping()
assert isinstance(ping_result, EmptyResult)
@pytest.fixture
async def initialized_sse_client_session(server) -> AsyncGenerator[ClientSession, None]:
async with sse_client(SERVER_SSE_URL) as streams:
async with ClientSession(*streams) as session:
await session.initialize()
yield session
@pytest.mark.anyio @pytest.mark.anyio
async def test_connection_cleanup(server_app): async def test_sse_client_request_and_response(initialized_sse_client_session: ClientSession):
"""Test that resources are cleaned up when client disconnects""" session = initialized_sse_client_session
sse = next( # TODO: expect raise
route.app for route in server_app.routes await session.read_resource(uri=AnyUrl("xxx://will-not-work"))
if isinstance(route, Mount) and route.path == "/messages/" response = await session.read_resource(uri=AnyUrl("foobar://should-work"))
).transport assert len(response.contents) == 1
assert isinstance(response.contents[0], TextResourceContents)
async with httpx.AsyncClient(app=server_app, base_url="http://test") as client: assert response.contents[0].text == "Read should-work"
# Connect and get session ID
async with client.stream("GET", "/sse") as response:
for line in response.iter_lines():
if line.startswith("data: "):
endpoint_url = json.loads(line[6:])
session_id = endpoint_url.split("session_id=")[1]
break
assert len(sse._read_stream_writers) == 1
# After connection closes, writer should be cleaned up
await anyio.sleep(0.1) # Give cleanup a moment
assert len(sse._read_stream_writers) == 0