mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-18 14:34:27 +01:00
467 lines
16 KiB
Python
467 lines
16 KiB
Python
import json
|
|
import multiprocessing
|
|
import socket
|
|
import time
|
|
from collections.abc import AsyncGenerator, Generator
|
|
|
|
import anyio
|
|
import httpx
|
|
import pytest
|
|
import uvicorn
|
|
from inline_snapshot import snapshot
|
|
from pydantic import AnyUrl
|
|
from starlette.applications import Starlette
|
|
from starlette.requests import Request
|
|
from starlette.responses import Response
|
|
from starlette.routing import Mount, Route
|
|
|
|
import mcp.types as types
|
|
from mcp.client.session import ClientSession
|
|
from mcp.client.sse import sse_client
|
|
from mcp.server import Server
|
|
from mcp.server.sse import SseServerTransport
|
|
from mcp.shared.exceptions import McpError
|
|
from mcp.types import (
|
|
EmptyResult,
|
|
ErrorData,
|
|
InitializeResult,
|
|
ReadResourceResult,
|
|
TextContent,
|
|
TextResourceContents,
|
|
Tool,
|
|
)
|
|
|
|
SERVER_NAME = "test_server_for_SSE"
|
|
|
|
|
|
@pytest.fixture
|
|
def server_port() -> int:
|
|
with socket.socket() as s:
|
|
s.bind(("127.0.0.1", 0))
|
|
return s.getsockname()[1]
|
|
|
|
|
|
@pytest.fixture
|
|
def server_url(server_port: int) -> str:
|
|
return f"http://127.0.0.1:{server_port}"
|
|
|
|
|
|
# Test server implementation
|
|
class ServerTest(Server):
|
|
def __init__(self):
|
|
super().__init__(SERVER_NAME)
|
|
|
|
@self.read_resource()
|
|
async def handle_read_resource(uri: AnyUrl) -> str | bytes:
|
|
if uri.scheme == "foobar":
|
|
return f"Read {uri.host}"
|
|
elif uri.scheme == "slow":
|
|
# Simulate a slow resource
|
|
await anyio.sleep(2.0)
|
|
return f"Slow response from {uri.host}"
|
|
|
|
raise McpError(error=ErrorData(code=404, message="OOPS! no resource with that URI was found"))
|
|
|
|
@self.list_tools()
|
|
async def handle_list_tools() -> list[Tool]:
|
|
return [
|
|
Tool(
|
|
name="test_tool",
|
|
description="A test tool",
|
|
inputSchema={"type": "object", "properties": {}},
|
|
)
|
|
]
|
|
|
|
@self.call_tool()
|
|
async def handle_call_tool(name: str, args: dict) -> list[TextContent]:
|
|
return [TextContent(type="text", text=f"Called {name}")]
|
|
|
|
|
|
# Test fixtures
|
|
def make_server_app() -> Starlette:
|
|
"""Create test Starlette app with SSE transport"""
|
|
sse = SseServerTransport("/messages/")
|
|
server = ServerTest()
|
|
|
|
async def handle_sse(request: Request) -> Response:
|
|
async with sse.connect_sse(request.scope, request.receive, request._send) as streams:
|
|
await server.run(streams[0], streams[1], server.create_initialization_options())
|
|
return Response()
|
|
|
|
app = Starlette(
|
|
routes=[
|
|
Route("/sse", endpoint=handle_sse),
|
|
Mount("/messages/", app=sse.handle_post_message),
|
|
]
|
|
)
|
|
|
|
return app
|
|
|
|
|
|
def run_server(server_port: int) -> None:
|
|
app = make_server_app()
|
|
server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error"))
|
|
print(f"starting server on {server_port}")
|
|
server.run()
|
|
|
|
# Give server time to start
|
|
while not server.started:
|
|
print("waiting for server to start")
|
|
time.sleep(0.5)
|
|
|
|
|
|
@pytest.fixture()
|
|
def server(server_port: int) -> Generator[None, None, None]:
|
|
proc = multiprocessing.Process(target=run_server, kwargs={"server_port": server_port}, daemon=True)
|
|
print("starting process")
|
|
proc.start()
|
|
|
|
# Wait for server to be running
|
|
max_attempts = 20
|
|
attempt = 0
|
|
print("waiting for server to start")
|
|
while attempt < max_attempts:
|
|
try:
|
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
s.connect(("127.0.0.1", server_port))
|
|
break
|
|
except ConnectionRefusedError:
|
|
time.sleep(0.1)
|
|
attempt += 1
|
|
else:
|
|
raise RuntimeError(f"Server failed to start after {max_attempts} attempts")
|
|
|
|
yield
|
|
|
|
print("killing server")
|
|
# Signal the server to stop
|
|
proc.kill()
|
|
proc.join(timeout=2)
|
|
if proc.is_alive():
|
|
print("server process failed to terminate")
|
|
|
|
|
|
@pytest.fixture()
|
|
async def http_client(server, server_url) -> AsyncGenerator[httpx.AsyncClient, None]:
|
|
"""Create test client"""
|
|
async with httpx.AsyncClient(base_url=server_url) as client:
|
|
yield client
|
|
|
|
|
|
# Tests
|
|
@pytest.mark.anyio
|
|
async def test_raw_sse_connection(http_client: httpx.AsyncClient) -> None:
|
|
"""Test the SSE connection establishment simply with an HTTP client."""
|
|
async with anyio.create_task_group():
|
|
|
|
async def connection_test() -> None:
|
|
async with http_client.stream("GET", "/sse") as response:
|
|
assert response.status_code == 200
|
|
assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
|
|
|
|
line_number = 0
|
|
async for line in response.aiter_lines():
|
|
if line_number == 0:
|
|
assert line == "event: endpoint"
|
|
elif line_number == 1:
|
|
assert line.startswith("data: /messages/?session_id=")
|
|
else:
|
|
return
|
|
line_number += 1
|
|
|
|
# Add timeout to prevent test from hanging if it fails
|
|
with anyio.fail_after(3):
|
|
await connection_test()
|
|
|
|
|
|
@pytest.mark.anyio
|
|
async def test_sse_client_basic_connection(server: None, server_url: str) -> None:
|
|
async with sse_client(server_url + "/sse") as streams:
|
|
async with ClientSession(*streams) as session:
|
|
# Test initialization
|
|
result = await session.initialize()
|
|
assert isinstance(result, InitializeResult)
|
|
assert result.serverInfo.name == SERVER_NAME
|
|
|
|
# Test ping
|
|
ping_result = await session.send_ping()
|
|
assert isinstance(ping_result, EmptyResult)
|
|
|
|
|
|
@pytest.fixture
|
|
async def initialized_sse_client_session(server, server_url: str) -> AsyncGenerator[ClientSession, None]:
|
|
async with sse_client(server_url + "/sse", sse_read_timeout=0.5) as streams:
|
|
async with ClientSession(*streams) as session:
|
|
await session.initialize()
|
|
yield session
|
|
|
|
|
|
@pytest.mark.anyio
|
|
async def test_sse_client_happy_request_and_response(
|
|
initialized_sse_client_session: ClientSession,
|
|
) -> None:
|
|
session = initialized_sse_client_session
|
|
response = await session.read_resource(uri=AnyUrl("foobar://should-work"))
|
|
assert len(response.contents) == 1
|
|
assert isinstance(response.contents[0], TextResourceContents)
|
|
assert response.contents[0].text == "Read should-work"
|
|
|
|
|
|
@pytest.mark.anyio
|
|
async def test_sse_client_exception_handling(
|
|
initialized_sse_client_session: ClientSession,
|
|
) -> None:
|
|
session = initialized_sse_client_session
|
|
with pytest.raises(McpError, match="OOPS! no resource with that URI was found"):
|
|
await session.read_resource(uri=AnyUrl("xxx://will-not-work"))
|
|
|
|
|
|
@pytest.mark.anyio
|
|
@pytest.mark.skip("this test highlights a possible bug in SSE read timeout exception handling")
|
|
async def test_sse_client_timeout(
|
|
initialized_sse_client_session: ClientSession,
|
|
) -> None:
|
|
session = initialized_sse_client_session
|
|
|
|
# sanity check that normal, fast responses are working
|
|
response = await session.read_resource(uri=AnyUrl("foobar://1"))
|
|
assert isinstance(response, ReadResourceResult)
|
|
|
|
with anyio.move_on_after(3):
|
|
with pytest.raises(McpError, match="Read timed out"):
|
|
response = await session.read_resource(uri=AnyUrl("slow://2"))
|
|
# we should receive an error here
|
|
return
|
|
|
|
pytest.fail("the client should have timed out and returned an error already")
|
|
|
|
|
|
def run_mounted_server(server_port: int) -> None:
|
|
app = make_server_app()
|
|
main_app = Starlette(routes=[Mount("/mounted_app", app=app)])
|
|
server = uvicorn.Server(config=uvicorn.Config(app=main_app, host="127.0.0.1", port=server_port, log_level="error"))
|
|
print(f"starting server on {server_port}")
|
|
server.run()
|
|
|
|
# Give server time to start
|
|
while not server.started:
|
|
print("waiting for server to start")
|
|
time.sleep(0.5)
|
|
|
|
|
|
@pytest.fixture()
|
|
def mounted_server(server_port: int) -> Generator[None, None, None]:
|
|
proc = multiprocessing.Process(target=run_mounted_server, kwargs={"server_port": server_port}, daemon=True)
|
|
print("starting process")
|
|
proc.start()
|
|
|
|
# Wait for server to be running
|
|
max_attempts = 20
|
|
attempt = 0
|
|
print("waiting for server to start")
|
|
while attempt < max_attempts:
|
|
try:
|
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
s.connect(("127.0.0.1", server_port))
|
|
break
|
|
except ConnectionRefusedError:
|
|
time.sleep(0.1)
|
|
attempt += 1
|
|
else:
|
|
raise RuntimeError(f"Server failed to start after {max_attempts} attempts")
|
|
|
|
yield
|
|
|
|
print("killing server")
|
|
# Signal the server to stop
|
|
proc.kill()
|
|
proc.join(timeout=2)
|
|
if proc.is_alive():
|
|
print("server process failed to terminate")
|
|
|
|
|
|
@pytest.mark.anyio
|
|
async def test_sse_client_basic_connection_mounted_app(mounted_server: None, server_url: str) -> None:
|
|
async with sse_client(server_url + "/mounted_app/sse") as streams:
|
|
async with ClientSession(*streams) as session:
|
|
# Test initialization
|
|
result = await session.initialize()
|
|
assert isinstance(result, InitializeResult)
|
|
assert result.serverInfo.name == SERVER_NAME
|
|
|
|
# Test ping
|
|
ping_result = await session.send_ping()
|
|
assert isinstance(ping_result, EmptyResult)
|
|
|
|
|
|
# Test server with request context that returns headers in the response
|
|
class RequestContextServer(Server[object, Request]):
|
|
def __init__(self):
|
|
super().__init__("request_context_server")
|
|
|
|
@self.call_tool()
|
|
async def handle_call_tool(name: str, args: dict) -> list[TextContent]:
|
|
headers_info = {}
|
|
context = self.request_context
|
|
if context.request:
|
|
headers_info = dict(context.request.headers)
|
|
|
|
if name == "echo_headers":
|
|
return [TextContent(type="text", text=json.dumps(headers_info))]
|
|
elif name == "echo_context":
|
|
context_data = {
|
|
"request_id": args.get("request_id"),
|
|
"headers": headers_info,
|
|
}
|
|
return [TextContent(type="text", text=json.dumps(context_data))]
|
|
|
|
return [TextContent(type="text", text=f"Called {name}")]
|
|
|
|
@self.list_tools()
|
|
async def handle_list_tools() -> list[Tool]:
|
|
return [
|
|
Tool(
|
|
name="echo_headers",
|
|
description="Echoes request headers",
|
|
inputSchema={"type": "object", "properties": {}},
|
|
),
|
|
Tool(
|
|
name="echo_context",
|
|
description="Echoes request context",
|
|
inputSchema={
|
|
"type": "object",
|
|
"properties": {"request_id": {"type": "string"}},
|
|
"required": ["request_id"],
|
|
},
|
|
),
|
|
]
|
|
|
|
|
|
def run_context_server(server_port: int) -> None:
|
|
"""Run a server that captures request context"""
|
|
sse = SseServerTransport("/messages/")
|
|
context_server = RequestContextServer()
|
|
|
|
async def handle_sse(request: Request) -> Response:
|
|
async with sse.connect_sse(request.scope, request.receive, request._send) as streams:
|
|
await context_server.run(streams[0], streams[1], context_server.create_initialization_options())
|
|
return Response()
|
|
|
|
app = Starlette(
|
|
routes=[
|
|
Route("/sse", endpoint=handle_sse),
|
|
Mount("/messages/", app=sse.handle_post_message),
|
|
]
|
|
)
|
|
|
|
server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error"))
|
|
print(f"starting context server on {server_port}")
|
|
server.run()
|
|
|
|
|
|
@pytest.fixture()
|
|
def context_server(server_port: int) -> Generator[None, None, None]:
|
|
"""Fixture that provides a server with request context capture"""
|
|
proc = multiprocessing.Process(target=run_context_server, kwargs={"server_port": server_port}, daemon=True)
|
|
print("starting context server process")
|
|
proc.start()
|
|
|
|
# Wait for server to be running
|
|
max_attempts = 20
|
|
attempt = 0
|
|
print("waiting for context server to start")
|
|
while attempt < max_attempts:
|
|
try:
|
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
s.connect(("127.0.0.1", server_port))
|
|
break
|
|
except ConnectionRefusedError:
|
|
time.sleep(0.1)
|
|
attempt += 1
|
|
else:
|
|
raise RuntimeError(f"Context server failed to start after {max_attempts} attempts")
|
|
|
|
yield
|
|
|
|
print("killing context server")
|
|
proc.kill()
|
|
proc.join(timeout=2)
|
|
if proc.is_alive():
|
|
print("context server process failed to terminate")
|
|
|
|
|
|
@pytest.mark.anyio
|
|
async def test_request_context_propagation(context_server: None, server_url: str) -> None:
|
|
"""Test that request context is properly propagated through SSE transport."""
|
|
# Test with custom headers
|
|
custom_headers = {
|
|
"Authorization": "Bearer test-token",
|
|
"X-Custom-Header": "test-value",
|
|
"X-Trace-Id": "trace-123",
|
|
}
|
|
|
|
async with sse_client(server_url + "/sse", headers=custom_headers) as (
|
|
read_stream,
|
|
write_stream,
|
|
):
|
|
async with ClientSession(read_stream, write_stream) as session:
|
|
# Initialize the session
|
|
result = await session.initialize()
|
|
assert isinstance(result, InitializeResult)
|
|
|
|
# Call the tool that echoes headers back
|
|
tool_result = await session.call_tool("echo_headers", {})
|
|
|
|
# Parse the JSON response
|
|
|
|
assert len(tool_result.content) == 1
|
|
headers_data = json.loads(tool_result.content[0].text if tool_result.content[0].type == "text" else "{}")
|
|
|
|
# Verify headers were propagated
|
|
assert headers_data.get("authorization") == "Bearer test-token"
|
|
assert headers_data.get("x-custom-header") == "test-value"
|
|
assert headers_data.get("x-trace-id") == "trace-123"
|
|
|
|
|
|
@pytest.mark.anyio
|
|
async def test_request_context_isolation(context_server: None, server_url: str) -> None:
|
|
"""Test that request contexts are isolated between different SSE clients."""
|
|
contexts = []
|
|
|
|
# Create multiple clients with different headers
|
|
for i in range(3):
|
|
headers = {"X-Request-Id": f"request-{i}", "X-Custom-Value": f"value-{i}"}
|
|
|
|
async with sse_client(server_url + "/sse", headers=headers) as (
|
|
read_stream,
|
|
write_stream,
|
|
):
|
|
async with ClientSession(read_stream, write_stream) as session:
|
|
await session.initialize()
|
|
|
|
# Call the tool that echoes context
|
|
tool_result = await session.call_tool("echo_context", {"request_id": f"request-{i}"})
|
|
|
|
assert len(tool_result.content) == 1
|
|
context_data = json.loads(
|
|
tool_result.content[0].text if tool_result.content[0].type == "text" else "{}"
|
|
)
|
|
contexts.append(context_data)
|
|
|
|
# Verify each request had its own context
|
|
assert len(contexts) == 3
|
|
for i, ctx in enumerate(contexts):
|
|
assert ctx["request_id"] == f"request-{i}"
|
|
assert ctx["headers"].get("x-request-id") == f"request-{i}"
|
|
assert ctx["headers"].get("x-custom-value") == f"value-{i}"
|
|
|
|
|
|
def test_sse_message_id_coercion():
|
|
"""Test that string message IDs that look like integers are parsed as integers.
|
|
|
|
See <https://github.com/modelcontextprotocol/python-sdk/pull/851> for more details.
|
|
"""
|
|
json_message = '{"jsonrpc": "2.0", "id": "123", "method": "ping", "params": null}'
|
|
msg = types.JSONRPCMessage.model_validate_json(json_message)
|
|
assert msg == snapshot(types.JSONRPCMessage(root=types.JSONRPCRequest(method="ping", jsonrpc="2.0", id=123)))
|