feat: add request cancellation and in-flight request tracking

This commit adds support for request cancellation and tracking of
in-flight requests in the MCP protocol implementation. The key
architectural changes are:

1. Request Lifecycle Management:
   - Added _in_flight dictionary to BaseSession to track active requests
   - Requests are tracked from receipt until completion/cancellation
   - Added proper cleanup via on_complete callback

2. Cancellation Support:
   - Added CancelledNotification handling in _receive_loop
   - Implemented cancel() method in RequestResponder
   - Uses anyio.CancelScope for robust cancellation
   - Sends error response on cancellation

3. Request Context:
   - Added request_ctx ContextVar for request context
   - Ensures proper cleanup after request handling
   - Maintains request state throughout lifecycle

4. Error Handling:
   - Improved error propagation for cancelled requests
   - Added proper cleanup of cancelled requests
   - Maintains consistency of in-flight tracking

This change enables clients to cancel long-running requests and
servers to properly clean up resources when requests are cancelled.

Github-Issue:#88
This commit is contained in:
David Soria Parra
2025-01-23 20:10:02 +00:00
parent 888bdd3c34
commit 827e494df4
3 changed files with 190 additions and 13 deletions

View File

@@ -453,7 +453,12 @@ class Server:
logger.debug(f"Received message: {message}") logger.debug(f"Received message: {message}")
match message: match message:
case RequestResponder(request=types.ClientRequest(root=req)): case (
RequestResponder(
request=types.ClientRequest(root=req)
) as responder
):
with responder:
await self._handle_request( await self._handle_request(
message, req, session, raise_exceptions message, req, session, raise_exceptions
) )

View File

