mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 06:54:18 +01:00
passing SSE client test
This commit is contained in:
@@ -3,20 +3,36 @@ import re
|
||||
import time
|
||||
import json
|
||||
import anyio
|
||||
from pydantic import AnyUrl
|
||||
from pydantic_core import Url
|
||||
import pytest
|
||||
import httpx
|
||||
from typing import AsyncGenerator
|
||||
from starlette.applications import Starlette
|
||||
from starlette.routing import Mount, Route
|
||||
|
||||
from mcp.client.session import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.server import Server
|
||||
from mcp.server.sse import SseServerTransport
|
||||
from mcp.types import TextContent, Tool
|
||||
from mcp.types import EmptyResult, InitializeResult, TextContent, TextResourceContents, Tool
|
||||
|
||||
SERVER_URL = "http://127.0.0.1:8765"
|
||||
SERVER_SSE_URL = f"{SERVER_URL}/sse"
|
||||
|
||||
SERVER_NAME = "test_server_for_SSE"
|
||||
|
||||
# Test server implementation
|
||||
class TestServer(Server):
|
||||
def __init__(self):
|
||||
super().__init__("test_server")
|
||||
super().__init__(SERVER_NAME)
|
||||
|
||||
@self.read_resource()
|
||||
async def handle_read_resource(uri: AnyUrl) -> str | bytes:
|
||||
if uri.scheme == "foobar":
|
||||
return f"Read {uri.host}"
|
||||
# TODO: make this an error
|
||||
return "NOT FOUND"
|
||||
|
||||
@self.list_tools()
|
||||
async def handle_list_tools():
|
||||
@@ -76,18 +92,18 @@ def server(server_app: Starlette):
|
||||
server_thread.join(timeout=0.1)
|
||||
|
||||
@pytest.fixture()
|
||||
async def client(server) -> AsyncGenerator[httpx.AsyncClient, None]:
|
||||
async def http_client(server) -> AsyncGenerator[httpx.AsyncClient, None]:
|
||||
"""Create test client"""
|
||||
async with httpx.AsyncClient(base_url="http://127.0.0.1:8765") as client:
|
||||
async with httpx.AsyncClient(base_url=SERVER_URL) as client:
|
||||
yield client
|
||||
|
||||
# Tests
|
||||
@pytest.mark.anyio
|
||||
async def test_sse_connection(client: httpx.AsyncClient):
|
||||
"""Test SSE connection establishment"""
|
||||
async def test_raw_sse_connection(http_client: httpx.AsyncClient):
|
||||
"""Test the SSE connection establishment simply with an HTTP client."""
|
||||
async with anyio.create_task_group() as tg:
|
||||
async def connection_test():
|
||||
async with client.stream("GET", "/sse") as response:
|
||||
async with http_client.stream("GET", "/sse") as response:
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
|
||||
|
||||
@@ -105,84 +121,33 @@ async def test_sse_connection(client: httpx.AsyncClient):
|
||||
with anyio.fail_after(3):
|
||||
await connection_test()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_message_exchange(client: httpx.AsyncClient):
|
||||
"""Test full message exchange flow"""
|
||||
# Connect to SSE endpoint
|
||||
session_id = None
|
||||
endpoint_url = None
|
||||
|
||||
async with client.stream("GET", "/sse") as sse_response:
|
||||
assert sse_response.status_code == 200
|
||||
|
||||
# Get endpoint URL and session ID
|
||||
async for line in sse_response.aiter_lines():
|
||||
if line.startswith("data: "):
|
||||
endpoint_url = json.loads(line[6:])
|
||||
session_id = endpoint_url.split("session_id=")[1]
|
||||
break
|
||||
|
||||
assert endpoint_url and session_id
|
||||
|
||||
# Send initialize request
|
||||
init_request = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "initialize",
|
||||
"params": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {},
|
||||
"clientInfo": {
|
||||
"name": "test_client",
|
||||
"version": "1.0"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
response = await client.post(
|
||||
endpoint_url,
|
||||
json=init_request
|
||||
)
|
||||
assert response.status_code == 202
|
||||
|
||||
# Get initialize response from SSE stream
|
||||
async for line in sse_response.aiter_lines():
|
||||
if line.startswith("event: message"):
|
||||
data_line = next(sse_response.aiter_lines())
|
||||
response = json.loads(data_line[6:]) # Strip "data: " prefix
|
||||
assert response["jsonrpc"] == "2.0"
|
||||
assert response["id"] == 1
|
||||
assert "result" in response
|
||||
break
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_invalid_session(client: httpx.AsyncClient):
|
||||
"""Test sending message with invalid session ID"""
|
||||
response = await client.post(
|
||||
"/messages/?session_id=invalid",
|
||||
json={"jsonrpc": "2.0", "method": "ping"}
|
||||
)
|
||||
assert response.status_code == 400
|
||||
async def test_sse_client_basic_connection(server):
|
||||
async with sse_client(SERVER_SSE_URL) as streams:
|
||||
async with ClientSession(*streams) as session:
|
||||
# Test initialization
|
||||
result = await session.initialize()
|
||||
assert isinstance(result, InitializeResult)
|
||||
assert result.serverInfo.name == SERVER_NAME
|
||||
|
||||
# Test ping
|
||||
ping_result = await session.send_ping()
|
||||
assert isinstance(ping_result, EmptyResult)
|
||||
|
||||
@pytest.fixture
|
||||
async def initialized_sse_client_session(server) -> AsyncGenerator[ClientSession, None]:
|
||||
async with sse_client(SERVER_SSE_URL) as streams:
|
||||
async with ClientSession(*streams) as session:
|
||||
await session.initialize()
|
||||
yield session
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_connection_cleanup(server_app):
|
||||
"""Test that resources are cleaned up when client disconnects"""
|
||||
sse = next(
|
||||
route.app for route in server_app.routes
|
||||
if isinstance(route, Mount) and route.path == "/messages/"
|
||||
).transport
|
||||
|
||||
async with httpx.AsyncClient(app=server_app, base_url="http://test") as client:
|
||||
# Connect and get session ID
|
||||
async with client.stream("GET", "/sse") as response:
|
||||
for line in response.iter_lines():
|
||||
if line.startswith("data: "):
|
||||
endpoint_url = json.loads(line[6:])
|
||||
session_id = endpoint_url.split("session_id=")[1]
|
||||
break
|
||||
|
||||
assert len(sse._read_stream_writers) == 1
|
||||
|
||||
# After connection closes, writer should be cleaned up
|
||||
await anyio.sleep(0.1) # Give cleanup a moment
|
||||
assert len(sse._read_stream_writers) == 0
|
||||
async def test_sse_client_request_and_response(initialized_sse_client_session: ClientSession):
|
||||
session = initialized_sse_client_session
|
||||
# TODO: expect raise
|
||||
await session.read_resource(uri=AnyUrl("xxx://will-not-work"))
|
||||
response = await session.read_resource(uri=AnyUrl("foobar://should-work"))
|
||||
assert len(response.contents) == 1
|
||||
assert isinstance(response.contents[0], TextResourceContents)
|
||||
assert response.contents[0].text == "Read should-work"
|
||||
|
||||
Reference in New Issue
Block a user