This commit is contained in:
Nick Merrill
2025-01-14 09:27:42 -05:00
parent b0a6aafaf6
commit 3f9f7c8311

View File

@@ -4,6 +4,8 @@ from starlette.applications import Starlette
from starlette.routing import Mount, Route from starlette.routing import Mount, Route
import httpx import httpx
from httpx import ReadTimeout, ASGITransport from httpx import ReadTimeout, ASGITransport
from starlette.responses import Response
from sse_starlette.sse import EventSourceResponse
from mcp.client.sse import sse_client from mcp.client.sse import sse_client
from mcp.server.sse import SseServerTransport from mcp.server.sse import SseServerTransport
@@ -21,17 +23,33 @@ async def sse_app(sse_transport):
"""Fixture that creates a Starlette app with SSE endpoints.""" """Fixture that creates a Starlette app with SSE endpoints."""
async def handle_sse(request): async def handle_sse(request):
"""Handler for SSE connections.""" """Handler for SSE connections."""
async def event_generator():
# Send initial connection event
yield {
"event": "endpoint",
"data": "/messages",
}
# Keep connection alive
async with sse_transport.connect_sse( async with sse_transport.connect_sse(
request.scope, request.receive, request._send request.scope, request.receive, request._send
) as streams: ) as streams:
client_to_server, server_to_client = streams client_to_server, server_to_client = streams
async for message in client_to_server: async for message in client_to_server:
# Echo messages back for testing yield {
await server_to_client.send(message) "event": "message",
"data": message.model_dump_json(),
}
return EventSourceResponse(event_generator())
async def handle_post(request):
"""Handler for POST messages."""
return Response(status_code=200)
routes = [ routes = [
Route("/sse", endpoint=handle_sse), Route("/sse", endpoint=handle_sse),
Mount("/messages", app=sse_transport.handle_post_message), Route("/messages", endpoint=handle_post, methods=["POST"]),
] ]
return Starlette(routes=routes) return Starlette(routes=routes)
@@ -40,9 +58,11 @@ async def sse_app(sse_transport):
@pytest.fixture @pytest.fixture
async def test_client(sse_app): async def test_client(sse_app):
"""Create a test client with ASGI transport.""" """Create a test client with ASGI transport."""
transport = ASGITransport(app=sse_app)
async with httpx.AsyncClient( async with httpx.AsyncClient(
transport=ASGITransport(app=sse_app), transport=transport,
base_url="http://testserver", base_url="http://testserver",
timeout=5.0,
) as client: ) as client:
yield client yield client
@@ -53,7 +73,8 @@ async def test_sse_connection(test_client):
async with sse_client( async with sse_client(
"http://testserver/sse", "http://testserver/sse",
headers={"Host": "testserver"}, headers={"Host": "testserver"},
timeout=5, timeout=2,
sse_read_timeout=1,
client=test_client, client=test_client,
) as (read_stream, write_stream): ) as (read_stream, write_stream):
# Send a test message # Send a test message
@@ -74,7 +95,7 @@ async def test_sse_read_timeout(test_client):
async with sse_client( async with sse_client(
"http://testserver/sse", "http://testserver/sse",
headers={"Host": "testserver"}, headers={"Host": "testserver"},
timeout=5, timeout=2,
sse_read_timeout=1, sse_read_timeout=1,
client=test_client, client=test_client,
) as (read_stream, write_stream): ) as (read_stream, write_stream):
@@ -90,7 +111,7 @@ async def test_sse_connection_error(test_client):
async with sse_client( async with sse_client(
"http://testserver/nonexistent", "http://testserver/nonexistent",
headers={"Host": "testserver"}, headers={"Host": "testserver"},
timeout=5, timeout=2,
client=test_client, client=test_client,
): ):
pass # Should not reach here pass # Should not reach here
@@ -102,7 +123,8 @@ async def test_sse_multiple_messages(test_client):
async with sse_client( async with sse_client(
"http://testserver/sse", "http://testserver/sse",
headers={"Host": "testserver"}, headers={"Host": "testserver"},
timeout=5, timeout=2,
sse_read_timeout=1,
client=test_client, client=test_client,
) as (read_stream, write_stream): ) as (read_stream, write_stream):
# Send multiple test messages # Send multiple test messages