mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-22 16:24:24 +01:00
Use 120 characters instead of 88 (#856)
This commit is contained in:
committed by
GitHub
parent
f7265f7b91
commit
543961968c
@@ -22,12 +22,8 @@ from mcp.shared.session import (
|
||||
async def test_bidirectional_progress_notifications():
|
||||
"""Test that both client and server can send progress notifications."""
|
||||
# Create memory streams for client/server
|
||||
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
|
||||
SessionMessage
|
||||
](5)
|
||||
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
|
||||
SessionMessage
|
||||
](5)
|
||||
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](5)
|
||||
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](5)
|
||||
|
||||
# Run a server session so we can send progress updates in tool
|
||||
async def run_server():
|
||||
@@ -134,9 +130,7 @@ async def test_bidirectional_progress_notifications():
|
||||
|
||||
# Client message handler to store progress notifications
|
||||
async def handle_client_message(
|
||||
message: RequestResponder[types.ServerRequest, types.ClientResult]
|
||||
| types.ServerNotification
|
||||
| Exception,
|
||||
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
|
||||
) -> None:
|
||||
if isinstance(message, Exception):
|
||||
raise message
|
||||
@@ -172,9 +166,7 @@ async def test_bidirectional_progress_notifications():
|
||||
await client_session.list_tools()
|
||||
|
||||
# Call test_tool with progress token
|
||||
await client_session.call_tool(
|
||||
"test_tool", {"_meta": {"progressToken": client_progress_token}}
|
||||
)
|
||||
await client_session.call_tool("test_tool", {"_meta": {"progressToken": client_progress_token}})
|
||||
|
||||
# Send progress notifications from client to server
|
||||
await client_session.send_progress_notification(
|
||||
@@ -221,12 +213,8 @@ async def test_bidirectional_progress_notifications():
|
||||
async def test_progress_context_manager():
|
||||
"""Test client using progress context manager for sending progress notifications."""
|
||||
# Create memory streams for client/server
|
||||
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
|
||||
SessionMessage
|
||||
](5)
|
||||
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
|
||||
SessionMessage
|
||||
](5)
|
||||
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](5)
|
||||
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](5)
|
||||
|
||||
# Track progress updates
|
||||
server_progress_updates = []
|
||||
@@ -270,9 +258,7 @@ async def test_progress_context_manager():
|
||||
|
||||
# Client message handler
|
||||
async def handle_client_message(
|
||||
message: RequestResponder[types.ServerRequest, types.ClientResult]
|
||||
| types.ServerNotification
|
||||
| Exception,
|
||||
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
|
||||
) -> None:
|
||||
if isinstance(message, Exception):
|
||||
raise message
|
||||
|
||||
@@ -90,9 +90,7 @@ async def test_request_cancellation():
|
||||
ClientRequest(
|
||||
types.CallToolRequest(
|
||||
method="tools/call",
|
||||
params=types.CallToolRequestParams(
|
||||
name="slow_tool", arguments={}
|
||||
),
|
||||
params=types.CallToolRequestParams(name="slow_tool", arguments={}),
|
||||
)
|
||||
),
|
||||
types.CallToolResult,
|
||||
@@ -103,9 +101,7 @@ async def test_request_cancellation():
|
||||
assert "Request cancelled" in str(e)
|
||||
ev_cancelled.set()
|
||||
|
||||
async with create_connected_server_and_client_session(
|
||||
make_server()
|
||||
) as client_session:
|
||||
async with create_connected_server_and_client_session(make_server()) as client_session:
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(make_request, client_session)
|
||||
|
||||
|
||||
@@ -60,11 +60,7 @@ class ServerTest(Server):
|
||||
await anyio.sleep(2.0)
|
||||
return f"Slow response from {uri.host}"
|
||||
|
||||
raise McpError(
|
||||
error=ErrorData(
|
||||
code=404, message="OOPS! no resource with that URI was found"
|
||||
)
|
||||
)
|
||||
raise McpError(error=ErrorData(code=404, message="OOPS! no resource with that URI was found"))
|
||||
|
||||
@self.list_tools()
|
||||
async def handle_list_tools() -> list[Tool]:
|
||||
@@ -88,12 +84,8 @@ def make_server_app() -> Starlette:
|
||||
server = ServerTest()
|
||||
|
||||
async def handle_sse(request: Request) -> Response:
|
||||
async with sse.connect_sse(
|
||||
request.scope, request.receive, request._send
|
||||
) as streams:
|
||||
await server.run(
|
||||
streams[0], streams[1], server.create_initialization_options()
|
||||
)
|
||||
async with sse.connect_sse(request.scope, request.receive, request._send) as streams:
|
||||
await server.run(streams[0], streams[1], server.create_initialization_options())
|
||||
return Response()
|
||||
|
||||
app = Starlette(
|
||||
@@ -108,11 +100,7 @@ def make_server_app() -> Starlette:
|
||||
|
||||
def run_server(server_port: int) -> None:
|
||||
app = make_server_app()
|
||||
server = uvicorn.Server(
|
||||
config=uvicorn.Config(
|
||||
app=app, host="127.0.0.1", port=server_port, log_level="error"
|
||||
)
|
||||
)
|
||||
server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error"))
|
||||
print(f"starting server on {server_port}")
|
||||
server.run()
|
||||
|
||||
@@ -124,9 +112,7 @@ def run_server(server_port: int) -> None:
|
||||
|
||||
@pytest.fixture()
|
||||
def server(server_port: int) -> Generator[None, None, None]:
|
||||
proc = multiprocessing.Process(
|
||||
target=run_server, kwargs={"server_port": server_port}, daemon=True
|
||||
)
|
||||
proc = multiprocessing.Process(target=run_server, kwargs={"server_port": server_port}, daemon=True)
|
||||
print("starting process")
|
||||
proc.start()
|
||||
|
||||
@@ -171,10 +157,7 @@ async def test_raw_sse_connection(http_client: httpx.AsyncClient) -> None:
|
||||
async def connection_test() -> None:
|
||||
async with http_client.stream("GET", "/sse") as response:
|
||||
assert response.status_code == 200
|
||||
assert (
|
||||
response.headers["content-type"]
|
||||
== "text/event-stream; charset=utf-8"
|
||||
)
|
||||
assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
|
||||
|
||||
line_number = 0
|
||||
async for line in response.aiter_lines():
|
||||
@@ -206,9 +189,7 @@ async def test_sse_client_basic_connection(server: None, server_url: str) -> Non
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def initialized_sse_client_session(
|
||||
server, server_url: str
|
||||
) -> AsyncGenerator[ClientSession, None]:
|
||||
async def initialized_sse_client_session(server, server_url: str) -> AsyncGenerator[ClientSession, None]:
|
||||
async with sse_client(server_url + "/sse", sse_read_timeout=0.5) as streams:
|
||||
async with ClientSession(*streams) as session:
|
||||
await session.initialize()
|
||||
@@ -236,9 +217,7 @@ async def test_sse_client_exception_handling(
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.skip(
|
||||
"this test highlights a possible bug in SSE read timeout exception handling"
|
||||
)
|
||||
@pytest.mark.skip("this test highlights a possible bug in SSE read timeout exception handling")
|
||||
async def test_sse_client_timeout(
|
||||
initialized_sse_client_session: ClientSession,
|
||||
) -> None:
|
||||
@@ -260,11 +239,7 @@ async def test_sse_client_timeout(
|
||||
def run_mounted_server(server_port: int) -> None:
|
||||
app = make_server_app()
|
||||
main_app = Starlette(routes=[Mount("/mounted_app", app=app)])
|
||||
server = uvicorn.Server(
|
||||
config=uvicorn.Config(
|
||||
app=main_app, host="127.0.0.1", port=server_port, log_level="error"
|
||||
)
|
||||
)
|
||||
server = uvicorn.Server(config=uvicorn.Config(app=main_app, host="127.0.0.1", port=server_port, log_level="error"))
|
||||
print(f"starting server on {server_port}")
|
||||
server.run()
|
||||
|
||||
@@ -276,9 +251,7 @@ def run_mounted_server(server_port: int) -> None:
|
||||
|
||||
@pytest.fixture()
|
||||
def mounted_server(server_port: int) -> Generator[None, None, None]:
|
||||
proc = multiprocessing.Process(
|
||||
target=run_mounted_server, kwargs={"server_port": server_port}, daemon=True
|
||||
)
|
||||
proc = multiprocessing.Process(target=run_mounted_server, kwargs={"server_port": server_port}, daemon=True)
|
||||
print("starting process")
|
||||
proc.start()
|
||||
|
||||
@@ -308,9 +281,7 @@ def mounted_server(server_port: int) -> Generator[None, None, None]:
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_sse_client_basic_connection_mounted_app(
|
||||
mounted_server: None, server_url: str
|
||||
) -> None:
|
||||
async def test_sse_client_basic_connection_mounted_app(mounted_server: None, server_url: str) -> None:
|
||||
async with sse_client(server_url + "/mounted_app/sse") as streams:
|
||||
async with ClientSession(*streams) as session:
|
||||
# Test initialization
|
||||
@@ -372,12 +343,8 @@ def run_context_server(server_port: int) -> None:
|
||||
context_server = RequestContextServer()
|
||||
|
||||
async def handle_sse(request: Request) -> Response:
|
||||
async with sse.connect_sse(
|
||||
request.scope, request.receive, request._send
|
||||
) as streams:
|
||||
await context_server.run(
|
||||
streams[0], streams[1], context_server.create_initialization_options()
|
||||
)
|
||||
async with sse.connect_sse(request.scope, request.receive, request._send) as streams:
|
||||
await context_server.run(streams[0], streams[1], context_server.create_initialization_options())
|
||||
return Response()
|
||||
|
||||
app = Starlette(
|
||||
@@ -387,11 +354,7 @@ def run_context_server(server_port: int) -> None:
|
||||
]
|
||||
)
|
||||
|
||||
server = uvicorn.Server(
|
||||
config=uvicorn.Config(
|
||||
app=app, host="127.0.0.1", port=server_port, log_level="error"
|
||||
)
|
||||
)
|
||||
server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error"))
|
||||
print(f"starting context server on {server_port}")
|
||||
server.run()
|
||||
|
||||
@@ -399,9 +362,7 @@ def run_context_server(server_port: int) -> None:
|
||||
@pytest.fixture()
|
||||
def context_server(server_port: int) -> Generator[None, None, None]:
|
||||
"""Fixture that provides a server with request context capture"""
|
||||
proc = multiprocessing.Process(
|
||||
target=run_context_server, kwargs={"server_port": server_port}, daemon=True
|
||||
)
|
||||
proc = multiprocessing.Process(target=run_context_server, kwargs={"server_port": server_port}, daemon=True)
|
||||
print("starting context server process")
|
||||
proc.start()
|
||||
|
||||
@@ -418,9 +379,7 @@ def context_server(server_port: int) -> Generator[None, None, None]:
|
||||
time.sleep(0.1)
|
||||
attempt += 1
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Context server failed to start after {max_attempts} attempts"
|
||||
)
|
||||
raise RuntimeError(f"Context server failed to start after {max_attempts} attempts")
|
||||
|
||||
yield
|
||||
|
||||
@@ -432,9 +391,7 @@ def context_server(server_port: int) -> Generator[None, None, None]:
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_request_context_propagation(
|
||||
context_server: None, server_url: str
|
||||
) -> None:
|
||||
async def test_request_context_propagation(context_server: None, server_url: str) -> None:
|
||||
"""Test that request context is properly propagated through SSE transport."""
|
||||
# Test with custom headers
|
||||
custom_headers = {
|
||||
@@ -458,11 +415,7 @@ async def test_request_context_propagation(
|
||||
# Parse the JSON response
|
||||
|
||||
assert len(tool_result.content) == 1
|
||||
headers_data = json.loads(
|
||||
tool_result.content[0].text
|
||||
if tool_result.content[0].type == "text"
|
||||
else "{}"
|
||||
)
|
||||
headers_data = json.loads(tool_result.content[0].text if tool_result.content[0].type == "text" else "{}")
|
||||
|
||||
# Verify headers were propagated
|
||||
assert headers_data.get("authorization") == "Bearer test-token"
|
||||
@@ -487,15 +440,11 @@ async def test_request_context_isolation(context_server: None, server_url: str)
|
||||
await session.initialize()
|
||||
|
||||
# Call the tool that echoes context
|
||||
tool_result = await session.call_tool(
|
||||
"echo_context", {"request_id": f"request-{i}"}
|
||||
)
|
||||
tool_result = await session.call_tool("echo_context", {"request_id": f"request-{i}"})
|
||||
|
||||
assert len(tool_result.content) == 1
|
||||
context_data = json.loads(
|
||||
tool_result.content[0].text
|
||||
if tool_result.content[0].type == "text"
|
||||
else "{}"
|
||||
tool_result.content[0].text if tool_result.content[0].type == "text" else "{}"
|
||||
)
|
||||
contexts.append(context_data)
|
||||
|
||||
@@ -514,8 +463,4 @@ def test_sse_message_id_coercion():
|
||||
"""
|
||||
json_message = '{"jsonrpc": "2.0", "id": "123", "method": "ping", "params": null}'
|
||||
msg = types.JSONRPCMessage.model_validate_json(json_message)
|
||||
assert msg == snapshot(
|
||||
types.JSONRPCMessage(
|
||||
root=types.JSONRPCRequest(method="ping", jsonrpc="2.0", id=123)
|
||||
)
|
||||
)
|
||||
assert msg == snapshot(types.JSONRPCMessage(root=types.JSONRPCRequest(method="ping", jsonrpc="2.0", id=123)))
|
||||
|
||||
@@ -72,9 +72,7 @@ class SimpleEventStore(EventStore):
|
||||
self._events: list[tuple[StreamId, EventId, types.JSONRPCMessage]] = []
|
||||
self._event_id_counter = 0
|
||||
|
||||
async def store_event(
|
||||
self, stream_id: StreamId, message: types.JSONRPCMessage
|
||||
) -> EventId:
|
||||
async def store_event(self, stream_id: StreamId, message: types.JSONRPCMessage) -> EventId:
|
||||
"""Store an event and return its ID."""
|
||||
self._event_id_counter += 1
|
||||
event_id = str(self._event_id_counter)
|
||||
@@ -156,9 +154,7 @@ class ServerTest(Server):
|
||||
|
||||
# When the tool is called, send a notification to test GET stream
|
||||
if name == "test_tool_with_standalone_notification":
|
||||
await ctx.session.send_resource_updated(
|
||||
uri=AnyUrl("http://test_resource")
|
||||
)
|
||||
await ctx.session.send_resource_updated(uri=AnyUrl("http://test_resource"))
|
||||
return [TextContent(type="text", text=f"Called {name}")]
|
||||
|
||||
elif name == "long_running_with_checkpoints":
|
||||
@@ -189,9 +185,7 @@ class ServerTest(Server):
|
||||
messages=[
|
||||
types.SamplingMessage(
|
||||
role="user",
|
||||
content=types.TextContent(
|
||||
type="text", text="Server needs client sampling"
|
||||
),
|
||||
content=types.TextContent(type="text", text="Server needs client sampling"),
|
||||
)
|
||||
],
|
||||
max_tokens=100,
|
||||
@@ -199,11 +193,7 @@ class ServerTest(Server):
|
||||
)
|
||||
|
||||
# Return the sampling result in the tool response
|
||||
response = (
|
||||
sampling_result.content.text
|
||||
if sampling_result.content.type == "text"
|
||||
else None
|
||||
)
|
||||
response = sampling_result.content.text if sampling_result.content.type == "text" else None
|
||||
return [
|
||||
TextContent(
|
||||
type="text",
|
||||
@@ -214,9 +204,7 @@ class ServerTest(Server):
|
||||
return [TextContent(type="text", text=f"Called {name}")]
|
||||
|
||||
|
||||
def create_app(
|
||||
is_json_response_enabled=False, event_store: EventStore | None = None
|
||||
) -> Starlette:
|
||||
def create_app(is_json_response_enabled=False, event_store: EventStore | None = None) -> Starlette:
|
||||
"""Create a Starlette application for testing using the session manager.
|
||||
|
||||
Args:
|
||||
@@ -245,9 +233,7 @@ def create_app(
|
||||
return app
|
||||
|
||||
|
||||
def run_server(
|
||||
port: int, is_json_response_enabled=False, event_store: EventStore | None = None
|
||||
) -> None:
|
||||
def run_server(port: int, is_json_response_enabled=False, event_store: EventStore | None = None) -> None:
|
||||
"""Run the test server.
|
||||
|
||||
Args:
|
||||
@@ -300,9 +286,7 @@ def json_server_port() -> int:
|
||||
@pytest.fixture
|
||||
def basic_server(basic_server_port: int) -> Generator[None, None, None]:
|
||||
"""Start a basic server."""
|
||||
proc = multiprocessing.Process(
|
||||
target=run_server, kwargs={"port": basic_server_port}, daemon=True
|
||||
)
|
||||
proc = multiprocessing.Process(target=run_server, kwargs={"port": basic_server_port}, daemon=True)
|
||||
proc.start()
|
||||
|
||||
# Wait for server to be running
|
||||
@@ -778,9 +762,7 @@ async def test_streamablehttp_client_basic_connection(basic_server, basic_server
|
||||
@pytest.mark.anyio
|
||||
async def test_streamablehttp_client_resource_read(initialized_client_session):
|
||||
"""Test client resource read functionality."""
|
||||
response = await initialized_client_session.read_resource(
|
||||
uri=AnyUrl("foobar://test-resource")
|
||||
)
|
||||
response = await initialized_client_session.read_resource(uri=AnyUrl("foobar://test-resource"))
|
||||
assert len(response.contents) == 1
|
||||
assert response.contents[0].uri == AnyUrl("foobar://test-resource")
|
||||
assert response.contents[0].text == "Read test-resource"
|
||||
@@ -805,17 +787,13 @@ async def test_streamablehttp_client_tool_invocation(initialized_client_session)
|
||||
async def test_streamablehttp_client_error_handling(initialized_client_session):
|
||||
"""Test error handling in client."""
|
||||
with pytest.raises(McpError) as exc_info:
|
||||
await initialized_client_session.read_resource(
|
||||
uri=AnyUrl("unknown://test-error")
|
||||
)
|
||||
await initialized_client_session.read_resource(uri=AnyUrl("unknown://test-error"))
|
||||
assert exc_info.value.error.code == 0
|
||||
assert "Unknown resource: unknown://test-error" in exc_info.value.error.message
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_streamablehttp_client_session_persistence(
|
||||
basic_server, basic_server_url
|
||||
):
|
||||
async def test_streamablehttp_client_session_persistence(basic_server, basic_server_url):
|
||||
"""Test that session ID persists across requests."""
|
||||
async with streamablehttp_client(f"{basic_server_url}/mcp") as (
|
||||
read_stream,
|
||||
@@ -843,9 +821,7 @@ async def test_streamablehttp_client_session_persistence(
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_streamablehttp_client_json_response(
|
||||
json_response_server, json_server_url
|
||||
):
|
||||
async def test_streamablehttp_client_json_response(json_response_server, json_server_url):
|
||||
"""Test client with JSON response mode."""
|
||||
async with streamablehttp_client(f"{json_server_url}/mcp") as (
|
||||
read_stream,
|
||||
@@ -882,9 +858,7 @@ async def test_streamablehttp_client_get_stream(basic_server, basic_server_url):
|
||||
|
||||
# Define message handler to capture notifications
|
||||
async def message_handler(
|
||||
message: RequestResponder[types.ServerRequest, types.ClientResult]
|
||||
| types.ServerNotification
|
||||
| Exception,
|
||||
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
|
||||
) -> None:
|
||||
if isinstance(message, types.ServerNotification):
|
||||
notifications_received.append(message)
|
||||
@@ -894,9 +868,7 @@ async def test_streamablehttp_client_get_stream(basic_server, basic_server_url):
|
||||
write_stream,
|
||||
_,
|
||||
):
|
||||
async with ClientSession(
|
||||
read_stream, write_stream, message_handler=message_handler
|
||||
) as session:
|
||||
async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session:
|
||||
# Initialize the session - this triggers the GET stream setup
|
||||
result = await session.initialize()
|
||||
assert isinstance(result, InitializeResult)
|
||||
@@ -914,15 +886,11 @@ async def test_streamablehttp_client_get_stream(basic_server, basic_server_url):
|
||||
assert str(notif.root.params.uri) == "http://test_resource/"
|
||||
resource_update_found = True
|
||||
|
||||
assert (
|
||||
resource_update_found
|
||||
), "ResourceUpdatedNotification not received via GET stream"
|
||||
assert resource_update_found, "ResourceUpdatedNotification not received via GET stream"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_streamablehttp_client_session_termination(
|
||||
basic_server, basic_server_url
|
||||
):
|
||||
async def test_streamablehttp_client_session_termination(basic_server, basic_server_url):
|
||||
"""Test client session termination functionality."""
|
||||
|
||||
captured_session_id = None
|
||||
@@ -963,9 +931,7 @@ async def test_streamablehttp_client_session_termination(
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_streamablehttp_client_session_termination_204(
|
||||
basic_server, basic_server_url, monkeypatch
|
||||
):
|
||||
async def test_streamablehttp_client_session_termination_204(basic_server, basic_server_url, monkeypatch):
|
||||
"""Test client session termination functionality with a 204 response.
|
||||
|
||||
This test patches the httpx client to return a 204 response for DELETEs.
|
||||
@@ -1040,9 +1006,7 @@ async def test_streamablehttp_client_resumption(event_server):
|
||||
tool_started = False
|
||||
|
||||
async def message_handler(
|
||||
message: RequestResponder[types.ServerRequest, types.ClientResult]
|
||||
| types.ServerNotification
|
||||
| Exception,
|
||||
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
|
||||
) -> None:
|
||||
if isinstance(message, types.ServerNotification):
|
||||
captured_notifications.append(message)
|
||||
@@ -1062,9 +1026,7 @@ async def test_streamablehttp_client_resumption(event_server):
|
||||
write_stream,
|
||||
get_session_id,
|
||||
):
|
||||
async with ClientSession(
|
||||
read_stream, write_stream, message_handler=message_handler
|
||||
) as session:
|
||||
async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session:
|
||||
# Initialize the session
|
||||
result = await session.initialize()
|
||||
assert isinstance(result, InitializeResult)
|
||||
@@ -1082,9 +1044,7 @@ async def test_streamablehttp_client_resumption(event_server):
|
||||
types.ClientRequest(
|
||||
types.CallToolRequest(
|
||||
method="tools/call",
|
||||
params=types.CallToolRequestParams(
|
||||
name="long_running_with_checkpoints", arguments={}
|
||||
),
|
||||
params=types.CallToolRequestParams(name="long_running_with_checkpoints", arguments={}),
|
||||
)
|
||||
),
|
||||
types.CallToolResult,
|
||||
@@ -1114,9 +1074,7 @@ async def test_streamablehttp_client_resumption(event_server):
|
||||
write_stream,
|
||||
_,
|
||||
):
|
||||
async with ClientSession(
|
||||
read_stream, write_stream, message_handler=message_handler
|
||||
) as session:
|
||||
async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session:
|
||||
# Don't initialize - just use the existing session
|
||||
|
||||
# Resume the tool with the resumption token
|
||||
@@ -1129,9 +1087,7 @@ async def test_streamablehttp_client_resumption(event_server):
|
||||
types.ClientRequest(
|
||||
types.CallToolRequest(
|
||||
method="tools/call",
|
||||
params=types.CallToolRequestParams(
|
||||
name="long_running_with_checkpoints", arguments={}
|
||||
),
|
||||
params=types.CallToolRequestParams(name="long_running_with_checkpoints", arguments={}),
|
||||
)
|
||||
),
|
||||
types.CallToolResult,
|
||||
@@ -1149,14 +1105,11 @@ async def test_streamablehttp_client_resumption(event_server):
|
||||
# Should not have the first notification
|
||||
# Check that "Tool started" notification isn't repeated when resuming
|
||||
assert not any(
|
||||
isinstance(n.root, types.LoggingMessageNotification)
|
||||
and n.root.params.data == "Tool started"
|
||||
isinstance(n.root, types.LoggingMessageNotification) and n.root.params.data == "Tool started"
|
||||
for n in captured_notifications
|
||||
)
|
||||
# there is no intersection between pre and post notifications
|
||||
assert not any(
|
||||
n in captured_notifications_pre for n in captured_notifications
|
||||
)
|
||||
assert not any(n in captured_notifications_pre for n in captured_notifications)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@@ -1175,11 +1128,7 @@ async def test_streamablehttp_server_sampling(basic_server, basic_server_url):
|
||||
nonlocal sampling_callback_invoked, captured_message_params
|
||||
sampling_callback_invoked = True
|
||||
captured_message_params = params
|
||||
message_received = (
|
||||
params.messages[0].content.text
|
||||
if params.messages[0].content.type == "text"
|
||||
else None
|
||||
)
|
||||
message_received = params.messages[0].content.text if params.messages[0].content.type == "text" else None
|
||||
|
||||
return types.CreateMessageResult(
|
||||
role="assistant",
|
||||
@@ -1212,19 +1161,13 @@ async def test_streamablehttp_server_sampling(basic_server, basic_server_url):
|
||||
# Verify the tool result contains the expected content
|
||||
assert len(tool_result.content) == 1
|
||||
assert tool_result.content[0].type == "text"
|
||||
assert (
|
||||
"Response from sampling: Received message from server"
|
||||
in tool_result.content[0].text
|
||||
)
|
||||
assert "Response from sampling: Received message from server" in tool_result.content[0].text
|
||||
|
||||
# Verify sampling callback was invoked
|
||||
assert sampling_callback_invoked
|
||||
assert captured_message_params is not None
|
||||
assert len(captured_message_params.messages) == 1
|
||||
assert (
|
||||
captured_message_params.messages[0].content.text
|
||||
== "Server needs client sampling"
|
||||
)
|
||||
assert captured_message_params.messages[0].content.text == "Server needs client sampling"
|
||||
|
||||
|
||||
# Context-aware server implementation for testing request context propagation
|
||||
@@ -1325,9 +1268,7 @@ def run_context_aware_server(port: int):
|
||||
@pytest.fixture
|
||||
def context_aware_server(basic_server_port: int) -> Generator[None, None, None]:
|
||||
"""Start the context-aware server in a separate process."""
|
||||
proc = multiprocessing.Process(
|
||||
target=run_context_aware_server, args=(basic_server_port,), daemon=True
|
||||
)
|
||||
proc = multiprocessing.Process(target=run_context_aware_server, args=(basic_server_port,), daemon=True)
|
||||
proc.start()
|
||||
|
||||
# Wait for server to be running
|
||||
@@ -1342,9 +1283,7 @@ def context_aware_server(basic_server_port: int) -> Generator[None, None, None]:
|
||||
time.sleep(0.1)
|
||||
attempt += 1
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Context-aware server failed to start after {max_attempts} attempts"
|
||||
)
|
||||
raise RuntimeError(f"Context-aware server failed to start after {max_attempts} attempts")
|
||||
|
||||
yield
|
||||
|
||||
@@ -1355,9 +1294,7 @@ def context_aware_server(basic_server_port: int) -> Generator[None, None, None]:
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_streamablehttp_request_context_propagation(
|
||||
context_aware_server: None, basic_server_url: str
|
||||
) -> None:
|
||||
async def test_streamablehttp_request_context_propagation(context_aware_server: None, basic_server_url: str) -> None:
|
||||
"""Test that request context is properly propagated through StreamableHTTP."""
|
||||
custom_headers = {
|
||||
"Authorization": "Bearer test-token",
|
||||
@@ -1365,9 +1302,11 @@ async def test_streamablehttp_request_context_propagation(
|
||||
"X-Trace-Id": "trace-123",
|
||||
}
|
||||
|
||||
async with streamablehttp_client(
|
||||
f"{basic_server_url}/mcp", headers=custom_headers
|
||||
) as (read_stream, write_stream, _):
|
||||
async with streamablehttp_client(f"{basic_server_url}/mcp", headers=custom_headers) as (
|
||||
read_stream,
|
||||
write_stream,
|
||||
_,
|
||||
):
|
||||
async with ClientSession(read_stream, write_stream) as session:
|
||||
result = await session.initialize()
|
||||
assert isinstance(result, InitializeResult)
|
||||
@@ -1388,9 +1327,7 @@ async def test_streamablehttp_request_context_propagation(
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_streamablehttp_request_context_isolation(
|
||||
context_aware_server: None, basic_server_url: str
|
||||
) -> None:
|
||||
async def test_streamablehttp_request_context_isolation(context_aware_server: None, basic_server_url: str) -> None:
|
||||
"""Test that request contexts are isolated between StreamableHTTP clients."""
|
||||
contexts = []
|
||||
|
||||
@@ -1402,16 +1339,12 @@ async def test_streamablehttp_request_context_isolation(
|
||||
"Authorization": f"Bearer token-{i}",
|
||||
}
|
||||
|
||||
async with streamablehttp_client(
|
||||
f"{basic_server_url}/mcp", headers=headers
|
||||
) as (read_stream, write_stream, _):
|
||||
async with streamablehttp_client(f"{basic_server_url}/mcp", headers=headers) as (read_stream, write_stream, _):
|
||||
async with ClientSession(read_stream, write_stream) as session:
|
||||
await session.initialize()
|
||||
|
||||
# Call the tool that echoes context
|
||||
tool_result = await session.call_tool(
|
||||
"echo_context", {"request_id": f"request-{i}"}
|
||||
)
|
||||
tool_result = await session.call_tool("echo_context", {"request_id": f"request-{i}"})
|
||||
|
||||
assert len(tool_result.content) == 1
|
||||
assert isinstance(tool_result.content[0], TextContent)
|
||||
|
||||
@@ -54,11 +54,7 @@ class ServerTest(Server):
|
||||
await anyio.sleep(2.0)
|
||||
return f"Slow response from {uri.host}"
|
||||
|
||||
raise McpError(
|
||||
error=ErrorData(
|
||||
code=404, message="OOPS! no resource with that URI was found"
|
||||
)
|
||||
)
|
||||
raise McpError(error=ErrorData(code=404, message="OOPS! no resource with that URI was found"))
|
||||
|
||||
@self.list_tools()
|
||||
async def handle_list_tools() -> list[Tool]:
|
||||
@@ -81,12 +77,8 @@ def make_server_app() -> Starlette:
|
||||
server = ServerTest()
|
||||
|
||||
async def handle_ws(websocket):
|
||||
async with websocket_server(
|
||||
websocket.scope, websocket.receive, websocket.send
|
||||
) as streams:
|
||||
await server.run(
|
||||
streams[0], streams[1], server.create_initialization_options()
|
||||
)
|
||||
async with websocket_server(websocket.scope, websocket.receive, websocket.send) as streams:
|
||||
await server.run(streams[0], streams[1], server.create_initialization_options())
|
||||
|
||||
app = Starlette(
|
||||
routes=[
|
||||
@@ -99,11 +91,7 @@ def make_server_app() -> Starlette:
|
||||
|
||||
def run_server(server_port: int) -> None:
|
||||
app = make_server_app()
|
||||
server = uvicorn.Server(
|
||||
config=uvicorn.Config(
|
||||
app=app, host="127.0.0.1", port=server_port, log_level="error"
|
||||
)
|
||||
)
|
||||
server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error"))
|
||||
print(f"starting server on {server_port}")
|
||||
server.run()
|
||||
|
||||
@@ -115,9 +103,7 @@ def run_server(server_port: int) -> None:
|
||||
|
||||
@pytest.fixture()
|
||||
def server(server_port: int) -> Generator[None, None, None]:
|
||||
proc = multiprocessing.Process(
|
||||
target=run_server, kwargs={"server_port": server_port}, daemon=True
|
||||
)
|
||||
proc = multiprocessing.Process(target=run_server, kwargs={"server_port": server_port}, daemon=True)
|
||||
print("starting process")
|
||||
proc.start()
|
||||
|
||||
@@ -147,9 +133,7 @@ def server(server_port: int) -> Generator[None, None, None]:
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
async def initialized_ws_client_session(
|
||||
server, server_url: str
|
||||
) -> AsyncGenerator[ClientSession, None]:
|
||||
async def initialized_ws_client_session(server, server_url: str) -> AsyncGenerator[ClientSession, None]:
|
||||
"""Create and initialize a WebSocket client session"""
|
||||
async with websocket_client(server_url + "/ws") as streams:
|
||||
async with ClientSession(*streams) as session:
|
||||
@@ -186,9 +170,7 @@ async def test_ws_client_happy_request_and_response(
|
||||
initialized_ws_client_session: ClientSession,
|
||||
) -> None:
|
||||
"""Test a successful request and response via WebSocket"""
|
||||
result = await initialized_ws_client_session.read_resource(
|
||||
AnyUrl("foobar://example")
|
||||
)
|
||||
result = await initialized_ws_client_session.read_resource(AnyUrl("foobar://example"))
|
||||
assert isinstance(result, ReadResourceResult)
|
||||
assert isinstance(result.contents, list)
|
||||
assert len(result.contents) > 0
|
||||
@@ -218,9 +200,7 @@ async def test_ws_client_timeout(
|
||||
|
||||
# Now test that we can still use the session after a timeout
|
||||
with anyio.fail_after(5): # Longer timeout to allow completion
|
||||
result = await initialized_ws_client_session.read_resource(
|
||||
AnyUrl("foobar://example")
|
||||
)
|
||||
result = await initialized_ws_client_session.read_resource(AnyUrl("foobar://example"))
|
||||
assert isinstance(result, ReadResourceResult)
|
||||
assert isinstance(result.contents, list)
|
||||
assert len(result.contents) > 0
|
||||
|
||||
Reference in New Issue
Block a user