mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 14:54:24 +01:00
Merge pull request #167 from modelcontextprotocol/davidsp/88-v2
feat: add request cancellation and cleanup
This commit is contained in:
@@ -453,7 +453,12 @@ class Server:
|
|||||||
logger.debug(f"Received message: {message}")
|
logger.debug(f"Received message: {message}")
|
||||||
|
|
||||||
match message:
|
match message:
|
||||||
case RequestResponder(request=types.ClientRequest(root=req)):
|
case (
|
||||||
|
RequestResponder(
|
||||||
|
request=types.ClientRequest(root=req)
|
||||||
|
) as responder
|
||||||
|
):
|
||||||
|
with responder:
|
||||||
await self._handle_request(
|
await self._handle_request(
|
||||||
message, req, session, raise_exceptions
|
message, req, session, raise_exceptions
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -126,6 +126,7 @@ class ServerSession(
|
|||||||
case types.InitializeRequest(params=params):
|
case types.InitializeRequest(params=params):
|
||||||
self._initialization_state = InitializationState.Initializing
|
self._initialization_state = InitializationState.Initializing
|
||||||
self._client_params = params
|
self._client_params = params
|
||||||
|
with responder:
|
||||||
await responder.respond(
|
await responder.respond(
|
||||||
types.ServerResult(
|
types.ServerResult(
|
||||||
types.InitializeResult(
|
types.InitializeResult(
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
|
import logging
|
||||||
from contextlib import AbstractAsyncContextManager
|
from contextlib import AbstractAsyncContextManager
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from typing import Generic, TypeVar
|
from typing import Any, Callable, Generic, TypeVar
|
||||||
|
|
||||||
import anyio
|
import anyio
|
||||||
import anyio.lowlevel
|
import anyio.lowlevel
|
||||||
@@ -10,6 +11,7 @@ from pydantic import BaseModel
|
|||||||
|
|
||||||
from mcp.shared.exceptions import McpError
|
from mcp.shared.exceptions import McpError
|
||||||
from mcp.types import (
|
from mcp.types import (
|
||||||
|
CancelledNotification,
|
||||||
ClientNotification,
|
ClientNotification,
|
||||||
ClientRequest,
|
ClientRequest,
|
||||||
ClientResult,
|
ClientResult,
|
||||||
@@ -38,27 +40,98 @@ RequestId = str | int
|
|||||||
|
|
||||||
|
|
||||||
class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
|
class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
|
||||||
|
"""Handles responding to MCP requests and manages request lifecycle.
|
||||||
|
|
||||||
|
This class MUST be used as a context manager to ensure proper cleanup and
|
||||||
|
cancellation handling:
|
||||||
|
|
||||||
|
Example:
|
||||||
|
with request_responder as resp:
|
||||||
|
await resp.respond(result)
|
||||||
|
|
||||||
|
The context manager ensures:
|
||||||
|
1. Proper cancellation scope setup and cleanup
|
||||||
|
2. Request completion tracking
|
||||||
|
3. Cleanup of in-flight requests
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
request_id: RequestId,
|
request_id: RequestId,
|
||||||
request_meta: RequestParams.Meta | None,
|
request_meta: RequestParams.Meta | None,
|
||||||
request: ReceiveRequestT,
|
request: ReceiveRequestT,
|
||||||
session: "BaseSession",
|
session: "BaseSession",
|
||||||
|
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any],
|
||||||
) -> None:
|
) -> None:
|
||||||
self.request_id = request_id
|
self.request_id = request_id
|
||||||
self.request_meta = request_meta
|
self.request_meta = request_meta
|
||||||
self.request = request
|
self.request = request
|
||||||
self._session = session
|
self._session = session
|
||||||
self._responded = False
|
self._completed = False
|
||||||
|
self._cancel_scope = anyio.CancelScope()
|
||||||
|
self._on_complete = on_complete
|
||||||
|
self._entered = False # Track if we're in a context manager
|
||||||
|
|
||||||
|
def __enter__(self) -> "RequestResponder[ReceiveRequestT, SendResultT]":
|
||||||
|
"""Enter the context manager, enabling request cancellation tracking."""
|
||||||
|
self._entered = True
|
||||||
|
self._cancel_scope = anyio.CancelScope()
|
||||||
|
self._cancel_scope.__enter__()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
||||||
|
"""Exit the context manager, performing cleanup and notifying completion."""
|
||||||
|
try:
|
||||||
|
if self._completed:
|
||||||
|
self._on_complete(self)
|
||||||
|
finally:
|
||||||
|
self._entered = False
|
||||||
|
if not self._cancel_scope:
|
||||||
|
raise RuntimeError("No active cancel scope")
|
||||||
|
self._cancel_scope.__exit__(exc_type, exc_val, exc_tb)
|
||||||
|
|
||||||
async def respond(self, response: SendResultT | ErrorData) -> None:
|
async def respond(self, response: SendResultT | ErrorData) -> None:
|
||||||
assert not self._responded, "Request already responded to"
|
"""Send a response for this request.
|
||||||
self._responded = True
|
|
||||||
|
Must be called within a context manager block.
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If not used within a context manager
|
||||||
|
AssertionError: If request was already responded to
|
||||||
|
"""
|
||||||
|
if not self._entered:
|
||||||
|
raise RuntimeError("RequestResponder must be used as a context manager")
|
||||||
|
assert not self._completed, "Request already responded to"
|
||||||
|
|
||||||
|
if not self.cancelled:
|
||||||
|
self._completed = True
|
||||||
|
|
||||||
await self._session._send_response(
|
await self._session._send_response(
|
||||||
request_id=self.request_id, response=response
|
request_id=self.request_id, response=response
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def cancel(self) -> None:
|
||||||
|
"""Cancel this request and mark it as completed."""
|
||||||
|
if not self._entered:
|
||||||
|
raise RuntimeError("RequestResponder must be used as a context manager")
|
||||||
|
if not self._cancel_scope:
|
||||||
|
raise RuntimeError("No active cancel scope")
|
||||||
|
|
||||||
|
self._cancel_scope.cancel()
|
||||||
|
self._completed = True # Mark as completed so it's removed from in_flight
|
||||||
|
# Send an error response to indicate cancellation
|
||||||
|
await self._session._send_response(
|
||||||
|
request_id=self.request_id,
|
||||||
|
response=ErrorData(code=0, message="Request cancelled", data=None),
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def in_flight(self) -> bool:
|
||||||
|
return not self._completed and not self.cancelled
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cancelled(self) -> bool:
|
||||||
|
return self._cancel_scope is not None and self._cancel_scope.cancel_called
|
||||||
|
|
||||||
|
|
||||||
class BaseSession(
|
class BaseSession(
|
||||||
AbstractAsyncContextManager,
|
AbstractAsyncContextManager,
|
||||||
@@ -82,6 +155,7 @@ class BaseSession(
|
|||||||
RequestId, MemoryObjectSendStream[JSONRPCResponse | JSONRPCError]
|
RequestId, MemoryObjectSendStream[JSONRPCResponse | JSONRPCError]
|
||||||
]
|
]
|
||||||
_request_id: int
|
_request_id: int
|
||||||
|
_in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -99,6 +173,7 @@ class BaseSession(
|
|||||||
self._receive_request_type = receive_request_type
|
self._receive_request_type = receive_request_type
|
||||||
self._receive_notification_type = receive_notification_type
|
self._receive_notification_type = receive_notification_type
|
||||||
self._read_timeout_seconds = read_timeout_seconds
|
self._read_timeout_seconds = read_timeout_seconds
|
||||||
|
self._in_flight = {}
|
||||||
|
|
||||||
self._incoming_message_stream_writer, self._incoming_message_stream_reader = (
|
self._incoming_message_stream_writer, self._incoming_message_stream_reader = (
|
||||||
anyio.create_memory_object_stream[
|
anyio.create_memory_object_stream[
|
||||||
@@ -219,6 +294,7 @@ class BaseSession(
|
|||||||
by_alias=True, mode="json", exclude_none=True
|
by_alias=True, mode="json", exclude_none=True
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
responder = RequestResponder(
|
responder = RequestResponder(
|
||||||
request_id=message.root.id,
|
request_id=message.root.id,
|
||||||
request_meta=validated_request.root.params.meta
|
request_meta=validated_request.root.params.meta
|
||||||
@@ -226,20 +302,37 @@ class BaseSession(
|
|||||||
else None,
|
else None,
|
||||||
request=validated_request,
|
request=validated_request,
|
||||||
session=self,
|
session=self,
|
||||||
|
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._in_flight[responder.request_id] = responder
|
||||||
await self._received_request(responder)
|
await self._received_request(responder)
|
||||||
if not responder._responded:
|
if not responder._completed:
|
||||||
await self._incoming_message_stream_writer.send(responder)
|
await self._incoming_message_stream_writer.send(responder)
|
||||||
|
|
||||||
elif isinstance(message.root, JSONRPCNotification):
|
elif isinstance(message.root, JSONRPCNotification):
|
||||||
|
try:
|
||||||
notification = self._receive_notification_type.model_validate(
|
notification = self._receive_notification_type.model_validate(
|
||||||
message.root.model_dump(
|
message.root.model_dump(
|
||||||
by_alias=True, mode="json", exclude_none=True
|
by_alias=True, mode="json", exclude_none=True
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
# Handle cancellation notifications
|
||||||
|
if isinstance(notification.root, CancelledNotification):
|
||||||
|
cancelled_id = notification.root.params.requestId
|
||||||
|
if cancelled_id in self._in_flight:
|
||||||
|
await self._in_flight[cancelled_id].cancel()
|
||||||
|
else:
|
||||||
await self._received_notification(notification)
|
await self._received_notification(notification)
|
||||||
await self._incoming_message_stream_writer.send(notification)
|
await self._incoming_message_stream_writer.send(
|
||||||
|
notification
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
# For other validation errors, log and continue
|
||||||
|
logging.warning(
|
||||||
|
f"Failed to validate notification: {e}. "
|
||||||
|
f"Message was: {message.root}"
|
||||||
|
)
|
||||||
else: # Response or error
|
else: # Response or error
|
||||||
stream = self._response_streams.pop(message.root.id, None)
|
stream = self._response_streams.pop(message.root.id, None)
|
||||||
if stream:
|
if stream:
|
||||||
|
|||||||
111
tests/issues/test_88_random_error.py
Normal file
111
tests/issues/test_88_random_error.py
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
"""Test to reproduce issue #88: Random error thrown on response."""
|
||||||
|
|
||||||
|
from datetime import timedelta
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Sequence
|
||||||
|
|
||||||
|
import anyio
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from mcp.client.session import ClientSession
|
||||||
|
from mcp.server.lowlevel import Server
|
||||||
|
from mcp.shared.exceptions import McpError
|
||||||
|
from mcp.types import (
|
||||||
|
EmbeddedResource,
|
||||||
|
ImageContent,
|
||||||
|
TextContent,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_notification_validation_error(tmp_path: Path):
|
||||||
|
"""Test that timeouts are handled gracefully and don't break the server.
|
||||||
|
|
||||||
|
This test verifies that when a client request times out:
|
||||||
|
1. The server task stays alive
|
||||||
|
2. The server can still handle new requests
|
||||||
|
3. The client can make new requests
|
||||||
|
4. No resources are leaked
|
||||||
|
"""
|
||||||
|
|
||||||
|
server = Server(name="test")
|
||||||
|
request_count = 0
|
||||||
|
slow_request_started = anyio.Event()
|
||||||
|
slow_request_complete = anyio.Event()
|
||||||
|
|
||||||
|
@server.call_tool()
|
||||||
|
async def slow_tool(
|
||||||
|
name: str, arg
|
||||||
|
) -> Sequence[TextContent | ImageContent | EmbeddedResource]:
|
||||||
|
nonlocal request_count
|
||||||
|
request_count += 1
|
||||||
|
|
||||||
|
if name == "slow":
|
||||||
|
# Signal that slow request has started
|
||||||
|
slow_request_started.set()
|
||||||
|
# Long enough to ensure timeout
|
||||||
|
await anyio.sleep(0.2)
|
||||||
|
# Signal completion
|
||||||
|
slow_request_complete.set()
|
||||||
|
return [TextContent(type="text", text=f"slow {request_count}")]
|
||||||
|
elif name == "fast":
|
||||||
|
# Fast enough to complete before timeout
|
||||||
|
await anyio.sleep(0.01)
|
||||||
|
return [TextContent(type="text", text=f"fast {request_count}")]
|
||||||
|
return [TextContent(type="text", text=f"unknown {request_count}")]
|
||||||
|
|
||||||
|
async def server_handler(read_stream, write_stream):
|
||||||
|
await server.run(
|
||||||
|
read_stream,
|
||||||
|
write_stream,
|
||||||
|
server.create_initialization_options(),
|
||||||
|
raise_exceptions=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def client(read_stream, write_stream):
|
||||||
|
# Use a timeout that's:
|
||||||
|
# - Long enough for fast operations (>10ms)
|
||||||
|
# - Short enough for slow operations (<200ms)
|
||||||
|
# - Not too short to avoid flakiness
|
||||||
|
async with ClientSession(
|
||||||
|
read_stream, write_stream, read_timeout_seconds=timedelta(milliseconds=50)
|
||||||
|
) as session:
|
||||||
|
await session.initialize()
|
||||||
|
|
||||||
|
# First call should work (fast operation)
|
||||||
|
result = await session.call_tool("fast")
|
||||||
|
assert result.content == [TextContent(type="text", text="fast 1")]
|
||||||
|
assert not slow_request_complete.is_set()
|
||||||
|
|
||||||
|
# Second call should timeout (slow operation)
|
||||||
|
with pytest.raises(McpError) as exc_info:
|
||||||
|
await session.call_tool("slow")
|
||||||
|
assert "Timed out while waiting" in str(exc_info.value)
|
||||||
|
|
||||||
|
# Wait for slow request to complete in the background
|
||||||
|
with anyio.fail_after(1): # Timeout after 1 second
|
||||||
|
await slow_request_complete.wait()
|
||||||
|
|
||||||
|
# Third call should work (fast operation),
|
||||||
|
# proving server is still responsive
|
||||||
|
result = await session.call_tool("fast")
|
||||||
|
assert result.content == [TextContent(type="text", text="fast 3")]
|
||||||
|
|
||||||
|
# Run server and client in separate task groups to avoid cancellation
|
||||||
|
server_writer, server_reader = anyio.create_memory_object_stream(1)
|
||||||
|
client_writer, client_reader = anyio.create_memory_object_stream(1)
|
||||||
|
|
||||||
|
server_ready = anyio.Event()
|
||||||
|
|
||||||
|
async def wrapped_server_handler(read_stream, write_stream):
|
||||||
|
server_ready.set()
|
||||||
|
await server_handler(read_stream, write_stream)
|
||||||
|
|
||||||
|
async with anyio.create_task_group() as tg:
|
||||||
|
tg.start_soon(wrapped_server_handler, server_reader, client_writer)
|
||||||
|
# Wait for server to start and initialize
|
||||||
|
with anyio.fail_after(1): # Timeout after 1 second
|
||||||
|
await server_ready.wait()
|
||||||
|
# Run client in a separate task to avoid cancellation
|
||||||
|
async with anyio.create_task_group() as client_tg:
|
||||||
|
client_tg.start_soon(client, client_reader, server_writer)
|
||||||
126
tests/shared/test_session.py
Normal file
126
tests/shared/test_session.py
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
|
import anyio
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import mcp.types as types
|
||||||
|
from mcp.client.session import ClientSession
|
||||||
|
from mcp.server.lowlevel.server import Server
|
||||||
|
from mcp.shared.exceptions import McpError
|
||||||
|
from mcp.shared.memory import create_connected_server_and_client_session
|
||||||
|
from mcp.types import (
|
||||||
|
CancelledNotification,
|
||||||
|
CancelledNotificationParams,
|
||||||
|
ClientNotification,
|
||||||
|
ClientRequest,
|
||||||
|
EmptyResult,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mcp_server() -> Server:
|
||||||
|
return Server(name="test server")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def client_connected_to_server(
|
||||||
|
mcp_server: Server,
|
||||||
|
) -> AsyncGenerator[ClientSession, None]:
|
||||||
|
async with create_connected_server_and_client_session(mcp_server) as client_session:
|
||||||
|
yield client_session
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_in_flight_requests_cleared_after_completion(
|
||||||
|
client_connected_to_server: ClientSession,
|
||||||
|
):
|
||||||
|
"""Verify that _in_flight is empty after all requests complete."""
|
||||||
|
# Send a request and wait for response
|
||||||
|
response = await client_connected_to_server.send_ping()
|
||||||
|
assert isinstance(response, EmptyResult)
|
||||||
|
|
||||||
|
# Verify _in_flight is empty
|
||||||
|
assert len(client_connected_to_server._in_flight) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_request_cancellation():
|
||||||
|
"""Test that requests can be cancelled while in-flight."""
|
||||||
|
# The tool is already registered in the fixture
|
||||||
|
|
||||||
|
ev_tool_called = anyio.Event()
|
||||||
|
ev_cancelled = anyio.Event()
|
||||||
|
request_id = None
|
||||||
|
|
||||||
|
# Start the request in a separate task so we can cancel it
|
||||||
|
def make_server() -> Server:
|
||||||
|
server = Server(name="TestSessionServer")
|
||||||
|
|
||||||
|
# Register the tool handler
|
||||||
|
@server.call_tool()
|
||||||
|
async def handle_call_tool(name: str, arguments: dict | None) -> list:
|
||||||
|
nonlocal request_id, ev_tool_called
|
||||||
|
if name == "slow_tool":
|
||||||
|
request_id = server.request_context.request_id
|
||||||
|
ev_tool_called.set()
|
||||||
|
await anyio.sleep(10) # Long enough to ensure we can cancel
|
||||||
|
return []
|
||||||
|
raise ValueError(f"Unknown tool: {name}")
|
||||||
|
|
||||||
|
# Register the tool so it shows up in list_tools
|
||||||
|
@server.list_tools()
|
||||||
|
async def handle_list_tools() -> list[types.Tool]:
|
||||||
|
return [
|
||||||
|
types.Tool(
|
||||||
|
name="slow_tool",
|
||||||
|
description="A slow tool that takes 10 seconds to complete",
|
||||||
|
inputSchema={},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
return server
|
||||||
|
|
||||||
|
async def make_request(client_session):
|
||||||
|
nonlocal ev_cancelled
|
||||||
|
try:
|
||||||
|
await client_session.send_request(
|
||||||
|
ClientRequest(
|
||||||
|
types.CallToolRequest(
|
||||||
|
method="tools/call",
|
||||||
|
params=types.CallToolRequestParams(
|
||||||
|
name="slow_tool", arguments={}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
types.CallToolResult,
|
||||||
|
)
|
||||||
|
pytest.fail("Request should have been cancelled")
|
||||||
|
except McpError as e:
|
||||||
|
# Expected - request was cancelled
|
||||||
|
assert "Request cancelled" in str(e)
|
||||||
|
ev_cancelled.set()
|
||||||
|
|
||||||
|
async with create_connected_server_and_client_session(
|
||||||
|
make_server()
|
||||||
|
) as client_session:
|
||||||
|
async with anyio.create_task_group() as tg:
|
||||||
|
tg.start_soon(make_request, client_session)
|
||||||
|
|
||||||
|
# Wait for the request to be in-flight
|
||||||
|
with anyio.fail_after(1): # Timeout after 1 second
|
||||||
|
await ev_tool_called.wait()
|
||||||
|
|
||||||
|
# Send cancellation notification
|
||||||
|
assert request_id is not None
|
||||||
|
await client_session.send_notification(
|
||||||
|
ClientNotification(
|
||||||
|
CancelledNotification(
|
||||||
|
method="notifications/cancelled",
|
||||||
|
params=CancelledNotificationParams(requestId=request_id),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Give cancellation time to process
|
||||||
|
with anyio.fail_after(1):
|
||||||
|
await ev_cancelled.wait()
|
||||||
Reference in New Issue
Block a user