mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 14:54:24 +01:00
formatting
This commit is contained in:
@@ -1,4 +1,3 @@
|
|||||||
# test_sse.py
|
|
||||||
import re
|
import re
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import socket
|
import socket
|
||||||
@@ -21,20 +20,30 @@ from mcp.client.session import ClientSession
|
|||||||
from mcp.client.sse import sse_client
|
from mcp.client.sse import sse_client
|
||||||
from mcp.server import Server
|
from mcp.server import Server
|
||||||
from mcp.server.sse import SseServerTransport
|
from mcp.server.sse import SseServerTransport
|
||||||
from mcp.types import EmptyResult, ErrorData, InitializeResult, TextContent, TextResourceContents, Tool
|
from mcp.types import (
|
||||||
|
EmptyResult,
|
||||||
|
ErrorData,
|
||||||
|
InitializeResult,
|
||||||
|
TextContent,
|
||||||
|
TextResourceContents,
|
||||||
|
Tool,
|
||||||
|
)
|
||||||
|
|
||||||
SERVER_NAME = "test_server_for_SSE"
|
SERVER_NAME = "test_server_for_SSE"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def server_port() -> int:
|
def server_port() -> int:
|
||||||
with socket.socket() as s:
|
with socket.socket() as s:
|
||||||
s.bind(('127.0.0.1', 0))
|
s.bind(("127.0.0.1", 0))
|
||||||
return s.getsockname()[1]
|
return s.getsockname()[1]
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def server_url(server_port: int) -> str:
|
def server_url(server_port: int) -> str:
|
||||||
return f"http://127.0.0.1:{server_port}"
|
return f"http://127.0.0.1:{server_port}"
|
||||||
|
|
||||||
|
|
||||||
# Test server implementation
|
# Test server implementation
|
||||||
class TestServer(Server):
|
class TestServer(Server):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -45,7 +54,11 @@ class TestServer(Server):
|
|||||||
if uri.scheme == "foobar":
|
if uri.scheme == "foobar":
|
||||||
return f"Read {uri.host}"
|
return f"Read {uri.host}"
|
||||||
# TODO: make this an error
|
# TODO: make this an error
|
||||||
raise McpError(error=ErrorData(code=404, message="OOPS! no resource with that URI was found"))
|
raise McpError(
|
||||||
|
error=ErrorData(
|
||||||
|
code=404, message="OOPS! no resource with that URI was found"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
@self.list_tools()
|
@self.list_tools()
|
||||||
async def handle_list_tools():
|
async def handle_list_tools():
|
||||||
@@ -53,7 +66,7 @@ class TestServer(Server):
|
|||||||
Tool(
|
Tool(
|
||||||
name="test_tool",
|
name="test_tool",
|
||||||
description="A test tool",
|
description="A test tool",
|
||||||
inputSchema={"type": "object", "properties": {}}
|
inputSchema={"type": "object", "properties": {}},
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -62,7 +75,6 @@ class TestServer(Server):
|
|||||||
return [TextContent(type="text", text=f"Called {name}")]
|
return [TextContent(type="text", text=f"Called {name}")]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Test fixtures
|
# Test fixtures
|
||||||
def make_server_app() -> Starlette:
|
def make_server_app() -> Starlette:
|
||||||
"""Create test Starlette app with SSE transport"""
|
"""Create test Starlette app with SSE transport"""
|
||||||
@@ -74,80 +86,97 @@ def make_server_app()-> Starlette:
|
|||||||
request.scope, request.receive, request._send
|
request.scope, request.receive, request._send
|
||||||
) as streams:
|
) as streams:
|
||||||
await server.run(
|
await server.run(
|
||||||
streams[0],
|
streams[0], streams[1], server.create_initialization_options()
|
||||||
streams[1],
|
|
||||||
server.create_initialization_options()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
app = Starlette(routes=[
|
app = Starlette(
|
||||||
|
routes=[
|
||||||
Route("/sse", endpoint=handle_sse),
|
Route("/sse", endpoint=handle_sse),
|
||||||
Mount("/messages/", app=sse.handle_post_message),
|
Mount("/messages/", app=sse.handle_post_message),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
|
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def space_around_test():
|
def space_around_test():
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
yield
|
yield
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
|
||||||
def run_server(server_port: int):
|
def run_server(server_port: int):
|
||||||
app = make_server_app()
|
app = make_server_app()
|
||||||
server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error"))
|
server = uvicorn.Server(
|
||||||
print(f'starting server on {server_port}')
|
config=uvicorn.Config(
|
||||||
|
app=app, host="127.0.0.1", port=server_port, log_level="error"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
print(f"starting server on {server_port}")
|
||||||
server.run()
|
server.run()
|
||||||
|
|
||||||
# Give server time to start
|
# Give server time to start
|
||||||
while not server.started:
|
while not server.started:
|
||||||
print('waiting for server to start')
|
print("waiting for server to start")
|
||||||
time.sleep(0.5)
|
time.sleep(0.5)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def server(server_port: int):
|
def server(server_port: int):
|
||||||
proc = multiprocessing.Process(target=run_server, kwargs={"server_port": server_port}, daemon=True)
|
proc = multiprocessing.Process(
|
||||||
print('starting process')
|
target=run_server, kwargs={"server_port": server_port}, daemon=True
|
||||||
|
)
|
||||||
|
print("starting process")
|
||||||
proc.start()
|
proc.start()
|
||||||
|
|
||||||
# Wait for server to be running
|
# Wait for server to be running
|
||||||
max_attempts = 20
|
max_attempts = 20
|
||||||
attempt = 0
|
attempt = 0
|
||||||
print('waiting for server to start')
|
print("waiting for server to start")
|
||||||
while attempt < max_attempts:
|
while attempt < max_attempts:
|
||||||
try:
|
try:
|
||||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||||
s.connect(('127.0.0.1', server_port))
|
s.connect(("127.0.0.1", server_port))
|
||||||
break
|
break
|
||||||
except ConnectionRefusedError:
|
except ConnectionRefusedError:
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
attempt += 1
|
attempt += 1
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Server failed to start after {} attempts".format(max_attempts))
|
raise RuntimeError(
|
||||||
|
"Server failed to start after {} attempts".format(max_attempts)
|
||||||
|
)
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
print('killing server')
|
print("killing server")
|
||||||
# Signal the server to stop
|
# Signal the server to stop
|
||||||
proc.kill()
|
proc.kill()
|
||||||
proc.join(timeout=2)
|
proc.join(timeout=2)
|
||||||
if proc.is_alive():
|
if proc.is_alive():
|
||||||
print("server process failed to terminate")
|
print("server process failed to terminate")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
async def http_client(server, server_url) -> AsyncGenerator[httpx.AsyncClient, None]:
|
async def http_client(server, server_url) -> AsyncGenerator[httpx.AsyncClient, None]:
|
||||||
"""Create test client"""
|
"""Create test client"""
|
||||||
async with httpx.AsyncClient(base_url=server_url) as client:
|
async with httpx.AsyncClient(base_url=server_url) as client:
|
||||||
yield client
|
yield client
|
||||||
|
|
||||||
|
|
||||||
# Tests
|
# Tests
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_raw_sse_connection(http_client: httpx.AsyncClient):
|
async def test_raw_sse_connection(http_client: httpx.AsyncClient):
|
||||||
"""Test the SSE connection establishment simply with an HTTP client."""
|
"""Test the SSE connection establishment simply with an HTTP client."""
|
||||||
async with anyio.create_task_group() as tg:
|
async with anyio.create_task_group() as tg:
|
||||||
|
|
||||||
async def connection_test():
|
async def connection_test():
|
||||||
async with http_client.stream("GET", "/sse") as response:
|
async with http_client.stream("GET", "/sse") as response:
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
|
assert (
|
||||||
|
response.headers["content-type"]
|
||||||
|
== "text/event-stream; charset=utf-8"
|
||||||
|
)
|
||||||
|
|
||||||
line_number = 0
|
line_number = 0
|
||||||
async for line in response.aiter_lines():
|
async for line in response.aiter_lines():
|
||||||
@@ -177,23 +206,32 @@ async def test_sse_client_basic_connection(server, server_url):
|
|||||||
ping_result = await session.send_ping()
|
ping_result = await session.send_ping()
|
||||||
assert isinstance(ping_result, EmptyResult)
|
assert isinstance(ping_result, EmptyResult)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def initialized_sse_client_session(server, server_url: str) -> AsyncGenerator[ClientSession, None]:
|
async def initialized_sse_client_session(
|
||||||
|
server, server_url: str
|
||||||
|
) -> AsyncGenerator[ClientSession, None]:
|
||||||
async with sse_client(server_url + "/sse") as streams:
|
async with sse_client(server_url + "/sse") as streams:
|
||||||
async with ClientSession(*streams) as session:
|
async with ClientSession(*streams) as session:
|
||||||
await session.initialize()
|
await session.initialize()
|
||||||
yield session
|
yield session
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_sse_client_happy_request_and_response(initialized_sse_client_session: ClientSession):
|
async def test_sse_client_happy_request_and_response(
|
||||||
|
initialized_sse_client_session: ClientSession,
|
||||||
|
):
|
||||||
session = initialized_sse_client_session
|
session = initialized_sse_client_session
|
||||||
response = await session.read_resource(uri=AnyUrl("foobar://should-work"))
|
response = await session.read_resource(uri=AnyUrl("foobar://should-work"))
|
||||||
assert len(response.contents) == 1
|
assert len(response.contents) == 1
|
||||||
assert isinstance(response.contents[0], TextResourceContents)
|
assert isinstance(response.contents[0], TextResourceContents)
|
||||||
assert response.contents[0].text == "Read should-work"
|
assert response.contents[0].text == "Read should-work"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_sse_client_exception_handling(initialized_sse_client_session: ClientSession):
|
async def test_sse_client_exception_handling(
|
||||||
|
initialized_sse_client_session: ClientSession,
|
||||||
|
):
|
||||||
session = initialized_sse_client_session
|
session = initialized_sse_client_session
|
||||||
with pytest.raises(McpError, match="OOPS! no resource with that URI was found"):
|
with pytest.raises(McpError, match="OOPS! no resource with that URI was found"):
|
||||||
await session.read_resource(uri=AnyUrl("xxx://will-not-work"))
|
await session.read_resource(uri=AnyUrl("xxx://will-not-work"))
|
||||||
|
|||||||
Reference in New Issue
Block a user