mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 14:54:24 +01:00
fix: enforce context manager usage for RequestResponder
This commit is contained in:
@@ -126,19 +126,20 @@ class ServerSession(
|
|||||||
case types.InitializeRequest(params=params):
|
case types.InitializeRequest(params=params):
|
||||||
self._initialization_state = InitializationState.Initializing
|
self._initialization_state = InitializationState.Initializing
|
||||||
self._client_params = params
|
self._client_params = params
|
||||||
await responder.respond(
|
with responder:
|
||||||
types.ServerResult(
|
await responder.respond(
|
||||||
types.InitializeResult(
|
types.ServerResult(
|
||||||
protocolVersion=types.LATEST_PROTOCOL_VERSION,
|
types.InitializeResult(
|
||||||
capabilities=self._init_options.capabilities,
|
protocolVersion=types.LATEST_PROTOCOL_VERSION,
|
||||||
serverInfo=types.Implementation(
|
capabilities=self._init_options.capabilities,
|
||||||
name=self._init_options.server_name,
|
serverInfo=types.Implementation(
|
||||||
version=self._init_options.server_version,
|
name=self._init_options.server_name,
|
||||||
),
|
version=self._init_options.server_version,
|
||||||
instructions=self._init_options.instructions,
|
),
|
||||||
|
instructions=self._init_options.instructions,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
case _:
|
case _:
|
||||||
if self._initialization_state != InitializationState.Initialized:
|
if self._initialization_state != InitializationState.Initialized:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
|
|||||||
@@ -40,6 +40,21 @@ RequestId = str | int
|
|||||||
|
|
||||||
|
|
||||||
class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
|
class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
|
||||||
|
"""Handles responding to MCP requests and manages request lifecycle.
|
||||||
|
|
||||||
|
This class MUST be used as a context manager to ensure proper cleanup and
|
||||||
|
cancellation handling:
|
||||||
|
|
||||||
|
Example:
|
||||||
|
with request_responder as resp:
|
||||||
|
await resp.respond(result)
|
||||||
|
|
||||||
|
The context manager ensures:
|
||||||
|
1. Proper cancellation scope setup and cleanup
|
||||||
|
2. Request completion tracking
|
||||||
|
3. Cleanup of in-flight requests
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
request_id: RequestId,
|
request_id: RequestId,
|
||||||
@@ -55,19 +70,36 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
|
|||||||
self._completed = False
|
self._completed = False
|
||||||
self._cancel_scope = anyio.CancelScope()
|
self._cancel_scope = anyio.CancelScope()
|
||||||
self._on_complete = on_complete
|
self._on_complete = on_complete
|
||||||
|
self._entered = False # Track if we're in a context manager
|
||||||
|
|
||||||
def __enter__(self) -> "RequestResponder[ReceiveRequestT, SendResultT]":
|
def __enter__(self) -> "RequestResponder[ReceiveRequestT, SendResultT]":
|
||||||
|
"""Enter the context manager, enabling request cancellation tracking."""
|
||||||
|
self._entered = True
|
||||||
|
self._cancel_scope = anyio.CancelScope()
|
||||||
self._cancel_scope.__enter__()
|
self._cancel_scope.__enter__()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
||||||
|
"""Exit the context manager, performing cleanup and notifying completion."""
|
||||||
try:
|
try:
|
||||||
if self._completed:
|
if self._completed:
|
||||||
self._on_complete(self)
|
self._on_complete(self)
|
||||||
finally:
|
finally:
|
||||||
|
self._entered = False
|
||||||
|
if not self._cancel_scope:
|
||||||
|
raise RuntimeError("No active cancel scope")
|
||||||
self._cancel_scope.__exit__(exc_type, exc_val, exc_tb)
|
self._cancel_scope.__exit__(exc_type, exc_val, exc_tb)
|
||||||
|
|
||||||
async def respond(self, response: SendResultT | ErrorData) -> None:
|
async def respond(self, response: SendResultT | ErrorData) -> None:
|
||||||
|
"""Send a response for this request.
|
||||||
|
|
||||||
|
Must be called within a context manager block.
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If not used within a context manager
|
||||||
|
AssertionError: If request was already responded to
|
||||||
|
"""
|
||||||
|
if not self._entered:
|
||||||
|
raise RuntimeError("RequestResponder must be used as a context manager")
|
||||||
assert not self._completed, "Request already responded to"
|
assert not self._completed, "Request already responded to"
|
||||||
|
|
||||||
if not self.cancelled:
|
if not self.cancelled:
|
||||||
@@ -79,6 +111,11 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
|
|||||||
|
|
||||||
async def cancel(self) -> None:
|
async def cancel(self) -> None:
|
||||||
"""Cancel this request and mark it as completed."""
|
"""Cancel this request and mark it as completed."""
|
||||||
|
if not self._entered:
|
||||||
|
raise RuntimeError("RequestResponder must be used as a context manager")
|
||||||
|
if not self._cancel_scope:
|
||||||
|
raise RuntimeError("No active cancel scope")
|
||||||
|
|
||||||
self._cancel_scope.cancel()
|
self._cancel_scope.cancel()
|
||||||
self._completed = True # Mark as completed so it's removed from in_flight
|
self._completed = True # Mark as completed so it's removed from in_flight
|
||||||
# Send an error response to indicate cancellation
|
# Send an error response to indicate cancellation
|
||||||
|
|||||||
Reference in New Issue
Block a user