@@ -1,6 +1,6 @@
from contextlib import AbstractAsyncContextManager from contextlib import AbstractAsyncContextManager
from datetime import timedelta from datetime import timedelta
from typing import Generic, TypeVar from typing import Any, Callable, Generic, TypeVar
import anyio import anyio
import anyio.lowlevel import anyio.lowlevel
@@ -10,6 +10,7 @@ from pydantic import BaseModel
from mcp.shared.exceptions import McpError from mcp.shared.exceptions import McpError
from mcp.types import ( from mcp.types import (
CancelledNotification,
ClientNotification, ClientNotification,
ClientRequest, ClientRequest,
ClientResult, ClientResult,
@@ -44,21 +45,55 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
request_meta: RequestParams.Meta | None, request_meta: RequestParams.Meta | None,
request: ReceiveRequestT, request: ReceiveRequestT,
session: "BaseSession", session: "BaseSession",
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any],
) -> None: ) -> None:
self.request_id = request_id self.request_id = request_id
self.request_meta = request_meta self.request_meta = request_meta
self.request = request self.request = request
self._session = session self._session = session
self._responded = False self._completed = False
self._cancel_scope = anyio.CancelScope()
self._on_complete = on_complete
def __enter__(self) -> "RequestResponder[ReceiveRequestT, SendResultT]":
self._cancel_scope.__enter__()
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
try:
if self._completed:
self._on_complete(self)
finally:
self._cancel_scope.__exit__(exc_type, exc_val, exc_tb)
async def respond(self, response: SendResultT | ErrorData) -> None: async def respond(self, response: SendResultT | ErrorData) -> None:
assert not self._responded, "Request already responded to" assert not self._completed, "Request already responded to"
self._responded = True
if not self.cancelled:
self._completed = True
await self._session._send_response( await self._session._send_response(
request_id=self.request_id, response=response request_id=self.request_id, response=response
) )
async def cancel(self) -> None:
"""Cancel this request and mark it as completed."""
self._cancel_scope.cancel()
self._completed = True # Mark as completed so it's removed from in_flight
# Send an error response to indicate cancellation
await self._session._send_response(
request_id=self.request_id,
response=ErrorData(code=0, message="Request cancelled", data=None),
)
@property
def in_flight(self) -> bool:
return not self._completed and not self.cancelled
@property
def cancelled(self) -> bool:
return self._cancel_scope is not None and self._cancel_scope.cancel_called
class BaseSession( class BaseSession(
AbstractAsyncContextManager, AbstractAsyncContextManager,
@@ -82,6 +117,7 @@ class BaseSession(
RequestId, MemoryObjectSendStream[JSONRPCResponse | JSONRPCError] RequestId, MemoryObjectSendStream[JSONRPCResponse | JSONRPCError]
] ]
_request_id: int _request_id: int
_in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]]
def __init__( def __init__(
self, self,
@@ -99,6 +135,7 @@ class BaseSession(
self._receive_request_type = receive_request_type self._receive_request_type = receive_request_type
self._receive_notification_type = receive_notification_type self._receive_notification_type = receive_notification_type
self._read_timeout_seconds = read_timeout_seconds self._read_timeout_seconds = read_timeout_seconds
self._in_flight = {}
self._incoming_message_stream_writer, self._incoming_message_stream_reader = ( self._incoming_message_stream_writer, self._incoming_message_stream_reader = (
anyio.create_memory_object_stream[ anyio.create_memory_object_stream[
@@ -219,6 +256,7 @@ class BaseSession(
by_alias=True, mode="json", exclude_none=True by_alias=True, mode="json", exclude_none=True
) )
) )
responder = RequestResponder( responder = RequestResponder(
request_id=message.root.id, request_id=message.root.id,
request_meta=validated_request.root.params.meta request_meta=validated_request.root.params.meta
@@ -226,18 +264,26 @@ class BaseSession(
else None, else None,
request=validated_request, request=validated_request,
session=self, session=self,
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
) )
self._in_flight[responder.request_id] = responder
await self._received_request(responder) await self._received_request(responder)
if not responder._responded: if not responder._completed:
await self._incoming_message_stream_writer.send(responder) await self._incoming_message_stream_writer.send(responder)
elif isinstance(message.root, JSONRPCNotification): elif isinstance(message.root, JSONRPCNotification):
notification = self._receive_notification_type.model_validate( notification = self._receive_notification_type.model_validate(
message.root.model_dump( message.root.model_dump(
by_alias=True, mode="json", exclude_none=True by_alias=True, mode="json", exclude_none=True
) )
) )
# Handle cancellation notifications
if isinstance(notification.root, CancelledNotification):
cancelled_id = notification.root.params.requestId
if cancelled_id in self._in_flight:
await self._in_flight[cancelled_id].cancel()
else:
await self._received_notification(notification) await self._received_notification(notification)
await self._incoming_message_stream_writer.send(notification) await self._incoming_message_stream_writer.send(notification)
else: # Response or error else: # Response or error

View File

@@ -0,0 +1,126 @@
from typing import AsyncGenerator
import anyio
import pytest
import mcp.types as types
from mcp.client.session import ClientSession
from mcp.server.lowlevel.server import Server
from mcp.shared.exceptions import McpError
from mcp.shared.memory import create_connected_server_and_client_session
from mcp.types import (
CancelledNotification,
CancelledNotificationParams,
ClientNotification,
ClientRequest,
EmptyResult,
)
@pytest.fixture
def mcp_server() -> Server:
return Server(name="test server")
@pytest.fixture
async def client_connected_to_server(
mcp_server: Server,
) -> AsyncGenerator[ClientSession, None]:
async with create_connected_server_and_client_session(mcp_server) as client_session:
yield client_session
@pytest.mark.anyio
async def test_in_flight_requests_cleared_after_completion(
client_connected_to_server: ClientSession,
):
"""Verify that _in_flight is empty after all requests complete."""
# Send a request and wait for response
response = await client_connected_to_server.send_ping()
assert isinstance(response, EmptyResult)
# Verify _in_flight is empty
assert len(client_connected_to_server._in_flight) == 0
@pytest.mark.anyio
async def test_request_cancellation():
"""Test that requests can be cancelled while in-flight."""
# The tool is already registered in the fixture
ev_tool_called = anyio.Event()
ev_cancelled = anyio.Event()
request_id = None
# Start the request in a separate task so we can cancel it
def make_server() -> Server:
server = Server(name="TestSessionServer")
# Register the tool handler
@server.call_tool()
async def handle_call_tool(name: str, arguments: dict | None) -> list:
nonlocal request_id, ev_tool_called
if name == "slow_tool":
request_id = server.request_context.request_id
ev_tool_called.set()
await anyio.sleep(10) # Long enough to ensure we can cancel
return []
raise ValueError(f"Unknown tool: {name}")
# Register the tool so it shows up in list_tools
@server.list_tools()
async def handle_list_tools() -> list[types.Tool]:
return [
types.Tool(
name="slow_tool",
description="A slow tool that takes 10 seconds to complete",
inputSchema={},
)
]
return server
async def make_request(client_session):
nonlocal ev_cancelled
try:
await client_session.send_request(
ClientRequest(
types.CallToolRequest(
method="tools/call",
params=types.CallToolRequestParams(
name="slow_tool", arguments={}
),
)
),
types.CallToolResult,
)
pytest.fail("Request should have been cancelled")
except McpError as e:
# Expected - request was cancelled
assert "Request cancelled" in str(e)
ev_cancelled.set()
async with create_connected_server_and_client_session(
make_server()
) as client_session:
async with anyio.create_task_group() as tg:
tg.start_soon(make_request, client_session)
# Wait for the request to be in-flight
with anyio.fail_after(1): # Timeout after 1 second
await ev_tool_called.wait()
# Send cancellation notification
assert request_id is not None
await client_session.send_notification(
ClientNotification(
CancelledNotification(
method="notifications/cancelled",
params=CancelledNotificationParams(requestId=request_id),
)
)
)
# Give cancellation time to process
with anyio.fail_after(1):
await ev_cancelled.wait()