mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 14:54:24 +01:00
feat: add request cancellation and in-flight request tracking
This commit adds support for request cancellation and tracking of in-flight requests in the MCP protocol implementation. The key architectural changes are: 1. Request Lifecycle Management: - Added _in_flight dictionary to BaseSession to track active requests - Requests are tracked from receipt until completion/cancellation - Added proper cleanup via on_complete callback 2. Cancellation Support: - Added CancelledNotification handling in _receive_loop - Implemented cancel() method in RequestResponder - Uses anyio.CancelScope for robust cancellation - Sends error response on cancellation 3. Request Context: - Added request_ctx ContextVar for request context - Ensures proper cleanup after request handling - Maintains request state throughout lifecycle 4. Error Handling: - Improved error propagation for cancelled requests - Added proper cleanup of cancelled requests - Maintains consistency of in-flight tracking This change enables clients to cancel long-running requests and servers to properly clean up resources when requests are cancelled. Github-Issue:#88
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
from contextlib import AbstractAsyncContextManager
|
||||
from datetime import timedelta
|
||||
from typing import Generic, TypeVar
|
||||
from typing import Any, Callable, Generic, TypeVar
|
||||
|
||||
import anyio
|
||||
import anyio.lowlevel
|
||||
@@ -10,6 +10,7 @@ from pydantic import BaseModel
|
||||
|
||||
from mcp.shared.exceptions import McpError
|
||||
from mcp.types import (
|
||||
CancelledNotification,
|
||||
ClientNotification,
|
||||
ClientRequest,
|
||||
ClientResult,
|
||||
@@ -44,21 +45,55 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
|
||||
request_meta: RequestParams.Meta | None,
|
||||
request: ReceiveRequestT,
|
||||
session: "BaseSession",
|
||||
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any],
|
||||
) -> None:
|
||||
self.request_id = request_id
|
||||
self.request_meta = request_meta
|
||||
self.request = request
|
||||
self._session = session
|
||||
self._responded = False
|
||||
self._completed = False
|
||||
self._cancel_scope = anyio.CancelScope()
|
||||
self._on_complete = on_complete
|
||||
|
||||
def __enter__(self) -> "RequestResponder[ReceiveRequestT, SendResultT]":
|
||||
self._cancel_scope.__enter__()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
||||
try:
|
||||
if self._completed:
|
||||
self._on_complete(self)
|
||||
finally:
|
||||
self._cancel_scope.__exit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
async def respond(self, response: SendResultT | ErrorData) -> None:
|
||||
assert not self._responded, "Request already responded to"
|
||||
self._responded = True
|
||||
assert not self._completed, "Request already responded to"
|
||||
|
||||
if not self.cancelled:
|
||||
self._completed = True
|
||||
|
||||
await self._session._send_response(
|
||||
request_id=self.request_id, response=response
|
||||
)
|
||||
|
||||
async def cancel(self) -> None:
|
||||
"""Cancel this request and mark it as completed."""
|
||||
self._cancel_scope.cancel()
|
||||
self._completed = True # Mark as completed so it's removed from in_flight
|
||||
# Send an error response to indicate cancellation
|
||||
await self._session._send_response(
|
||||
request_id=self.request_id, response=response
|
||||
request_id=self.request_id,
|
||||
response=ErrorData(code=0, message="Request cancelled", data=None),
|
||||
)
|
||||
|
||||
@property
|
||||
def in_flight(self) -> bool:
|
||||
return not self._completed and not self.cancelled
|
||||
|
||||
@property
|
||||
def cancelled(self) -> bool:
|
||||
return self._cancel_scope is not None and self._cancel_scope.cancel_called
|
||||
|
||||
|
||||
class BaseSession(
|
||||
AbstractAsyncContextManager,
|
||||
@@ -82,6 +117,7 @@ class BaseSession(
|
||||
RequestId, MemoryObjectSendStream[JSONRPCResponse | JSONRPCError]
|
||||
]
|
||||
_request_id: int
|
||||
_in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -99,6 +135,7 @@ class BaseSession(
|
||||
self._receive_request_type = receive_request_type
|
||||
self._receive_notification_type = receive_notification_type
|
||||
self._read_timeout_seconds = read_timeout_seconds
|
||||
self._in_flight = {}
|
||||
|
||||
self._incoming_message_stream_writer, self._incoming_message_stream_reader = (
|
||||
anyio.create_memory_object_stream[
|
||||
@@ -219,6 +256,7 @@ class BaseSession(
|
||||
by_alias=True, mode="json", exclude_none=True
|
||||
)
|
||||
)
|
||||
|
||||
responder = RequestResponder(
|
||||
request_id=message.root.id,
|
||||
request_meta=validated_request.root.params.meta
|
||||
@@ -226,20 +264,28 @@ class BaseSession(
|
||||
else None,
|
||||
request=validated_request,
|
||||
session=self,
|
||||
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
|
||||
)
|
||||
|
||||
self._in_flight[responder.request_id] = responder
|
||||
await self._received_request(responder)
|
||||
if not responder._responded:
|
||||
if not responder._completed:
|
||||
await self._incoming_message_stream_writer.send(responder)
|
||||
|
||||
elif isinstance(message.root, JSONRPCNotification):
|
||||
notification = self._receive_notification_type.model_validate(
|
||||
message.root.model_dump(
|
||||
by_alias=True, mode="json", exclude_none=True
|
||||
)
|
||||
)
|
||||
|
||||
await self._received_notification(notification)
|
||||
await self._incoming_message_stream_writer.send(notification)
|
||||
# Handle cancellation notifications
|
||||
if isinstance(notification.root, CancelledNotification):
|
||||
cancelled_id = notification.root.params.requestId
|
||||
if cancelled_id in self._in_flight:
|
||||
await self._in_flight[cancelled_id].cancel()
|
||||
else:
|
||||
await self._received_notification(notification)
|
||||
await self._incoming_message_stream_writer.send(notification)
|
||||
else: # Response or error
|
||||
stream = self._response_streams.pop(message.root.id, None)
|
||||
if stream:
|
||||
|
||||
Reference in New Issue
Block a user