This commit is contained in:
Nick Merrill
2025-01-14 09:24:44 -05:00
parent f164291483
commit b0a6aafaf6
2 changed files with 117 additions and 55 deletions

View File

@@ -1,82 +1,127 @@
import pytest
import anyio
import pytest
from starlette.applications import Starlette
from starlette.routing import Mount, Route
import uvicorn
from mcp.client.sse import sse_client
from exceptiongroup import ExceptionGroup
import asyncio
import httpx
from httpx import ReadTimeout
from httpx import ReadTimeout, ASGITransport
from mcp.client.sse import sse_client
from mcp.server.sse import SseServerTransport
from mcp.types import JSONRPCMessage
@pytest.fixture
async def sse_server():
async def sse_transport():
"""Fixture that creates an SSE transport instance."""
return SseServerTransport("/messages/")
# Create an SSE transport at an endpoint
sse = SseServerTransport("/messages/")
# Create Starlette routes for SSE and message handling
routes = [
Route("/sse", endpoint=handle_sse),
Mount("/messages/", app=sse.handle_post_message),
]
#
# Create and run Starlette app
app = Starlette(routes=routes)
# Define handler functions
@pytest.fixture
async def sse_app(sse_transport):
"""Fixture that creates a Starlette app with SSE endpoints."""
async def handle_sse(request):
async with sse.connect_sse(
"""Handler for SSE connections."""
async with sse_transport.connect_sse(
request.scope, request.receive, request._send
) as streams:
await app.run(
streams[0], streams[1], app.create_initialization_options()
)
client_to_server, server_to_client = streams
async for message in client_to_server:
# Echo messages back for testing
await server_to_client.send(message)
uvicorn.run(app, host="127.0.0.1", port=34891)
routes = [
Route("/sse", endpoint=handle_sse),
Mount("/messages", app=sse_transport.handle_post_message),
]
async def sse_handler(request):
response = httpx.Response(200, content_type="text/event-stream")
response.send_headers()
response.write("data: test\n\n")
await response.aclose()
async with httpx.AsyncServer(sse_handler) as server:
yield server.url
return Starlette(routes=routes)
@pytest.fixture
async def sse_client():
async with sse_client("http://test/sse") as (read_stream, write_stream):
async with read_stream:
async for message in read_stream:
if isinstance(message, Exception):
raise message
async def test_client(sse_app):
"""Create a test client with ASGI transport."""
async with httpx.AsyncClient(
transport=ASGITransport(app=sse_app),
base_url="http://testserver",
) as client:
yield client
return read_stream, write_stream
@pytest.mark.anyio
async def test_sse_happy_path(monkeypatch):
# Mock httpx.AsyncClient to return our mock response
monkeypatch.setattr(httpx, "AsyncClient", MockClient)
async def test_sse_connection(test_client):
"""Test basic SSE connection and message exchange."""
async with sse_client(
"http://testserver/sse",
headers={"Host": "testserver"},
timeout=5,
client=test_client,
) as (read_stream, write_stream):
# Send a test message
test_message = JSONRPCMessage.model_validate({"jsonrpc": "2.0", "method": "test"})
await write_stream.send(test_message)
with pytest.raises(ReadTimeout) as exc_info:
# Receive echoed message
async with read_stream:
message = await read_stream.__anext__()
assert isinstance(message, JSONRPCMessage)
assert message.model_dump() == test_message.model_dump()
@pytest.mark.anyio
async def test_sse_read_timeout(test_client):
"""Test that SSE client properly handles read timeouts."""
with pytest.raises(ReadTimeout):
async with sse_client(
"http://test/sse",
timeout=5, # Connection timeout - make this longer
sse_read_timeout=1 # Read timeout - this should trigger
"http://testserver/sse",
headers={"Host": "testserver"},
timeout=5,
sse_read_timeout=1,
client=test_client,
) as (read_stream, write_stream):
async with read_stream:
async for message in read_stream:
if isinstance(message, Exception):
raise message
# This should timeout since no messages are being sent
await read_stream.__anext__()
error = exc_info.value
assert isinstance(error, ReadTimeout)
assert str(error) == "Read timeout"
@pytest.mark.anyio
async def test_sse_read_timeouts(monkeypatch):
"""Test that the SSE client properly handles read timeouts between SSE messages."""
async def test_sse_connection_error(test_client):
"""Test SSE client behavior with connection errors."""
with pytest.raises(httpx.HTTPError):
async with sse_client(
"http://testserver/nonexistent",
headers={"Host": "testserver"},
timeout=5,
client=test_client,
):
pass # Should not reach here
@pytest.mark.anyio
async def test_sse_multiple_messages(test_client):
"""Test sending and receiving multiple SSE messages."""
async with sse_client(
"http://testserver/sse",
headers={"Host": "testserver"},
timeout=5,
client=test_client,
) as (read_stream, write_stream):
# Send multiple test messages
messages = [
JSONRPCMessage.model_validate({"jsonrpc": "2.0", "method": f"test{i}"})
for i in range(3)
]
for msg in messages:
await write_stream.send(msg)
# Receive all echoed messages
received = []
async with read_stream:
for _ in range(len(messages)):
message = await read_stream.__anext__()
assert isinstance(message, JSONRPCMessage)
received.append(message)
# Verify all messages were received in order
for sent, received in zip(messages, received):
assert sent.model_dump() == received.model_dump()