diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index b71b372..d918b98 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -126,19 +126,20 @@ class ServerSession( case types.InitializeRequest(params=params): self._initialization_state = InitializationState.Initializing self._client_params = params - await responder.respond( - types.ServerResult( - types.InitializeResult( - protocolVersion=types.LATEST_PROTOCOL_VERSION, - capabilities=self._init_options.capabilities, - serverInfo=types.Implementation( - name=self._init_options.server_name, - version=self._init_options.server_version, - ), - instructions=self._init_options.instructions, + with responder: + await responder.respond( + types.ServerResult( + types.InitializeResult( + protocolVersion=types.LATEST_PROTOCOL_VERSION, + capabilities=self._init_options.capabilities, + serverInfo=types.Implementation( + name=self._init_options.server_name, + version=self._init_options.server_version, + ), + instructions=self._init_options.instructions, + ) ) ) - ) case _: if self._initialization_state != InitializationState.Initialized: raise RuntimeError( diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index e21bcbc..3d3988c 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -40,6 +40,21 @@ RequestId = str | int class RequestResponder(Generic[ReceiveRequestT, SendResultT]): + """Handles responding to MCP requests and manages request lifecycle. + + This class MUST be used as a context manager to ensure proper cleanup and + cancellation handling: + + Example: + with request_responder as resp: + await resp.respond(result) + + The context manager ensures: + 1. Proper cancellation scope setup and cleanup + 2. Request completion tracking + 3. Cleanup of in-flight requests + """ + def __init__( self, request_id: RequestId, @@ -55,19 +70,36 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]): self._completed = False self._cancel_scope = anyio.CancelScope() self._on_complete = on_complete + self._entered = False # Track if we're in a context manager def __enter__(self) -> "RequestResponder[ReceiveRequestT, SendResultT]": + """Enter the context manager, enabling request cancellation tracking.""" + self._entered = True + self._cancel_scope = anyio.CancelScope() self._cancel_scope.__enter__() return self def __exit__(self, exc_type, exc_val, exc_tb) -> None: + """Exit the context manager, performing cleanup and notifying completion.""" try: if self._completed: self._on_complete(self) finally: + self._entered = False + if not self._cancel_scope: + raise RuntimeError("No active cancel scope") self._cancel_scope.__exit__(exc_type, exc_val, exc_tb) async def respond(self, response: SendResultT | ErrorData) -> None: + """Send a response for this request. + + Must be called within a context manager block. + Raises: + RuntimeError: If not used within a context manager + AssertionError: If request was already responded to + """ + if not self._entered: + raise RuntimeError("RequestResponder must be used as a context manager") assert not self._completed, "Request already responded to" if not self.cancelled: @@ -79,6 +111,11 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]): async def cancel(self) -> None: """Cancel this request and mark it as completed.""" + if not self._entered: + raise RuntimeError("RequestResponder must be used as a context manager") + if not self._cancel_scope: + raise RuntimeError("No active cancel scope") + self._cancel_scope.cancel() self._completed = True # Mark as completed so it's removed from in_flight # Send an error response to indicate cancellation