Merge pull request #167 from modelcontextprotocol/davidsp/88-v2

feat: add request cancellation and cleanup
This commit is contained in:
David Soria Parra
2025-02-04 20:23:05 +00:00
committed by GitHub
5 changed files with 365 additions and 29 deletions

View File

@@ -453,10 +453,15 @@ class Server:
logger.debug(f"Received message: {message}")
match message:
case RequestResponder(request=types.ClientRequest(root=req)):
await self._handle_request(
message, req, session, raise_exceptions
)
case (
RequestResponder(
request=types.ClientRequest(root=req)
) as responder
):
with responder:
await self._handle_request(
message, req, session, raise_exceptions
)
case types.ClientNotification(root=notify):
await self._handle_notification(notify)

View File

@@ -126,19 +126,20 @@ class ServerSession(
case types.InitializeRequest(params=params):
self._initialization_state = InitializationState.Initializing
self._client_params = params
await responder.respond(
types.ServerResult(
types.InitializeResult(
protocolVersion=types.LATEST_PROTOCOL_VERSION,
capabilities=self._init_options.capabilities,
serverInfo=types.Implementation(
name=self._init_options.server_name,
version=self._init_options.server_version,
),
instructions=self._init_options.instructions,
with responder:
await responder.respond(
types.ServerResult(
types.InitializeResult(
protocolVersion=types.LATEST_PROTOCOL_VERSION,
capabilities=self._init_options.capabilities,
serverInfo=types.Implementation(
name=self._init_options.server_name,
version=self._init_options.server_version,
),
instructions=self._init_options.instructions,
)
)
)
)
case _:
if self._initialization_state != InitializationState.Initialized:
raise RuntimeError(

View File

@@ -1,6 +1,7 @@
import logging
from contextlib import AbstractAsyncContextManager
from datetime import timedelta
from typing import Generic, TypeVar
from typing import Any, Callable, Generic, TypeVar
import anyio
import anyio.lowlevel
@@ -10,6 +11,7 @@ from pydantic import BaseModel
from mcp.shared.exceptions import McpError
from mcp.types import (
CancelledNotification,
ClientNotification,
ClientRequest,
ClientResult,
@@ -38,27 +40,98 @@ RequestId = str | int
class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
"""Handles responding to MCP requests and manages request lifecycle.
This class MUST be used as a context manager to ensure proper cleanup and
cancellation handling:
Example:
with request_responder as resp:
await resp.respond(result)
The context manager ensures:
1. Proper cancellation scope setup and cleanup
2. Request completion tracking
3. Cleanup of in-flight requests
"""
def __init__(
self,
request_id: RequestId,
request_meta: RequestParams.Meta | None,
request: ReceiveRequestT,
session: "BaseSession",
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any],
) -> None:
self.request_id = request_id
self.request_meta = request_meta
self.request = request
self._session = session
self._responded = False
self._completed = False
self._cancel_scope = anyio.CancelScope()
self._on_complete = on_complete
self._entered = False # Track if we're in a context manager
def __enter__(self) -> "RequestResponder[ReceiveRequestT, SendResultT]":
"""Enter the context manager, enabling request cancellation tracking."""
self._entered = True
self._cancel_scope = anyio.CancelScope()
self._cancel_scope.__enter__()
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
"""Exit the context manager, performing cleanup and notifying completion."""
try:
if self._completed:
self._on_complete(self)
finally:
self._entered = False
if not self._cancel_scope:
raise RuntimeError("No active cancel scope")
self._cancel_scope.__exit__(exc_type, exc_val, exc_tb)
async def respond(self, response: SendResultT | ErrorData) -> None:
assert not self._responded, "Request already responded to"
self._responded = True
"""Send a response for this request.
Must be called within a context manager block.
Raises:
RuntimeError: If not used within a context manager
AssertionError: If request was already responded to
"""
if not self._entered:
raise RuntimeError("RequestResponder must be used as a context manager")
assert not self._completed, "Request already responded to"
if not self.cancelled:
self._completed = True
await self._session._send_response(
request_id=self.request_id, response=response
)
async def cancel(self) -> None:
"""Cancel this request and mark it as completed."""
if not self._entered:
raise RuntimeError("RequestResponder must be used as a context manager")
if not self._cancel_scope:
raise RuntimeError("No active cancel scope")
self._cancel_scope.cancel()
self._completed = True # Mark as completed so it's removed from in_flight
# Send an error response to indicate cancellation
await self._session._send_response(
request_id=self.request_id, response=response
request_id=self.request_id,
response=ErrorData(code=0, message="Request cancelled", data=None),
)
@property
def in_flight(self) -> bool:
return not self._completed and not self.cancelled
@property
def cancelled(self) -> bool:
return self._cancel_scope is not None and self._cancel_scope.cancel_called
class BaseSession(
AbstractAsyncContextManager,
@@ -82,6 +155,7 @@ class BaseSession(
RequestId, MemoryObjectSendStream[JSONRPCResponse | JSONRPCError]
]
_request_id: int
_in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]]
def __init__(
self,
@@ -99,6 +173,7 @@ class BaseSession(
self._receive_request_type = receive_request_type
self._receive_notification_type = receive_notification_type
self._read_timeout_seconds = read_timeout_seconds
self._in_flight = {}
self._incoming_message_stream_writer, self._incoming_message_stream_reader = (
anyio.create_memory_object_stream[
@@ -219,6 +294,7 @@ class BaseSession(
by_alias=True, mode="json", exclude_none=True
)
)
responder = RequestResponder(
request_id=message.root.id,
request_meta=validated_request.root.params.meta
@@ -226,20 +302,37 @@ class BaseSession(
else None,
request=validated_request,
session=self,
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
)
self._in_flight[responder.request_id] = responder
await self._received_request(responder)
if not responder._responded:
if not responder._completed:
await self._incoming_message_stream_writer.send(responder)
elif isinstance(message.root, JSONRPCNotification):
notification = self._receive_notification_type.model_validate(
message.root.model_dump(
by_alias=True, mode="json", exclude_none=True
)
)
await self._received_notification(notification)
await self._incoming_message_stream_writer.send(notification)
elif isinstance(message.root, JSONRPCNotification):
try:
notification = self._receive_notification_type.model_validate(
message.root.model_dump(
by_alias=True, mode="json", exclude_none=True
)
)
# Handle cancellation notifications
if isinstance(notification.root, CancelledNotification):
cancelled_id = notification.root.params.requestId
if cancelled_id in self._in_flight:
await self._in_flight[cancelled_id].cancel()
else:
await self._received_notification(notification)
await self._incoming_message_stream_writer.send(
notification
)
except Exception as e:
# For other validation errors, log and continue
logging.warning(
f"Failed to validate notification: {e}. "
f"Message was: {message.root}"
)
else: # Response or error
stream = self._response_streams.pop(message.root.id, None)
if stream:

View File

@@ -0,0 +1,111 @@
"""Test to reproduce issue #88: Random error thrown on response."""
from datetime import timedelta
from pathlib import Path
from typing import Sequence
import anyio
import pytest
from mcp.client.session import ClientSession
from mcp.server.lowlevel import Server
from mcp.shared.exceptions import McpError
from mcp.types import (
EmbeddedResource,
ImageContent,
TextContent,
)
@pytest.mark.anyio
async def test_notification_validation_error(tmp_path: Path):
"""Test that timeouts are handled gracefully and don't break the server.
This test verifies that when a client request times out:
1. The server task stays alive
2. The server can still handle new requests
3. The client can make new requests
4. No resources are leaked
"""
server = Server(name="test")
request_count = 0
slow_request_started = anyio.Event()
slow_request_complete = anyio.Event()
@server.call_tool()
async def slow_tool(
name: str, arg
) -> Sequence[TextContent | ImageContent | EmbeddedResource]:
nonlocal request_count
request_count += 1
if name == "slow":
# Signal that slow request has started
slow_request_started.set()
# Long enough to ensure timeout
await anyio.sleep(0.2)
# Signal completion
slow_request_complete.set()
return [TextContent(type="text", text=f"slow {request_count}")]
elif name == "fast":
# Fast enough to complete before timeout
await anyio.sleep(0.01)
return [TextContent(type="text", text=f"fast {request_count}")]
return [TextContent(type="text", text=f"unknown {request_count}")]
async def server_handler(read_stream, write_stream):
await server.run(
read_stream,
write_stream,
server.create_initialization_options(),
raise_exceptions=True,
)
async def client(read_stream, write_stream):
# Use a timeout that's:
# - Long enough for fast operations (>10ms)
# - Short enough for slow operations (<200ms)
# - Not too short to avoid flakiness
async with ClientSession(
read_stream, write_stream, read_timeout_seconds=timedelta(milliseconds=50)
) as session:
await session.initialize()
# First call should work (fast operation)
result = await session.call_tool("fast")
assert result.content == [TextContent(type="text", text="fast 1")]
assert not slow_request_complete.is_set()
# Second call should timeout (slow operation)
with pytest.raises(McpError) as exc_info:
await session.call_tool("slow")
assert "Timed out while waiting" in str(exc_info.value)
# Wait for slow request to complete in the background
with anyio.fail_after(1): # Timeout after 1 second
await slow_request_complete.wait()
# Third call should work (fast operation),
# proving server is still responsive
result = await session.call_tool("fast")
assert result.content == [TextContent(type="text", text="fast 3")]
# Run server and client in separate task groups to avoid cancellation
server_writer, server_reader = anyio.create_memory_object_stream(1)
client_writer, client_reader = anyio.create_memory_object_stream(1)
server_ready = anyio.Event()
async def wrapped_server_handler(read_stream, write_stream):
server_ready.set()
await server_handler(read_stream, write_stream)
async with anyio.create_task_group() as tg:
tg.start_soon(wrapped_server_handler, server_reader, client_writer)
# Wait for server to start and initialize
with anyio.fail_after(1): # Timeout after 1 second
await server_ready.wait()
# Run client in a separate task to avoid cancellation
async with anyio.create_task_group() as client_tg:
client_tg.start_soon(client, client_reader, server_writer)

View File

@@ -0,0 +1,126 @@
from typing import AsyncGenerator
import anyio
import pytest
import mcp.types as types
from mcp.client.session import ClientSession
from mcp.server.lowlevel.server import Server
from mcp.shared.exceptions import McpError
from mcp.shared.memory import create_connected_server_and_client_session
from mcp.types import (
CancelledNotification,
CancelledNotificationParams,
ClientNotification,
ClientRequest,
EmptyResult,
)
@pytest.fixture
def mcp_server() -> Server:
return Server(name="test server")
@pytest.fixture
async def client_connected_to_server(
mcp_server: Server,
) -> AsyncGenerator[ClientSession, None]:
async with create_connected_server_and_client_session(mcp_server) as client_session:
yield client_session
@pytest.mark.anyio
async def test_in_flight_requests_cleared_after_completion(
client_connected_to_server: ClientSession,
):
"""Verify that _in_flight is empty after all requests complete."""
# Send a request and wait for response
response = await client_connected_to_server.send_ping()
assert isinstance(response, EmptyResult)
# Verify _in_flight is empty
assert len(client_connected_to_server._in_flight) == 0
@pytest.mark.anyio
async def test_request_cancellation():
"""Test that requests can be cancelled while in-flight."""
# The tool is already registered in the fixture
ev_tool_called = anyio.Event()
ev_cancelled = anyio.Event()
request_id = None
# Start the request in a separate task so we can cancel it
def make_server() -> Server:
server = Server(name="TestSessionServer")
# Register the tool handler
@server.call_tool()
async def handle_call_tool(name: str, arguments: dict | None) -> list:
nonlocal request_id, ev_tool_called
if name == "slow_tool":
request_id = server.request_context.request_id
ev_tool_called.set()
await anyio.sleep(10) # Long enough to ensure we can cancel
return []
raise ValueError(f"Unknown tool: {name}")
# Register the tool so it shows up in list_tools
@server.list_tools()
async def handle_list_tools() -> list[types.Tool]:
return [
types.Tool(
name="slow_tool",
description="A slow tool that takes 10 seconds to complete",
inputSchema={},
)
]
return server
async def make_request(client_session):
nonlocal ev_cancelled
try:
await client_session.send_request(
ClientRequest(
types.CallToolRequest(
method="tools/call",
params=types.CallToolRequestParams(
name="slow_tool", arguments={}
),
)
),
types.CallToolResult,
)
pytest.fail("Request should have been cancelled")
except McpError as e:
# Expected - request was cancelled
assert "Request cancelled" in str(e)
ev_cancelled.set()
async with create_connected_server_and_client_session(
make_server()
) as client_session:
async with anyio.create_task_group() as tg:
tg.start_soon(make_request, client_session)
# Wait for the request to be in-flight
with anyio.fail_after(1): # Timeout after 1 second
await ev_tool_called.wait()
# Send cancellation notification
assert request_id is not None
await client_session.send_notification(
ClientNotification(
CancelledNotification(
method="notifications/cancelled",
params=CancelledNotificationParams(requestId=request_id),
)
)
)
# Give cancellation time to process
with anyio.fail_after(1):
await ev_cancelled.wait()