test_sse_connection is passing

This commit is contained in:
Nick Merrill
2025-01-14 10:18:11 -05:00
parent a0e2f7fab7
commit 66ccd1c515
2 changed files with 188 additions and 173 deletions

View File

@@ -1,173 +0,0 @@
import anyio
import asyncio
import pytest
from starlette.applications import Starlette
from starlette.routing import Mount, Route
import httpx
from httpx import ReadTimeout, ASGITransport
from starlette.responses import Response
from sse_starlette.sse import EventSourceResponse
from mcp.client.sse import sse_client
from mcp.server.sse import SseServerTransport
from mcp.types import JSONRPCMessage
@pytest.fixture
async def sse_transport():
"""Fixture that creates an SSE transport instance."""
return SseServerTransport("/messages/")
@pytest.fixture
async def sse_app(sse_transport):
"""Fixture that creates a Starlette app with SSE endpoints."""
async def handle_sse(request):
"""Handler for SSE connections."""
async def event_generator():
try:
async with sse_transport.connect_sse(
request.scope, request.receive, request._send
) as streams:
client_to_server, server_to_client = streams
# Send initial connection event
yield {
"event": "endpoint",
"data": "/messages",
}
# Process messages
async with anyio.create_task_group() as tg:
try:
async for message in client_to_server:
if isinstance(message, Exception):
break
yield {
"event": "message",
"data": message.model_dump_json(),
}
except (asyncio.CancelledError, GeneratorExit):
print('cancelled')
return
except Exception as e:
print("unhandled exception:", e)
return
except Exception:
# Log any unexpected errors but allow connection to close gracefully
pass
return EventSourceResponse(event_generator())
routes = [
Route("/sse", endpoint=handle_sse),
Mount("/messages", app=sse_transport.handle_post_message),
]
return Starlette(routes=routes)
@pytest.fixture
async def test_client(sse_app):
"""Create a test client with ASGI transport."""
transport = ASGITransport(app=sse_app)
async with httpx.AsyncClient(
transport=transport,
base_url="http://testserver",
timeout=10.0,
) as client:
yield client
@pytest.mark.anyio
async def test_sse_connection(test_client):
"""Test basic SSE connection and message exchange."""
async with anyio.create_task_group() as tg:
try:
async with sse_client(
"http://testserver/sse",
headers={"Host": "testserver"},
timeout=5,
sse_read_timeout=5,
client=test_client,
) as (read_stream, write_stream):
# First get the initial endpoint message
async with read_stream:
init_message = await read_stream.__anext__()
assert isinstance(init_message, JSONRPCMessage)
# Send a test message
test_message = JSONRPCMessage.model_validate({"jsonrpc": "2.0", "method": "test"})
await write_stream.send(test_message)
# Receive echoed message
async with read_stream:
message = await read_stream.__anext__()
assert isinstance(message, JSONRPCMessage)
assert message.model_dump() == test_message.model_dump()
# Explicitly close streams
await write_stream.aclose()
await read_stream.aclose()
except Exception as e:
pytest.fail(f"Test failed with error: {str(e)}")
# @pytest.mark.anyio
# async def test_sse_read_timeout(test_client):
# """Test that SSE client properly handles read timeouts."""
# with pytest.raises(ReadTimeout):
# async with sse_client(
# "http://testserver/sse",
# headers={"Host": "testserver"},
# timeout=5,
# sse_read_timeout=2,
# client=test_client,
# ) as (read_stream, write_stream):
# async with read_stream:
# # This should timeout since no messages are being sent
# await read_stream.__anext__()
# @pytest.mark.anyio
# async def test_sse_connection_error(test_client):
# """Test SSE client behavior with connection errors."""
# with pytest.raises(httpx.HTTPError):
# async with sse_client(
# "http://testserver/nonexistent",
# headers={"Host": "testserver"},
# timeout=5,
# client=test_client,
# ):
# pass # Should not reach here
# @pytest.mark.anyio
# async def test_sse_multiple_messages(test_client):
# """Test sending and receiving multiple SSE messages."""
# async with sse_client(
# "http://testserver/sse",
# headers={"Host": "testserver"},
# timeout=5,
# sse_read_timeout=5,
# client=test_client,
# ) as (read_stream, write_stream):
# # Send multiple test messages
# messages = [
# JSONRPCMessage.model_validate({"jsonrpc": "2.0", "method": f"test{i}"})
# for i in range(3)
# ]
# for msg in messages:
# await write_stream.send(msg)
# # Receive all echoed messages
# received = []
# async with read_stream:
# for _ in range(len(messages)):
# message = await read_stream.__anext__()
# assert isinstance(message, JSONRPCMessage)
# received.append(message)
# # Verify all messages were received in order
# for sent, received in zip(messages, received):
# assert sent.model_dump() == received.model_dump()

