Files
mcp-python-sdk/tests/shared/test_sse.py
2025-01-14 10:34:32 -05:00

154 lines
5.0 KiB
Python

# test_sse.py
import re
import time
import json
import anyio
from pydantic import AnyUrl
from pydantic_core import Url
import pytest
import httpx
from typing import AsyncGenerator
from starlette.applications import Starlette
from starlette.routing import Mount, Route
from mcp.client.session import ClientSession
from mcp.client.sse import sse_client
from mcp.server import Server
from mcp.server.sse import SseServerTransport
from mcp.types import EmptyResult, InitializeResult, TextContent, TextResourceContents, Tool
SERVER_URL = "http://127.0.0.1:8765"
SERVER_SSE_URL = f"{SERVER_URL}/sse"
SERVER_NAME = "test_server_for_SSE"
# Test server implementation
class TestServer(Server):
def __init__(self):
super().__init__(SERVER_NAME)
@self.read_resource()
async def handle_read_resource(uri: AnyUrl) -> str | bytes:
if uri.scheme == "foobar":
return f"Read {uri.host}"
# TODO: make this an error
return "NOT FOUND"
@self.list_tools()
async def handle_list_tools():
return [
Tool(
name="test_tool",
description="A test tool",
inputSchema={"type": "object", "properties": {}}
)
]
@self.call_tool()
async def handle_call_tool(name: str, args: dict):
return [TextContent(type="text", text=f"Called {name}")]
import threading
import uvicorn
import pytest
# Test fixtures
@pytest.fixture
async def server_app()-> Starlette:
"""Create test Starlette app with SSE transport"""
sse = SseServerTransport("/messages/")
server = TestServer()
async def handle_sse(request):
async with sse.connect_sse(
request.scope, request.receive, request._send
) as streams:
await server.run(
streams[0],
streams[1],
server.create_initialization_options()
)
app = Starlette(routes=[
Route("/sse", endpoint=handle_sse),
Mount("/messages/", app=sse.handle_post_message),
])
return app
@pytest.fixture()
def server(server_app: Starlette):
server = uvicorn.Server(config=uvicorn.Config(app=server_app, host="127.0.0.1", port=8765, log_level="error"))
server_thread = threading.Thread( target=server.run, daemon=True )
print('starting server')
server_thread.start()
# Give server time to start
while not server.started:
print('waiting for server to start')
time.sleep(0.5)
yield
print('killing server')
server_thread.join(timeout=0.1)
@pytest.fixture()
async def http_client(server) -> AsyncGenerator[httpx.AsyncClient, None]:
"""Create test client"""
async with httpx.AsyncClient(base_url=SERVER_URL) as client:
yield client
# Tests
@pytest.mark.anyio
async def test_raw_sse_connection(http_client: httpx.AsyncClient):
"""Test the SSE connection establishment simply with an HTTP client."""
async with anyio.create_task_group() as tg:
async def connection_test():
async with http_client.stream("GET", "/sse") as response:
assert response.status_code == 200
assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
line_number = 0
async for line in response.aiter_lines():
if line_number == 0:
assert line == "event: endpoint"
elif line_number == 1:
assert line.startswith("data: /messages/?session_id=")
else:
return
line_number += 1
# Add timeout to prevent test from hanging if it fails
with anyio.fail_after(3):
await connection_test()
@pytest.mark.anyio
async def test_sse_client_basic_connection(server):
async with sse_client(SERVER_SSE_URL) as streams:
async with ClientSession(*streams) as session:
# Test initialization
result = await session.initialize()
assert isinstance(result, InitializeResult)
assert result.serverInfo.name == SERVER_NAME
# Test ping
ping_result = await session.send_ping()
assert isinstance(ping_result, EmptyResult)
@pytest.fixture
async def initialized_sse_client_session(server) -> AsyncGenerator[ClientSession, None]:
async with sse_client(SERVER_SSE_URL) as streams:
async with ClientSession(*streams) as session:
await session.initialize()
yield session
@pytest.mark.anyio
async def test_sse_client_request_and_response(initialized_sse_client_session: ClientSession):
session = initialized_sse_client_session
# TODO: expect raise
await session.read_resource(uri=AnyUrl("xxx://will-not-work"))
response = await session.read_resource(uri=AnyUrl("foobar://should-work"))
assert len(response.contents) == 1
assert isinstance(response.contents[0], TextResourceContents)
assert response.contents[0].text == "Read should-work"