This commit is contained in:
Nick Merrill
2025-01-14 09:27:42 -05:00
parent b0a6aafaf6
commit 3f9f7c8311

View File

@@ -4,6 +4,8 @@ from starlette.applications import Starlette
from starlette.routing import Mount, Route
import httpx
from httpx import ReadTimeout, ASGITransport
from starlette.responses import Response
from sse_starlette.sse import EventSourceResponse
from mcp.client.sse import sse_client
from mcp.server.sse import SseServerTransport
@@ -21,17 +23,33 @@ async def sse_app(sse_transport):
"""Fixture that creates a Starlette app with SSE endpoints."""
async def handle_sse(request):
"""Handler for SSE connections."""
async def event_generator():
# Send initial connection event
yield {
"event": "endpoint",
"data": "/messages",
}
# Keep connection alive
async with sse_transport.connect_sse(
request.scope, request.receive, request._send
) as streams:
client_to_server, server_to_client = streams
async for message in client_to_server:
# Echo messages back for testing
await server_to_client.send(message)
yield {
"event": "message",
"data": message.model_dump_json(),
}
return EventSourceResponse(event_generator())
async def handle_post(request):
"""Handler for POST messages."""
return Response(status_code=200)
routes = [
Route("/sse", endpoint=handle_sse),
Mount("/messages", app=sse_transport.handle_post_message),
Route("/messages", endpoint=handle_post, methods=["POST"]),
]
return Starlette(routes=routes)
@@ -40,9 +58,11 @@ async def sse_app(sse_transport):
@pytest.fixture
async def test_client(sse_app):
"""Create a test client with ASGI transport."""
transport = ASGITransport(app=sse_app)
async with httpx.AsyncClient(
transport=ASGITransport(app=sse_app),
transport=transport,
base_url="http://testserver",
timeout=5.0,
) as client:
yield client
@@ -53,7 +73,8 @@ async def test_sse_connection(test_client):
async with sse_client(
"http://testserver/sse",
headers={"Host": "testserver"},
timeout=5,
timeout=2,
sse_read_timeout=1,
client=test_client,
) as (read_stream, write_stream):
# Send a test message
@@ -74,7 +95,7 @@ async def test_sse_read_timeout(test_client):
async with sse_client(
"http://testserver/sse",
headers={"Host": "testserver"},
timeout=5,
timeout=2,
sse_read_timeout=1,
client=test_client,
) as (read_stream, write_stream):
@@ -90,7 +111,7 @@ async def test_sse_connection_error(test_client):
async with sse_client(
"http://testserver/nonexistent",
headers={"Host": "testserver"},
timeout=5,
timeout=2,
client=test_client,
):
pass # Should not reach here
@@ -102,7 +123,8 @@ async def test_sse_multiple_messages(test_client):
async with sse_client(
"http://testserver/sse",
headers={"Host": "testserver"},
timeout=5,
timeout=2,
sse_read_timeout=1,
client=test_client,
) as (read_stream, write_stream):
# Send multiple test messages