fix: enforce context manager usage for RequestResponder

This commit is contained in:
David Soria Parra
2025-02-04 19:29:12 +00:00
parent 08cfbe522a
commit 733db0c9cf
2 changed files with 49 additions and 11 deletions

View File

@@ -126,6 +126,7 @@ class ServerSession(
case types.InitializeRequest(params=params): case types.InitializeRequest(params=params):
self._initialization_state = InitializationState.Initializing self._initialization_state = InitializationState.Initializing
self._client_params = params self._client_params = params
with responder:
await responder.respond( await responder.respond(
types.ServerResult( types.ServerResult(
types.InitializeResult( types.InitializeResult(

View File

@@ -40,6 +40,21 @@ RequestId = str | int
class RequestResponder(Generic[ReceiveRequestT, SendResultT]): class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
"""Handles responding to MCP requests and manages request lifecycle.
This class MUST be used as a context manager to ensure proper cleanup and
cancellation handling:
Example:
with request_responder as resp:
await resp.respond(result)
The context manager ensures:
1. Proper cancellation scope setup and cleanup
2. Request completion tracking
3. Cleanup of in-flight requests
"""
def __init__( def __init__(
self, self,
request_id: RequestId, request_id: RequestId,
@@ -55,19 +70,36 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
self._completed = False self._completed = False
self._cancel_scope = anyio.CancelScope() self._cancel_scope = anyio.CancelScope()
self._on_complete = on_complete self._on_complete = on_complete
self._entered = False # Track if we're in a context manager
def __enter__(self) -> "RequestResponder[ReceiveRequestT, SendResultT]": def __enter__(self) -> "RequestResponder[ReceiveRequestT, SendResultT]":
"""Enter the context manager, enabling request cancellation tracking."""
self._entered = True
self._cancel_scope = anyio.CancelScope()
self._cancel_scope.__enter__() self._cancel_scope.__enter__()
return self return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None: def __exit__(self, exc_type, exc_val, exc_tb) -> None:
"""Exit the context manager, performing cleanup and notifying completion."""
try: try:
if self._completed: if self._completed:
self._on_complete(self) self._on_complete(self)
finally: finally:
self._entered = False
if not self._cancel_scope:
raise RuntimeError("No active cancel scope")
self._cancel_scope.__exit__(exc_type, exc_val, exc_tb) self._cancel_scope.__exit__(exc_type, exc_val, exc_tb)
async def respond(self, response: SendResultT | ErrorData) -> None: async def respond(self, response: SendResultT | ErrorData) -> None:
"""Send a response for this request.
Must be called within a context manager block.
Raises:
RuntimeError: If not used within a context manager
AssertionError: If request was already responded to
"""
if not self._entered:
raise RuntimeError("RequestResponder must be used as a context manager")
assert not self._completed, "Request already responded to" assert not self._completed, "Request already responded to"
if not self.cancelled: if not self.cancelled:
@@ -79,6 +111,11 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
async def cancel(self) -> None: async def cancel(self) -> None:
"""Cancel this request and mark it as completed.""" """Cancel this request and mark it as completed."""
if not self._entered:
raise RuntimeError("RequestResponder must be used as a context manager")
if not self._cancel_scope:
raise RuntimeError("No active cancel scope")
self._cancel_scope.cancel() self._cancel_scope.cancel()
self._completed = True # Mark as completed so it's removed from in_flight self._completed = True # Mark as completed so it's removed from in_flight
# Send an error response to indicate cancellation # Send an error response to indicate cancellation