mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 14:54:24 +01:00
passing SSE client test
This commit is contained in:
@@ -3,20 +3,36 @@ import re
|
|||||||
import time
|
import time
|
||||||
import json
|
import json
|
||||||
import anyio
|
import anyio
|
||||||
|
from pydantic import AnyUrl
|
||||||
|
from pydantic_core import Url
|
||||||
import pytest
|
import pytest
|
||||||
import httpx
|
import httpx
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator
|
||||||
from starlette.applications import Starlette
|
from starlette.applications import Starlette
|
||||||
from starlette.routing import Mount, Route
|
from starlette.routing import Mount, Route
|
||||||
|
|
||||||
|
from mcp.client.session import ClientSession
|
||||||
|
from mcp.client.sse import sse_client
|
||||||
from mcp.server import Server
|
from mcp.server import Server
|
||||||
from mcp.server.sse import SseServerTransport
|
from mcp.server.sse import SseServerTransport
|
||||||
from mcp.types import TextContent, Tool
|
from mcp.types import EmptyResult, InitializeResult, TextContent, TextResourceContents, Tool
|
||||||
|
|
||||||
|
SERVER_URL = "http://127.0.0.1:8765"
|
||||||
|
SERVER_SSE_URL = f"{SERVER_URL}/sse"
|
||||||
|
|
||||||
|
SERVER_NAME = "test_server_for_SSE"
|
||||||
|
|
||||||
# Test server implementation
|
# Test server implementation
|
||||||
class TestServer(Server):
|
class TestServer(Server):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__("test_server")
|
super().__init__(SERVER_NAME)
|
||||||
|
|
||||||
|
@self.read_resource()
|
||||||
|
async def handle_read_resource(uri: AnyUrl) -> str | bytes:
|
||||||
|
if uri.scheme == "foobar":
|
||||||
|
return f"Read {uri.host}"
|
||||||
|
# TODO: make this an error
|
||||||
|
return "NOT FOUND"
|
||||||
|
|
||||||
@self.list_tools()
|
@self.list_tools()
|
||||||
async def handle_list_tools():
|
async def handle_list_tools():
|
||||||
@@ -76,18 +92,18 @@ def server(server_app: Starlette):
|
|||||||
server_thread.join(timeout=0.1)
|
server_thread.join(timeout=0.1)
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
async def client(server) -> AsyncGenerator[httpx.AsyncClient, None]:
|
async def http_client(server) -> AsyncGenerator[httpx.AsyncClient, None]:
|
||||||
"""Create test client"""
|
"""Create test client"""
|
||||||
async with httpx.AsyncClient(base_url="http://127.0.0.1:8765") as client:
|
async with httpx.AsyncClient(base_url=SERVER_URL) as client:
|
||||||
yield client
|
yield client
|
||||||
|
|
||||||
# Tests
|
# Tests
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_sse_connection(client: httpx.AsyncClient):
|
async def test_raw_sse_connection(http_client: httpx.AsyncClient):
|
||||||
"""Test SSE connection establishment"""
|
"""Test the SSE connection establishment simply with an HTTP client."""
|
||||||
async with anyio.create_task_group() as tg:
|
async with anyio.create_task_group() as tg:
|
||||||
async def connection_test():
|
async def connection_test():
|
||||||
async with client.stream("GET", "/sse") as response:
|
async with http_client.stream("GET", "/sse") as response:
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
|
assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
|
||||||
|
|
||||||
@@ -105,84 +121,33 @@ async def test_sse_connection(client: httpx.AsyncClient):
|
|||||||
with anyio.fail_after(3):
|
with anyio.fail_after(3):
|
||||||
await connection_test()
|
await connection_test()
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_message_exchange(client: httpx.AsyncClient):
|
|
||||||
"""Test full message exchange flow"""
|
|
||||||
# Connect to SSE endpoint
|
|
||||||
session_id = None
|
|
||||||
endpoint_url = None
|
|
||||||
|
|
||||||
async with client.stream("GET", "/sse") as sse_response:
|
|
||||||
assert sse_response.status_code == 200
|
|
||||||
|
|
||||||
# Get endpoint URL and session ID
|
|
||||||
async for line in sse_response.aiter_lines():
|
|
||||||
if line.startswith("data: "):
|
|
||||||
endpoint_url = json.loads(line[6:])
|
|
||||||
session_id = endpoint_url.split("session_id=")[1]
|
|
||||||
break
|
|
||||||
|
|
||||||
assert endpoint_url and session_id
|
|
||||||
|
|
||||||
# Send initialize request
|
|
||||||
init_request = {
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"id": 1,
|
|
||||||
"method": "initialize",
|
|
||||||
"params": {
|
|
||||||
"protocolVersion": "2024-11-05",
|
|
||||||
"capabilities": {},
|
|
||||||
"clientInfo": {
|
|
||||||
"name": "test_client",
|
|
||||||
"version": "1.0"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
response = await client.post(
|
|
||||||
endpoint_url,
|
|
||||||
json=init_request
|
|
||||||
)
|
|
||||||
assert response.status_code == 202
|
|
||||||
|
|
||||||
# Get initialize response from SSE stream
|
|
||||||
async for line in sse_response.aiter_lines():
|
|
||||||
if line.startswith("event: message"):
|
|
||||||
data_line = next(sse_response.aiter_lines())
|
|
||||||
response = json.loads(data_line[6:]) # Strip "data: " prefix
|
|
||||||
assert response["jsonrpc"] == "2.0"
|
|
||||||
assert response["id"] == 1
|
|
||||||
assert "result" in response
|
|
||||||
break
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_invalid_session(client: httpx.AsyncClient):
|
async def test_sse_client_basic_connection(server):
|
||||||
"""Test sending message with invalid session ID"""
|
async with sse_client(SERVER_SSE_URL) as streams:
|
||||||
response = await client.post(
|
async with ClientSession(*streams) as session:
|
||||||
"/messages/?session_id=invalid",
|
# Test initialization
|
||||||
json={"jsonrpc": "2.0", "method": "ping"}
|
result = await session.initialize()
|
||||||
)
|
assert isinstance(result, InitializeResult)
|
||||||
assert response.status_code == 400
|
assert result.serverInfo.name == SERVER_NAME
|
||||||
|
|
||||||
|
# Test ping
|
||||||
|
ping_result = await session.send_ping()
|
||||||
|
assert isinstance(ping_result, EmptyResult)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def initialized_sse_client_session(server) -> AsyncGenerator[ClientSession, None]:
|
||||||
|
async with sse_client(SERVER_SSE_URL) as streams:
|
||||||
|
async with ClientSession(*streams) as session:
|
||||||
|
await session.initialize()
|
||||||
|
yield session
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_connection_cleanup(server_app):
|
async def test_sse_client_request_and_response(initialized_sse_client_session: ClientSession):
|
||||||
"""Test that resources are cleaned up when client disconnects"""
|
session = initialized_sse_client_session
|
||||||
sse = next(
|
# TODO: expect raise
|
||||||
route.app for route in server_app.routes
|
await session.read_resource(uri=AnyUrl("xxx://will-not-work"))
|
||||||
if isinstance(route, Mount) and route.path == "/messages/"
|
response = await session.read_resource(uri=AnyUrl("foobar://should-work"))
|
||||||
).transport
|
assert len(response.contents) == 1
|
||||||
|
assert isinstance(response.contents[0], TextResourceContents)
|
||||||
async with httpx.AsyncClient(app=server_app, base_url="http://test") as client:
|
assert response.contents[0].text == "Read should-work"
|
||||||
# Connect and get session ID
|
|
||||||
async with client.stream("GET", "/sse") as response:
|
|
||||||
for line in response.iter_lines():
|
|
||||||
if line.startswith("data: "):
|
|
||||||
endpoint_url = json.loads(line[6:])
|
|
||||||
session_id = endpoint_url.split("session_id=")[1]
|
|
||||||
break
|
|
||||||
|
|
||||||
assert len(sse._read_stream_writers) == 1
|
|
||||||
|
|
||||||
# After connection closes, writer should be cleaned up
|
|
||||||
await anyio.sleep(0.1) # Give cleanup a moment
|
|
||||||
assert len(sse._read_stream_writers) == 0
|
|
||||||
|
|||||||
Reference in New Issue
Block a user