Merge pull request #167 from modelcontextprotocol/davidsp/88-v2

feat: add request cancellation and cleanup
This commit is contained in:
David Soria Parra
2025-02-04 20:23:05 +00:00
committed by GitHub
5 changed files with 365 additions and 29 deletions

View File

@@ -453,10 +453,15 @@ class Server:
logger.debug(f"Received message: {message}")
match message:
case RequestResponder(request=types.ClientRequest(root=req)):
await self._handle_request(
message, req, session, raise_exceptions
)
case (
RequestResponder(
request=types.ClientRequest(root=req)
) as responder
):
with responder:
await self._handle_request(
message, req, session, raise_exceptions
)
case types.ClientNotification(root=notify):
await self._handle_notification(notify)

View File

@@ -126,19 +126,20 @@ class ServerSession(
case types.InitializeRequest(params=params):
self._initialization_state = InitializationState.Initializing
self._client_params = params
await responder.respond(
types.ServerResult(
types.InitializeResult(
protocolVersion=types.LATEST_PROTOCOL_VERSION,
capabilities=self._init_options.capabilities,
serverInfo=types.Implementation(
name=self._init_options.server_name,
version=self._init_options.server_version,
),
instructions=self._init_options.instructions,
with responder:
await responder.respond(
types.ServerResult(
types.InitializeResult(
protocolVersion=types.LATEST_PROTOCOL_VERSION,
capabilities=self._init_options.capabilities,
serverInfo=types.Implementation(
name=self._init_options.server_name,
version=self._init_options.server_version,
),
instructions=self._init_options.instructions,
)
)
)
)
case _:
if self._initialization_state != InitializationState.Initialized:
raise RuntimeError(

View File

@@ -1,6 +1,7 @@
import logging
from contextlib import AbstractAsyncContextManager
from datetime import timedelta
from typing import Generic, TypeVar
from typing import Any, Callable, Generic, TypeVar
import anyio
import anyio.lowlevel
@@ -10,6 +11,7 @@ from pydantic import BaseModel
from mcp.shared.exceptions import McpError
from mcp.types import (
CancelledNotification,
ClientNotification,
ClientRequest,
ClientResult,
@@ -38,27 +40,98 @@ RequestId = str | int
class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
"""Handles responding to MCP requests and manages request lifecycle.
This class MUST be used as a context manager to ensure proper cleanup and
cancellation handling:
Example:
with request_responder as resp:
await resp.respond(result)
The context manager ensures:
1. Proper cancellation scope setup and cleanup
2. Request completion tracking
3. Cleanup of in-flight requests
"""
def __init__(
self,
request_id: RequestId,
request_meta: RequestParams.Meta | None,
request: ReceiveRequestT,
session: "BaseSession",
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any],
) -> None:
self.request_id = request_id
self.request_meta = request_meta
self.request = request
self._session = session
self._responded = False
self._completed = False
self._cancel_scope = anyio.CancelScope()
self._on_complete = on_complete
self._entered = False # Track if we're in a context manager
def __enter__(self) -> "RequestResponder[ReceiveRequestT, SendResultT]":
"""Enter the context manager, enabling request cancellation tracking."""
self._entered = True
self._cancel_scope = anyio.CancelScope()
self._cancel_scope.__enter__()
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
"""Exit the context manager, performing cleanup and notifying completion."""
try:
if self._completed:
self._on_complete(self)
finally:
self._entered = False
if not self._cancel_scope:
raise RuntimeError("No active cancel scope")
self._cancel_scope.__exit__(exc_type, exc_val, exc_tb)
async def respond(self, response: SendResultT | ErrorData) -> None:
assert not self._responded, "Request already responded to"
self._responded = True
"""Send a response for this request.
Must be called within a context manager block.
Raises:
RuntimeError: If not used within a context manager
AssertionError: If request was already responded to
"""
if not self._entered:
raise RuntimeError("RequestResponder must be used as a context manager")
assert not self._completed, "Request already responded to"
if not self.cancelled:
self._completed = True
await self._session._send_response(
request_id=self.request_id, response=response
)
async def cancel(self) -> None:
"""Cancel this request and mark it as completed."""
if not self._entered:
raise RuntimeError("RequestResponder must be used as a context manager")
if not self._cancel_scope:
raise RuntimeError("No active cancel scope")
self._cancel_scope.cancel()
self._completed = True # Mark as completed so it's removed from in_flight
# Send an error response to indicate cancellation
await self._session._send_response(
request_id=self.request_id, response=response
request_id=self.request_id,
response=ErrorData(code=0, message="Request cancelled", data=None),
)
@property
def in_flight(self) -> bool:
return not self._completed and not self.cancelled
@property
def cancelled(self) -> bool:
return self._cancel_scope is not None and self._cancel_scope.cancel_called
class BaseSession(
AbstractAsyncContextManager,
@@ -82,6 +155,7 @@ class BaseSession(
RequestId, MemoryObjectSendStream[JSONRPCResponse | JSONRPCError]
]
_request_id: int
_in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]]
def __init__(
self,
@@ -99,6 +173,7 @@ class BaseSession(
self._receive_request_type = receive_request_type
self._receive_notification_type = receive_notification_type
self._read_timeout_seconds = read_timeout_seconds
self._in_flight = {}
self._incoming_message_stream_writer, self._incoming_message_stream_reader = (
anyio.create_memory_object_stream[
@@ -219,6 +294,7 @@ class BaseSession(
by_alias=True, mode="json", exclude_none=True
)
)
responder = RequestResponder(
request_id=message.root.id,
request_meta=validated_request.root.params.meta
@@ -226,20 +302,37 @@ class BaseSession(
else None,
request=validated_request,
session=self,
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
)
self._in_flight[responder.request_id] = responder
await self._received_request(responder)
if not responder._responded:
if not responder._completed:
await self._incoming_message_stream_writer.send(responder)
elif isinstance(message.root, JSONRPCNotification):
notification = self._receive_notification_type.model_validate(
message.root.model_dump(
by_alias=True, mode="json", exclude_none=True
)
)
await self._received_notification(notification)
await self._incoming_message_stream_writer.send(notification)
elif isinstance(message.root, JSONRPCNotification):
try:
notification = self._receive_notification_type.model_validate(
message.root.model_dump(
by_alias=True, mode="json", exclude_none=True
)
)
# Handle cancellation notifications
if isinstance(notification.root, CancelledNotification):
cancelled_id = notification.root.params.requestId
if cancelled_id in self._in_flight:
await self._in_flight[cancelled_id].cancel()
else:
await self._received_notification(notification)
await self._incoming_message_stream_writer.send(
notification
)
except Exception as e:
# For other validation errors, log and continue
logging.warning(
f"Failed to validate notification: {e}. "
f"Message was: {message.root}"
)
else: # Response or error
stream = self._response_streams.pop(message.root.id, None)
if stream: