formatting

This commit is contained in:
Nick Merrill
2025-01-14 11:47:41 -05:00
parent 5097bb7ef8
commit 07e721f63f

View File

@@ -1,4 +1,3 @@
# test_sse.py
import re import re
import multiprocessing import multiprocessing
import socket import socket
@@ -21,20 +20,30 @@ from mcp.client.session import ClientSession
from mcp.client.sse import sse_client from mcp.client.sse import sse_client
from mcp.server import Server from mcp.server import Server
from mcp.server.sse import SseServerTransport from mcp.server.sse import SseServerTransport
from mcp.types import EmptyResult, ErrorData, InitializeResult, TextContent, TextResourceContents, Tool from mcp.types import (
EmptyResult,
ErrorData,
InitializeResult,
TextContent,
TextResourceContents,
Tool,
)
SERVER_NAME = "test_server_for_SSE" SERVER_NAME = "test_server_for_SSE"
@pytest.fixture @pytest.fixture
def server_port() -> int: def server_port() -> int:
with socket.socket() as s: with socket.socket() as s:
s.bind(('127.0.0.1', 0)) s.bind(("127.0.0.1", 0))
return s.getsockname()[1] return s.getsockname()[1]
@pytest.fixture @pytest.fixture
def server_url(server_port: int) -> str: def server_url(server_port: int) -> str:
return f"http://127.0.0.1:{server_port}" return f"http://127.0.0.1:{server_port}"
# Test server implementation # Test server implementation
class TestServer(Server): class TestServer(Server):
def __init__(self): def __init__(self):
@@ -45,7 +54,11 @@ class TestServer(Server):
if uri.scheme == "foobar": if uri.scheme == "foobar":
return f"Read {uri.host}" return f"Read {uri.host}"
# TODO: make this an error # TODO: make this an error
raise McpError(error=ErrorData(code=404, message="OOPS! no resource with that URI was found")) raise McpError(
error=ErrorData(
code=404, message="OOPS! no resource with that URI was found"
)
)
@self.list_tools() @self.list_tools()
async def handle_list_tools(): async def handle_list_tools():
@@ -53,7 +66,7 @@ class TestServer(Server):
Tool( Tool(
name="test_tool", name="test_tool",
description="A test tool", description="A test tool",
inputSchema={"type": "object", "properties": {}} inputSchema={"type": "object", "properties": {}},
) )
] ]
@@ -62,9 +75,8 @@ class TestServer(Server):
return [TextContent(type="text", text=f"Called {name}")] return [TextContent(type="text", text=f"Called {name}")]
# Test fixtures # Test fixtures
def make_server_app()-> Starlette: def make_server_app() -> Starlette:
"""Create test Starlette app with SSE transport""" """Create test Starlette app with SSE transport"""
sse = SseServerTransport("/messages/") sse = SseServerTransport("/messages/")
server = TestServer() server = TestServer()
@@ -74,80 +86,97 @@ def make_server_app()-> Starlette:
request.scope, request.receive, request._send request.scope, request.receive, request._send
) as streams: ) as streams:
await server.run( await server.run(
streams[0], streams[0], streams[1], server.create_initialization_options()
streams[1],
server.create_initialization_options()
) )
app = Starlette(routes=[ app = Starlette(
Route("/sse", endpoint=handle_sse), routes=[
Mount("/messages/", app=sse.handle_post_message), Route("/sse", endpoint=handle_sse),
]) Mount("/messages/", app=sse.handle_post_message),
]
)
return app return app
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def space_around_test(): def space_around_test():
time.sleep(0.1) time.sleep(0.1)
yield yield
time.sleep(0.1) time.sleep(0.1)
def run_server(server_port: int): def run_server(server_port: int):
app = make_server_app() app = make_server_app()
server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error")) server = uvicorn.Server(
print(f'starting server on {server_port}') config=uvicorn.Config(
app=app, host="127.0.0.1", port=server_port, log_level="error"
)
)
print(f"starting server on {server_port}")
server.run() server.run()
# Give server time to start # Give server time to start
while not server.started: while not server.started:
print('waiting for server to start') print("waiting for server to start")
time.sleep(0.5) time.sleep(0.5)
@pytest.fixture() @pytest.fixture()
def server(server_port: int): def server(server_port: int):
proc = multiprocessing.Process(target=run_server, kwargs={"server_port": server_port}, daemon=True) proc = multiprocessing.Process(
print('starting process') target=run_server, kwargs={"server_port": server_port}, daemon=True
)
print("starting process")
proc.start() proc.start()
# Wait for server to be running # Wait for server to be running
max_attempts = 20 max_attempts = 20
attempt = 0 attempt = 0
print('waiting for server to start') print("waiting for server to start")
while attempt < max_attempts: while attempt < max_attempts:
try: try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.connect(('127.0.0.1', server_port)) s.connect(("127.0.0.1", server_port))
break break
except ConnectionRefusedError: except ConnectionRefusedError:
time.sleep(0.1) time.sleep(0.1)
attempt += 1 attempt += 1
else: else:
raise RuntimeError("Server failed to start after {} attempts".format(max_attempts)) raise RuntimeError(
"Server failed to start after {} attempts".format(max_attempts)
)
yield yield
print('killing server') print("killing server")
# Signal the server to stop # Signal the server to stop
proc.kill() proc.kill()
proc.join(timeout=2) proc.join(timeout=2)
if proc.is_alive(): if proc.is_alive():
print("server process failed to terminate") print("server process failed to terminate")
@pytest.fixture() @pytest.fixture()
async def http_client(server, server_url) -> AsyncGenerator[httpx.AsyncClient, None]: async def http_client(server, server_url) -> AsyncGenerator[httpx.AsyncClient, None]:
"""Create test client""" """Create test client"""
async with httpx.AsyncClient(base_url=server_url) as client: async with httpx.AsyncClient(base_url=server_url) as client:
yield client yield client
# Tests # Tests
@pytest.mark.anyio @pytest.mark.anyio
async def test_raw_sse_connection(http_client: httpx.AsyncClient): async def test_raw_sse_connection(http_client: httpx.AsyncClient):
"""Test the SSE connection establishment simply with an HTTP client.""" """Test the SSE connection establishment simply with an HTTP client."""
async with anyio.create_task_group() as tg: async with anyio.create_task_group() as tg:
async def connection_test(): async def connection_test():
async with http_client.stream("GET", "/sse") as response: async with http_client.stream("GET", "/sse") as response:
assert response.status_code == 200 assert response.status_code == 200
assert response.headers["content-type"] == "text/event-stream; charset=utf-8" assert (
response.headers["content-type"]
== "text/event-stream; charset=utf-8"
)
line_number = 0 line_number = 0
async for line in response.aiter_lines(): async for line in response.aiter_lines():
@@ -177,23 +206,32 @@ async def test_sse_client_basic_connection(server, server_url):
ping_result = await session.send_ping() ping_result = await session.send_ping()
assert isinstance(ping_result, EmptyResult) assert isinstance(ping_result, EmptyResult)
@pytest.fixture @pytest.fixture
async def initialized_sse_client_session(server, server_url: str) -> AsyncGenerator[ClientSession, None]: async def initialized_sse_client_session(
server, server_url: str
) -> AsyncGenerator[ClientSession, None]:
async with sse_client(server_url + "/sse") as streams: async with sse_client(server_url + "/sse") as streams:
async with ClientSession(*streams) as session: async with ClientSession(*streams) as session:
await session.initialize() await session.initialize()
yield session yield session
@pytest.mark.anyio @pytest.mark.anyio
async def test_sse_client_happy_request_and_response(initialized_sse_client_session: ClientSession): async def test_sse_client_happy_request_and_response(
initialized_sse_client_session: ClientSession,
):
session = initialized_sse_client_session session = initialized_sse_client_session
response = await session.read_resource(uri=AnyUrl("foobar://should-work")) response = await session.read_resource(uri=AnyUrl("foobar://should-work"))
assert len(response.contents) == 1 assert len(response.contents) == 1
assert isinstance(response.contents[0], TextResourceContents) assert isinstance(response.contents[0], TextResourceContents)
assert response.contents[0].text == "Read should-work" assert response.contents[0].text == "Read should-work"
@pytest.mark.anyio @pytest.mark.anyio
async def test_sse_client_exception_handling(initialized_sse_client_session: ClientSession): async def test_sse_client_exception_handling(
initialized_sse_client_session: ClientSession,
):
session = initialized_sse_client_session session = initialized_sse_client_session
with pytest.raises(McpError, match="OOPS! no resource with that URI was found"): with pytest.raises(McpError, match="OOPS! no resource with that URI was found"):
await session.read_resource(uri=AnyUrl("xxx://will-not-work")) await session.read_resource(uri=AnyUrl("xxx://will-not-work"))