mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 23:04:25 +01:00
test_sse_connection is passing
This commit is contained in:
@@ -1,173 +0,0 @@
|
||||
import anyio
|
||||
import asyncio
|
||||
import pytest
|
||||
from starlette.applications import Starlette
|
||||
from starlette.routing import Mount, Route
|
||||
import httpx
|
||||
from httpx import ReadTimeout, ASGITransport
|
||||
from starlette.responses import Response
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.server.sse import SseServerTransport
|
||||
from mcp.types import JSONRPCMessage
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def sse_transport():
|
||||
"""Fixture that creates an SSE transport instance."""
|
||||
return SseServerTransport("/messages/")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def sse_app(sse_transport):
|
||||
"""Fixture that creates a Starlette app with SSE endpoints."""
|
||||
async def handle_sse(request):
|
||||
"""Handler for SSE connections."""
|
||||
async def event_generator():
|
||||
try:
|
||||
async with sse_transport.connect_sse(
|
||||
request.scope, request.receive, request._send
|
||||
) as streams:
|
||||
client_to_server, server_to_client = streams
|
||||
# Send initial connection event
|
||||
yield {
|
||||
"event": "endpoint",
|
||||
"data": "/messages",
|
||||
}
|
||||
|
||||
# Process messages
|
||||
async with anyio.create_task_group() as tg:
|
||||
try:
|
||||
async for message in client_to_server:
|
||||
if isinstance(message, Exception):
|
||||
break
|
||||
yield {
|
||||
"event": "message",
|
||||
"data": message.model_dump_json(),
|
||||
}
|
||||
except (asyncio.CancelledError, GeneratorExit):
|
||||
print('cancelled')
|
||||
return
|
||||
except Exception as e:
|
||||
print("unhandled exception:", e)
|
||||
return
|
||||
except Exception:
|
||||
# Log any unexpected errors but allow connection to close gracefully
|
||||
pass
|
||||
|
||||
return EventSourceResponse(event_generator())
|
||||
|
||||
routes = [
|
||||
Route("/sse", endpoint=handle_sse),
|
||||
Mount("/messages", app=sse_transport.handle_post_message),
|
||||
]
|
||||
|
||||
return Starlette(routes=routes)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def test_client(sse_app):
|
||||
"""Create a test client with ASGI transport."""
|
||||
transport = ASGITransport(app=sse_app)
|
||||
async with httpx.AsyncClient(
|
||||
transport=transport,
|
||||
base_url="http://testserver",
|
||||
timeout=10.0,
|
||||
) as client:
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_sse_connection(test_client):
|
||||
"""Test basic SSE connection and message exchange."""
|
||||
async with anyio.create_task_group() as tg:
|
||||
try:
|
||||
async with sse_client(
|
||||
"http://testserver/sse",
|
||||
headers={"Host": "testserver"},
|
||||
timeout=5,
|
||||
sse_read_timeout=5,
|
||||
client=test_client,
|
||||
) as (read_stream, write_stream):
|
||||
# First get the initial endpoint message
|
||||
async with read_stream:
|
||||
init_message = await read_stream.__anext__()
|
||||
assert isinstance(init_message, JSONRPCMessage)
|
||||
|
||||
# Send a test message
|
||||
test_message = JSONRPCMessage.model_validate({"jsonrpc": "2.0", "method": "test"})
|
||||
await write_stream.send(test_message)
|
||||
|
||||
# Receive echoed message
|
||||
async with read_stream:
|
||||
message = await read_stream.__anext__()
|
||||
assert isinstance(message, JSONRPCMessage)
|
||||
assert message.model_dump() == test_message.model_dump()
|
||||
|
||||
# Explicitly close streams
|
||||
await write_stream.aclose()
|
||||
await read_stream.aclose()
|
||||
except Exception as e:
|
||||
pytest.fail(f"Test failed with error: {str(e)}")
|
||||
|
||||
|
||||
# @pytest.mark.anyio
|
||||
# async def test_sse_read_timeout(test_client):
|
||||
# """Test that SSE client properly handles read timeouts."""
|
||||
# with pytest.raises(ReadTimeout):
|
||||
# async with sse_client(
|
||||
# "http://testserver/sse",
|
||||
# headers={"Host": "testserver"},
|
||||
# timeout=5,
|
||||
# sse_read_timeout=2,
|
||||
# client=test_client,
|
||||
# ) as (read_stream, write_stream):
|
||||
# async with read_stream:
|
||||
# # This should timeout since no messages are being sent
|
||||
# await read_stream.__anext__()
|
||||
|
||||
|
||||
# @pytest.mark.anyio
|
||||
# async def test_sse_connection_error(test_client):
|
||||
# """Test SSE client behavior with connection errors."""
|
||||
# with pytest.raises(httpx.HTTPError):
|
||||
# async with sse_client(
|
||||
# "http://testserver/nonexistent",
|
||||
# headers={"Host": "testserver"},
|
||||
# timeout=5,
|
||||
# client=test_client,
|
||||
# ):
|
||||
# pass # Should not reach here
|
||||
|
||||
|
||||
# @pytest.mark.anyio
|
||||
# async def test_sse_multiple_messages(test_client):
|
||||
# """Test sending and receiving multiple SSE messages."""
|
||||
# async with sse_client(
|
||||
# "http://testserver/sse",
|
||||
# headers={"Host": "testserver"},
|
||||
# timeout=5,
|
||||
# sse_read_timeout=5,
|
||||
# client=test_client,
|
||||
# ) as (read_stream, write_stream):
|
||||
# # Send multiple test messages
|
||||
# messages = [
|
||||
# JSONRPCMessage.model_validate({"jsonrpc": "2.0", "method": f"test{i}"})
|
||||
# for i in range(3)
|
||||
# ]
|
||||
|
||||
# for msg in messages:
|
||||
# await write_stream.send(msg)
|
||||
|
||||
# # Receive all echoed messages
|
||||
# received = []
|
||||
# async with read_stream:
|
||||
# for _ in range(len(messages)):
|
||||
# message = await read_stream.__anext__()
|
||||
# assert isinstance(message, JSONRPCMessage)
|
||||
# received.append(message)
|
||||
|
||||
# # Verify all messages were received in order
|
||||
# for sent, received in zip(messages, received):
|
||||
# assert sent.model_dump() == received.model_dump()
|
||||
188
tests/shared/test_sse.py
Normal file
188
tests/shared/test_sse.py
Normal file
@@ -0,0 +1,188 @@
|
||||
# test_sse.py
|
||||
import re
|
||||
import time
|
||||
import json
|
||||
import anyio
|
||||
import pytest
|
||||
import httpx
|
||||
from typing import AsyncGenerator
|
||||
from starlette.applications import Starlette
|
||||
from starlette.routing import Mount, Route
|
||||
|
||||
from mcp.server import Server
|
||||
from mcp.server.sse import SseServerTransport
|
||||
from mcp.types import TextContent, Tool
|
||||
|
||||
# Test server implementation
|
||||
class TestServer(Server):
|
||||
def __init__(self):
|
||||
super().__init__("test_server")
|
||||
|
||||
@self.list_tools()
|
||||
async def handle_list_tools():
|
||||
return [
|
||||
Tool(
|
||||
name="test_tool",
|
||||
description="A test tool",
|
||||
inputSchema={"type": "object", "properties": {}}
|
||||
)
|
||||
]
|
||||
|
||||
@self.call_tool()
|
||||
async def handle_call_tool(name: str, args: dict):
|
||||
return [TextContent(type="text", text=f"Called {name}")]
|
||||
|
||||
import threading
|
||||
import uvicorn
|
||||
import pytest
|
||||
|
||||
|
||||
# Test fixtures
|
||||
@pytest.fixture
|
||||
async def server_app()-> Starlette:
|
||||
"""Create test Starlette app with SSE transport"""
|
||||
sse = SseServerTransport("/messages/")
|
||||
server = TestServer()
|
||||
|
||||
async def handle_sse(request):
|
||||
async with sse.connect_sse(
|
||||
request.scope, request.receive, request._send
|
||||
) as streams:
|
||||
await server.run(
|
||||
streams[0],
|
||||
streams[1],
|
||||
server.create_initialization_options()
|
||||
)
|
||||
|
||||
app = Starlette(routes=[
|
||||
Route("/sse", endpoint=handle_sse),
|
||||
Mount("/messages/", app=sse.handle_post_message),
|
||||
])
|
||||
|
||||
return app
|
||||
|
||||
@pytest.fixture()
|
||||
def server(server_app: Starlette):
|
||||
server = uvicorn.Server(config=uvicorn.Config(app=server_app, host="127.0.0.1", port=8765, log_level="error"))
|
||||
server_thread = threading.Thread( target=server.run, daemon=True )
|
||||
print('starting server')
|
||||
server_thread.start()
|
||||
# Give server time to start
|
||||
while not server.started:
|
||||
print('waiting for server to start')
|
||||
time.sleep(0.5)
|
||||
yield
|
||||
print('killing server')
|
||||
server_thread.join(timeout=0.1)
|
||||
|
||||
@pytest.fixture()
|
||||
async def client(server) -> AsyncGenerator[httpx.AsyncClient, None]:
|
||||
"""Create test client"""
|
||||
async with httpx.AsyncClient(base_url="http://127.0.0.1:8765") as client:
|
||||
yield client
|
||||
|
||||
# Tests
|
||||
@pytest.mark.anyio
|
||||
async def test_sse_connection(client: httpx.AsyncClient):
|
||||
"""Test SSE connection establishment"""
|
||||
async with anyio.create_task_group() as tg:
|
||||
async def connection_test():
|
||||
async with client.stream("GET", "/sse") as response:
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
|
||||
|
||||
line_number = 0
|
||||
async for line in response.aiter_lines():
|
||||
if line_number == 0:
|
||||
assert line == "event: endpoint"
|
||||
elif line_number == 1:
|
||||
assert line.startswith("data: /messages/?session_id=")
|
||||
else:
|
||||
return
|
||||
line_number += 1
|
||||
|
||||
# Add timeout to prevent test from hanging if it fails
|
||||
with anyio.fail_after(3):
|
||||
await connection_test()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_message_exchange(client: httpx.AsyncClient):
|
||||
"""Test full message exchange flow"""
|
||||
# Connect to SSE endpoint
|
||||
session_id = None
|
||||
endpoint_url = None
|
||||
|
||||
async with client.stream("GET", "/sse") as sse_response:
|
||||
assert sse_response.status_code == 200
|
||||
|
||||
# Get endpoint URL and session ID
|
||||
async for line in sse_response.aiter_lines():
|
||||
if line.startswith("data: "):
|
||||
endpoint_url = json.loads(line[6:])
|
||||
session_id = endpoint_url.split("session_id=")[1]
|
||||
break
|
||||
|
||||
assert endpoint_url and session_id
|
||||
|
||||
# Send initialize request
|
||||
init_request = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "initialize",
|
||||
"params": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {},
|
||||
"clientInfo": {
|
||||
"name": "test_client",
|
||||
"version": "1.0"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
response = await client.post(
|
||||
endpoint_url,
|
||||
json=init_request
|
||||
)
|
||||
assert response.status_code == 202
|
||||
|
||||
# Get initialize response from SSE stream
|
||||
async for line in sse_response.aiter_lines():
|
||||
if line.startswith("event: message"):
|
||||
data_line = next(sse_response.aiter_lines())
|
||||
response = json.loads(data_line[6:]) # Strip "data: " prefix
|
||||
assert response["jsonrpc"] == "2.0"
|
||||
assert response["id"] == 1
|
||||
assert "result" in response
|
||||
break
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_invalid_session(client: httpx.AsyncClient):
|
||||
"""Test sending message with invalid session ID"""
|
||||
response = await client.post(
|
||||
"/messages/?session_id=invalid",
|
||||
json={"jsonrpc": "2.0", "method": "ping"}
|
||||
)
|
||||
assert response.status_code == 400
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_connection_cleanup(server_app):
|
||||
"""Test that resources are cleaned up when client disconnects"""
|
||||
sse = next(
|
||||
route.app for route in server_app.routes
|
||||
if isinstance(route, Mount) and route.path == "/messages/"
|
||||
).transport
|
||||
|
||||
async with httpx.AsyncClient(app=server_app, base_url="http://test") as client:
|
||||
# Connect and get session ID
|
||||
async with client.stream("GET", "/sse") as response:
|
||||
for line in response.iter_lines():
|
||||
if line.startswith("data: "):
|
||||
endpoint_url = json.loads(line[6:])
|
||||
session_id = endpoint_url.split("session_id=")[1]
|
||||
break
|
||||
|
||||
assert len(sse._read_stream_writers) == 1
|
||||
|
||||
# After connection closes, writer should be cleaned up
|
||||
await anyio.sleep(0.1) # Give cleanup a moment
|
||||
assert len(sse._read_stream_writers) == 0
|
||||
Reference in New Issue
Block a user