Fix uncaught exception in MCP server (#967)

This commit is contained in:
David Dworken
2025-06-16 17:04:51 -07:00
committed by GitHub
parent 1eb1bba83c
commit 7b420656de
2 changed files with 134 additions and 80 deletions

View File

@@ -333,90 +333,107 @@ class BaseSession(
self._read_stream,
self._write_stream,
):
async for message in self._read_stream:
if isinstance(message, Exception):
await self._handle_incoming(message)
elif isinstance(message.message.root, JSONRPCRequest):
try:
validated_request = self._receive_request_type.model_validate(
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
)
responder = RequestResponder(
request_id=message.message.root.id,
request_meta=validated_request.root.params.meta if validated_request.root.params else None,
request=validated_request,
session=self,
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
message_metadata=message.metadata,
)
self._in_flight[responder.request_id] = responder
await self._received_request(responder)
try:
async for message in self._read_stream:
if isinstance(message, Exception):
await self._handle_incoming(message)
elif isinstance(message.message.root, JSONRPCRequest):
try:
validated_request = self._receive_request_type.model_validate(
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
)
responder = RequestResponder(
request_id=message.message.root.id,
request_meta=validated_request.root.params.meta
if validated_request.root.params
else None,
request=validated_request,
session=self,
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
message_metadata=message.metadata,
)
self._in_flight[responder.request_id] = responder
await self._received_request(responder)
if not responder._completed: # type: ignore[reportPrivateUsage]
await self._handle_incoming(responder)
except Exception as e:
# For request validation errors, send a proper JSON-RPC error
# response instead of crashing the server
logging.warning(f"Failed to validate request: {e}")
logging.debug(f"Message that failed validation: {message.message.root}")
error_response = JSONRPCError(
jsonrpc="2.0",
id=message.message.root.id,
error=ErrorData(
code=INVALID_PARAMS,
message="Invalid request parameters",
data="",
),
)
session_message = SessionMessage(message=JSONRPCMessage(error_response))
await self._write_stream.send(session_message)
if not responder._completed: # type: ignore[reportPrivateUsage]
await self._handle_incoming(responder)
except Exception as e:
# For request validation errors, send a proper JSON-RPC error
# response instead of crashing the server
logging.warning(f"Failed to validate request: {e}")
logging.debug(f"Message that failed validation: {message.message.root}")
error_response = JSONRPCError(
jsonrpc="2.0",
id=message.message.root.id,
error=ErrorData(
code=INVALID_PARAMS,
message="Invalid request parameters",
data="",
),
)
session_message = SessionMessage(message=JSONRPCMessage(error_response))
await self._write_stream.send(session_message)
elif isinstance(message.message.root, JSONRPCNotification):
try:
notification = self._receive_notification_type.model_validate(
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
)
# Handle cancellation notifications
if isinstance(notification.root, CancelledNotification):
cancelled_id = notification.root.params.requestId
if cancelled_id in self._in_flight:
await self._in_flight[cancelled_id].cancel()
elif isinstance(message.message.root, JSONRPCNotification):
try:
notification = self._receive_notification_type.model_validate(
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
)
# Handle cancellation notifications
if isinstance(notification.root, CancelledNotification):
cancelled_id = notification.root.params.requestId
if cancelled_id in self._in_flight:
await self._in_flight[cancelled_id].cancel()
else:
# Handle progress notifications callback
if isinstance(notification.root, ProgressNotification):
progress_token = notification.root.params.progressToken
# If there is a progress callback for this token,
# call it with the progress information
if progress_token in self._progress_callbacks:
callback = self._progress_callbacks[progress_token]
await callback(
notification.root.params.progress,
notification.root.params.total,
notification.root.params.message,
)
await self._received_notification(notification)
await self._handle_incoming(notification)
except Exception as e:
# For other validation errors, log and continue
logging.warning(
f"Failed to validate notification: {e}. " f"Message was: {message.message.root}"
)
else: # Response or error
stream = self._response_streams.pop(message.message.root.id, None)
if stream:
await stream.send(message.message.root)
else:
# Handle progress notifications callback
if isinstance(notification.root, ProgressNotification):
progress_token = notification.root.params.progressToken
# If there is a progress callback for this token,
# call it with the progress information
if progress_token in self._progress_callbacks:
callback = self._progress_callbacks[progress_token]
await callback(
notification.root.params.progress,
notification.root.params.total,
notification.root.params.message,
)
await self._received_notification(notification)
await self._handle_incoming(notification)
except Exception as e:
# For other validation errors, log and continue
logging.warning(
f"Failed to validate notification: {e}. " f"Message was: {message.message.root}"
)
else: # Response or error
stream = self._response_streams.pop(message.message.root.id, None)
if stream:
await stream.send(message.message.root)
else:
await self._handle_incoming(
RuntimeError("Received response with an unknown " f"request ID: {message}")
)
await self._handle_incoming(
RuntimeError("Received response with an unknown " f"request ID: {message}")
)
# after the read stream is closed, we need to send errors
# to any pending requests
for id, stream in self._response_streams.items():
error = ErrorData(code=CONNECTION_CLOSED, message="Connection closed")
await stream.send(JSONRPCError(jsonrpc="2.0", id=id, error=error))
await stream.aclose()
self._response_streams.clear()
except anyio.ClosedResourceError:
# This is expected when the client disconnects abruptly.
# Without this handler, the exception would propagate up and
# crash the server's task group.
logging.debug("Read stream closed by client")
except Exception as e:
# Other exceptions are not expected and should be logged. We purposefully
# catch all exceptions here to avoid crashing the server.
logging.exception(f"Unhandled exception in receive loop: {e}")
finally:
# after the read stream is closed, we need to send errors
# to any pending requests
for id, stream in self._response_streams.items():
error = ErrorData(code=CONNECTION_CLOSED, message="Connection closed")
try:
await stream.send(JSONRPCError(jsonrpc="2.0", id=id, error=error))
await stream.aclose()
except Exception:
# Stream might already be closed
pass
self._response_streams.clear()
async def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]) -> None:
"""

View File

@@ -1521,3 +1521,40 @@ def test_server_backwards_compatibility_no_protocol_version(basic_server, basic_
)
assert response.status_code == 200 # Should succeed for backwards compatibility
assert response.headers.get("Content-Type") == "text/event-stream"
@pytest.mark.anyio
async def test_client_crash_handled(basic_server, basic_server_url):
"""Test that cases where the client crashes are handled gracefully."""
# Simulate bad client that crashes after init
async def bad_client():
"""Client that triggers ClosedResourceError"""
async with streamablehttp_client(f"{basic_server_url}/mcp") as (
read_stream,
write_stream,
_,
):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()
raise Exception("client crash")
# Run bad client a few times to trigger the crash
for _ in range(3):
try:
await bad_client()
except Exception:
pass
await anyio.sleep(0.1)
# Try a good client, it should still be able to connect and list tools
async with streamablehttp_client(f"{basic_server_url}/mcp") as (
read_stream,
write_stream,
_,
):
async with ClientSession(read_stream, write_stream) as session:
result = await session.initialize()
assert isinstance(result, InitializeResult)
tools = await session.list_tools()
assert tools.tools