passing SSE client test

This commit is contained in:
Nick Merrill
2025-01-14 10:34:32 -05:00
parent 66ccd1c515
commit e79a56435a

View File

@@ -3,20 +3,36 @@ import re
import time
import json
import anyio
from pydantic import AnyUrl
from pydantic_core import Url
import pytest
import httpx
from typing import AsyncGenerator
from starlette.applications import Starlette
from starlette.routing import Mount, Route
from mcp.client.session import ClientSession
from mcp.client.sse import sse_client
from mcp.server import Server
from mcp.server.sse import SseServerTransport
from mcp.types import TextContent, Tool
from mcp.types import EmptyResult, InitializeResult, TextContent, TextResourceContents, Tool
SERVER_URL = "http://127.0.0.1:8765"
SERVER_SSE_URL = f"{SERVER_URL}/sse"
SERVER_NAME = "test_server_for_SSE"
# Test server implementation
class TestServer(Server):
def __init__(self):
super().__init__("test_server")
super().__init__(SERVER_NAME)
@self.read_resource()
async def handle_read_resource(uri: AnyUrl) -> str | bytes:
if uri.scheme == "foobar":
return f"Read {uri.host}"
# TODO: make this an error
return "NOT FOUND"
@self.list_tools()
async def handle_list_tools():
@@ -76,18 +92,18 @@ def server(server_app: Starlette):
server_thread.join(timeout=0.1)
@pytest.fixture()
async def client(server) -> AsyncGenerator[httpx.AsyncClient, None]:
async def http_client(server) -> AsyncGenerator[httpx.AsyncClient, None]:
"""Create test client"""
async with httpx.AsyncClient(base_url="http://127.0.0.1:8765") as client:
async with httpx.AsyncClient(base_url=SERVER_URL) as client:
yield client
# Tests
@pytest.mark.anyio
async def test_sse_connection(client: httpx.AsyncClient):
"""Test SSE connection establishment"""
async def test_raw_sse_connection(http_client: httpx.AsyncClient):
"""Test the SSE connection establishment simply with an HTTP client."""
async with anyio.create_task_group() as tg:
async def connection_test():
async with client.stream("GET", "/sse") as response:
async with http_client.stream("GET", "/sse") as response:
assert response.status_code == 200
assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
@@ -105,84 +121,33 @@ async def test_sse_connection(client: httpx.AsyncClient):
with anyio.fail_after(3):
await connection_test()
@pytest.mark.anyio
async def test_message_exchange(client: httpx.AsyncClient):
"""Test full message exchange flow"""
# Connect to SSE endpoint
session_id = None
endpoint_url = None
async with client.stream("GET", "/sse") as sse_response:
assert sse_response.status_code == 200
# Get endpoint URL and session ID
async for line in sse_response.aiter_lines():
if line.startswith("data: "):
endpoint_url = json.loads(line[6:])
session_id = endpoint_url.split("session_id=")[1]
break
assert endpoint_url and session_id
# Send initialize request
init_request = {
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {},
"clientInfo": {
"name": "test_client",
"version": "1.0"
}
}
}
response = await client.post(
endpoint_url,
json=init_request
)
assert response.status_code == 202
# Get initialize response from SSE stream
async for line in sse_response.aiter_lines():
if line.startswith("event: message"):
data_line = next(sse_response.aiter_lines())
response = json.loads(data_line[6:]) # Strip "data: " prefix
assert response["jsonrpc"] == "2.0"
assert response["id"] == 1
assert "result" in response
break
@pytest.mark.anyio
async def test_invalid_session(client: httpx.AsyncClient):
"""Test sending message with invalid session ID"""
response = await client.post(
"/messages/?session_id=invalid",
json={"jsonrpc": "2.0", "method": "ping"}
)
assert response.status_code == 400
async def test_sse_client_basic_connection(server):
async with sse_client(SERVER_SSE_URL) as streams:
async with ClientSession(*streams) as session:
# Test initialization
result = await session.initialize()
assert isinstance(result, InitializeResult)
assert result.serverInfo.name == SERVER_NAME
# Test ping
ping_result = await session.send_ping()
assert isinstance(ping_result, EmptyResult)
@pytest.fixture
async def initialized_sse_client_session(server) -> AsyncGenerator[ClientSession, None]:
async with sse_client(SERVER_SSE_URL) as streams:
async with ClientSession(*streams) as session:
await session.initialize()
yield session
@pytest.mark.anyio
async def test_connection_cleanup(server_app):
"""Test that resources are cleaned up when client disconnects"""
sse = next(
route.app for route in server_app.routes
if isinstance(route, Mount) and route.path == "/messages/"
).transport
async with httpx.AsyncClient(app=server_app, base_url="http://test") as client:
# Connect and get session ID
async with client.stream("GET", "/sse") as response:
for line in response.iter_lines():
if line.startswith("data: "):
endpoint_url = json.loads(line[6:])
session_id = endpoint_url.split("session_id=")[1]
break
assert len(sse._read_stream_writers) == 1
# After connection closes, writer should be cleaned up
await anyio.sleep(0.1) # Give cleanup a moment
assert len(sse._read_stream_writers) == 0
async def test_sse_client_request_and_response(initialized_sse_client_session: ClientSession):
session = initialized_sse_client_session
# TODO: expect raise
await session.read_resource(uri=AnyUrl("xxx://will-not-work"))
response = await session.read_resource(uri=AnyUrl("foobar://should-work"))
assert len(response.contents) == 1
assert isinstance(response.contents[0], TextResourceContents)
assert response.contents[0].text == "Read should-work"