mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 14:54:24 +01:00
feat: add request cancellation and in-flight request tracking
This commit adds support for request cancellation and tracking of in-flight requests in the MCP protocol implementation. The key architectural changes are: 1. Request Lifecycle Management: - Added _in_flight dictionary to BaseSession to track active requests - Requests are tracked from receipt until completion/cancellation - Added proper cleanup via on_complete callback 2. Cancellation Support: - Added CancelledNotification handling in _receive_loop - Implemented cancel() method in RequestResponder - Uses anyio.CancelScope for robust cancellation - Sends error response on cancellation 3. Request Context: - Added request_ctx ContextVar for request context - Ensures proper cleanup after request handling - Maintains request state throughout lifecycle 4. Error Handling: - Improved error propagation for cancelled requests - Added proper cleanup of cancelled requests - Maintains consistency of in-flight tracking This change enables clients to cancel long-running requests and servers to properly clean up resources when requests are cancelled. Github-Issue:#88
This commit is contained in:
126
tests/shared/test_session.py
Normal file
126
tests/shared/test_session.py
Normal file
@@ -0,0 +1,126 @@
|
||||
from typing import AsyncGenerator
|
||||
|
||||
import anyio
|
||||
import pytest
|
||||
|
||||
import mcp.types as types
|
||||
from mcp.client.session import ClientSession
|
||||
from mcp.server.lowlevel.server import Server
|
||||
from mcp.shared.exceptions import McpError
|
||||
from mcp.shared.memory import create_connected_server_and_client_session
|
||||
from mcp.types import (
|
||||
CancelledNotification,
|
||||
CancelledNotificationParams,
|
||||
ClientNotification,
|
||||
ClientRequest,
|
||||
EmptyResult,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mcp_server() -> Server:
|
||||
return Server(name="test server")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def client_connected_to_server(
|
||||
mcp_server: Server,
|
||||
) -> AsyncGenerator[ClientSession, None]:
|
||||
async with create_connected_server_and_client_session(mcp_server) as client_session:
|
||||
yield client_session
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_in_flight_requests_cleared_after_completion(
|
||||
client_connected_to_server: ClientSession,
|
||||
):
|
||||
"""Verify that _in_flight is empty after all requests complete."""
|
||||
# Send a request and wait for response
|
||||
response = await client_connected_to_server.send_ping()
|
||||
assert isinstance(response, EmptyResult)
|
||||
|
||||
# Verify _in_flight is empty
|
||||
assert len(client_connected_to_server._in_flight) == 0
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_request_cancellation():
|
||||
"""Test that requests can be cancelled while in-flight."""
|
||||
# The tool is already registered in the fixture
|
||||
|
||||
ev_tool_called = anyio.Event()
|
||||
ev_cancelled = anyio.Event()
|
||||
request_id = None
|
||||
|
||||
# Start the request in a separate task so we can cancel it
|
||||
def make_server() -> Server:
|
||||
server = Server(name="TestSessionServer")
|
||||
|
||||
# Register the tool handler
|
||||
@server.call_tool()
|
||||
async def handle_call_tool(name: str, arguments: dict | None) -> list:
|
||||
nonlocal request_id, ev_tool_called
|
||||
if name == "slow_tool":
|
||||
request_id = server.request_context.request_id
|
||||
ev_tool_called.set()
|
||||
await anyio.sleep(10) # Long enough to ensure we can cancel
|
||||
return []
|
||||
raise ValueError(f"Unknown tool: {name}")
|
||||
|
||||
# Register the tool so it shows up in list_tools
|
||||
@server.list_tools()
|
||||
async def handle_list_tools() -> list[types.Tool]:
|
||||
return [
|
||||
types.Tool(
|
||||
name="slow_tool",
|
||||
description="A slow tool that takes 10 seconds to complete",
|
||||
inputSchema={},
|
||||
)
|
||||
]
|
||||
|
||||
return server
|
||||
|
||||
async def make_request(client_session):
|
||||
nonlocal ev_cancelled
|
||||
try:
|
||||
await client_session.send_request(
|
||||
ClientRequest(
|
||||
types.CallToolRequest(
|
||||
method="tools/call",
|
||||
params=types.CallToolRequestParams(
|
||||
name="slow_tool", arguments={}
|
||||
),
|
||||
)
|
||||
),
|
||||
types.CallToolResult,
|
||||
)
|
||||
pytest.fail("Request should have been cancelled")
|
||||
except McpError as e:
|
||||
# Expected - request was cancelled
|
||||
assert "Request cancelled" in str(e)
|
||||
ev_cancelled.set()
|
||||
|
||||
async with create_connected_server_and_client_session(
|
||||
make_server()
|
||||
) as client_session:
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(make_request, client_session)
|
||||
|
||||
# Wait for the request to be in-flight
|
||||
with anyio.fail_after(1): # Timeout after 1 second
|
||||
await ev_tool_called.wait()
|
||||
|
||||
# Send cancellation notification
|
||||
assert request_id is not None
|
||||
await client_session.send_notification(
|
||||
ClientNotification(
|
||||
CancelledNotification(
|
||||
method="notifications/cancelled",
|
||||
params=CancelledNotificationParams(requestId=request_id),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Give cancellation time to process
|
||||
with anyio.fail_after(1):
|
||||
await ev_cancelled.wait()
|
||||
Reference in New Issue
Block a user