188
tests/shared/test_sse.py Normal file
View File

@@ -0,0 +1,188 @@
# test_sse.py
import re
import time
import json
import anyio
import pytest
import httpx
from typing import AsyncGenerator
from starlette.applications import Starlette
from starlette.routing import Mount, Route
from mcp.server import Server
from mcp.server.sse import SseServerTransport
from mcp.types import TextContent, Tool
# Test server implementation
class TestServer(Server):
def __init__(self):
super().__init__("test_server")
@self.list_tools()
async def handle_list_tools():
return [
Tool(
name="test_tool",
description="A test tool",
inputSchema={"type": "object", "properties": {}}
)
]
@self.call_tool()
async def handle_call_tool(name: str, args: dict):
return [TextContent(type="text", text=f"Called {name}")]
import threading
import uvicorn
import pytest
# Test fixtures
@pytest.fixture
async def server_app()-> Starlette:
"""Create test Starlette app with SSE transport"""
sse = SseServerTransport("/messages/")
server = TestServer()
async def handle_sse(request):
async with sse.connect_sse(
request.scope, request.receive, request._send
) as streams:
await server.run(
streams[0],
streams[1],
server.create_initialization_options()
)
app = Starlette(routes=[
Route("/sse", endpoint=handle_sse),
Mount("/messages/", app=sse.handle_post_message),
])
return app
@pytest.fixture()
def server(server_app: Starlette):
server = uvicorn.Server(config=uvicorn.Config(app=server_app, host="127.0.0.1", port=8765, log_level="error"))
server_thread = threading.Thread( target=server.run, daemon=True )
print('starting server')
server_thread.start()
# Give server time to start
while not server.started:
print('waiting for server to start')
time.sleep(0.5)
yield
print('killing server')
server_thread.join(timeout=0.1)
@pytest.fixture()
async def client(server) -> AsyncGenerator[httpx.AsyncClient, None]:
"""Create test client"""
async with httpx.AsyncClient(base_url="http://127.0.0.1:8765") as client:
yield client
# Tests
@pytest.mark.anyio
async def test_sse_connection(client: httpx.AsyncClient):
"""Test SSE connection establishment"""
async with anyio.create_task_group() as tg:
async def connection_test():
async with client.stream("GET", "/sse") as response:
assert response.status_code == 200
assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
line_number = 0
async for line in response.aiter_lines():
if line_number == 0:
assert line == "event: endpoint"
elif line_number == 1:
assert line.startswith("data: /messages/?session_id=")
else:
return
line_number += 1
# Add timeout to prevent test from hanging if it fails
with anyio.fail_after(3):
await connection_test()
@pytest.mark.anyio
async def test_message_exchange(client: httpx.AsyncClient):
"""Test full message exchange flow"""
# Connect to SSE endpoint
session_id = None
endpoint_url = None
async with client.stream("GET", "/sse") as sse_response:
assert sse_response.status_code == 200
# Get endpoint URL and session ID
async for line in sse_response.aiter_lines():
if line.startswith("data: "):
endpoint_url = json.loads(line[6:])
session_id = endpoint_url.split("session_id=")[1]
break
assert endpoint_url and session_id
# Send initialize request
init_request = {
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {},
"clientInfo": {
"name": "test_client",
"version": "1.0"
}
}
}
response = await client.post(
endpoint_url,
json=init_request
)
assert response.status_code == 202
# Get initialize response from SSE stream
async for line in sse_response.aiter_lines():
if line.startswith("event: message"):
data_line = next(sse_response.aiter_lines())
response = json.loads(data_line[6:]) # Strip "data: " prefix
assert response["jsonrpc"] == "2.0"
assert response["id"] == 1
assert "result" in response
break
@pytest.mark.anyio
async def test_invalid_session(client: httpx.AsyncClient):
"""Test sending message with invalid session ID"""
response = await client.post(
"/messages/?session_id=invalid",
json={"jsonrpc": "2.0", "method": "ping"}
)
assert response.status_code == 400
@pytest.mark.anyio
async def test_connection_cleanup(server_app):
"""Test that resources are cleaned up when client disconnects"""
sse = next(
route.app for route in server_app.routes
if isinstance(route, Mount) and route.path == "/messages/"
).transport
async with httpx.AsyncClient(app=server_app, base_url="http://test") as client:
# Connect and get session ID
async with client.stream("GET", "/sse") as response:
for line in response.iter_lines():
if line.startswith("data: "):
endpoint_url = json.loads(line[6:])
session_id = endpoint_url.split("session_id=")[1]
break
assert len(sse._read_stream_writers) == 1
# After connection closes, writer should be cleaned up
await anyio.sleep(0.1) # Give cleanup a moment
assert len(sse._read_stream_writers) == 0