mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-18 14:34:27 +01:00
Fix uncaught exception in MCP server (#967)
This commit is contained in:
@@ -333,90 +333,107 @@ class BaseSession(
|
||||
self._read_stream,
|
||||
self._write_stream,
|
||||
):
|
||||
async for message in self._read_stream:
|
||||
if isinstance(message, Exception):
|
||||
await self._handle_incoming(message)
|
||||
elif isinstance(message.message.root, JSONRPCRequest):
|
||||
try:
|
||||
validated_request = self._receive_request_type.model_validate(
|
||||
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||
)
|
||||
responder = RequestResponder(
|
||||
request_id=message.message.root.id,
|
||||
request_meta=validated_request.root.params.meta if validated_request.root.params else None,
|
||||
request=validated_request,
|
||||
session=self,
|
||||
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
|
||||
message_metadata=message.metadata,
|
||||
)
|
||||
self._in_flight[responder.request_id] = responder
|
||||
await self._received_request(responder)
|
||||
try:
|
||||
async for message in self._read_stream:
|
||||
if isinstance(message, Exception):
|
||||
await self._handle_incoming(message)
|
||||
elif isinstance(message.message.root, JSONRPCRequest):
|
||||
try:
|
||||
validated_request = self._receive_request_type.model_validate(
|
||||
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||
)
|
||||
responder = RequestResponder(
|
||||
request_id=message.message.root.id,
|
||||
request_meta=validated_request.root.params.meta
|
||||
if validated_request.root.params
|
||||
else None,
|
||||
request=validated_request,
|
||||
session=self,
|
||||
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
|
||||
message_metadata=message.metadata,
|
||||
)
|
||||
self._in_flight[responder.request_id] = responder
|
||||
await self._received_request(responder)
|
||||
|
||||
if not responder._completed: # type: ignore[reportPrivateUsage]
|
||||
await self._handle_incoming(responder)
|
||||
except Exception as e:
|
||||
# For request validation errors, send a proper JSON-RPC error
|
||||
# response instead of crashing the server
|
||||
logging.warning(f"Failed to validate request: {e}")
|
||||
logging.debug(f"Message that failed validation: {message.message.root}")
|
||||
error_response = JSONRPCError(
|
||||
jsonrpc="2.0",
|
||||
id=message.message.root.id,
|
||||
error=ErrorData(
|
||||
code=INVALID_PARAMS,
|
||||
message="Invalid request parameters",
|
||||
data="",
|
||||
),
|
||||
)
|
||||
session_message = SessionMessage(message=JSONRPCMessage(error_response))
|
||||
await self._write_stream.send(session_message)
|
||||
if not responder._completed: # type: ignore[reportPrivateUsage]
|
||||
await self._handle_incoming(responder)
|
||||
except Exception as e:
|
||||
# For request validation errors, send a proper JSON-RPC error
|
||||
# response instead of crashing the server
|
||||
logging.warning(f"Failed to validate request: {e}")
|
||||
logging.debug(f"Message that failed validation: {message.message.root}")
|
||||
error_response = JSONRPCError(
|
||||
jsonrpc="2.0",
|
||||
id=message.message.root.id,
|
||||
error=ErrorData(
|
||||
code=INVALID_PARAMS,
|
||||
message="Invalid request parameters",
|
||||
data="",
|
||||
),
|
||||
)
|
||||
session_message = SessionMessage(message=JSONRPCMessage(error_response))
|
||||
await self._write_stream.send(session_message)
|
||||
|
||||
elif isinstance(message.message.root, JSONRPCNotification):
|
||||
try:
|
||||
notification = self._receive_notification_type.model_validate(
|
||||
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||
)
|
||||
# Handle cancellation notifications
|
||||
if isinstance(notification.root, CancelledNotification):
|
||||
cancelled_id = notification.root.params.requestId
|
||||
if cancelled_id in self._in_flight:
|
||||
await self._in_flight[cancelled_id].cancel()
|
||||
elif isinstance(message.message.root, JSONRPCNotification):
|
||||
try:
|
||||
notification = self._receive_notification_type.model_validate(
|
||||
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||
)
|
||||
# Handle cancellation notifications
|
||||
if isinstance(notification.root, CancelledNotification):
|
||||
cancelled_id = notification.root.params.requestId
|
||||
if cancelled_id in self._in_flight:
|
||||
await self._in_flight[cancelled_id].cancel()
|
||||
else:
|
||||
# Handle progress notifications callback
|
||||
if isinstance(notification.root, ProgressNotification):
|
||||
progress_token = notification.root.params.progressToken
|
||||
# If there is a progress callback for this token,
|
||||
# call it with the progress information
|
||||
if progress_token in self._progress_callbacks:
|
||||
callback = self._progress_callbacks[progress_token]
|
||||
await callback(
|
||||
notification.root.params.progress,
|
||||
notification.root.params.total,
|
||||
notification.root.params.message,
|
||||
)
|
||||
await self._received_notification(notification)
|
||||
await self._handle_incoming(notification)
|
||||
except Exception as e:
|
||||
# For other validation errors, log and continue
|
||||
logging.warning(
|
||||
f"Failed to validate notification: {e}. " f"Message was: {message.message.root}"
|
||||
)
|
||||
else: # Response or error
|
||||
stream = self._response_streams.pop(message.message.root.id, None)
|
||||
if stream:
|
||||
await stream.send(message.message.root)
|
||||
else:
|
||||
# Handle progress notifications callback
|
||||
if isinstance(notification.root, ProgressNotification):
|
||||
progress_token = notification.root.params.progressToken
|
||||
# If there is a progress callback for this token,
|
||||
# call it with the progress information
|
||||
if progress_token in self._progress_callbacks:
|
||||
callback = self._progress_callbacks[progress_token]
|
||||
await callback(
|
||||
notification.root.params.progress,
|
||||
notification.root.params.total,
|
||||
notification.root.params.message,
|
||||
)
|
||||
await self._received_notification(notification)
|
||||
await self._handle_incoming(notification)
|
||||
except Exception as e:
|
||||
# For other validation errors, log and continue
|
||||
logging.warning(
|
||||
f"Failed to validate notification: {e}. " f"Message was: {message.message.root}"
|
||||
)
|
||||
else: # Response or error
|
||||
stream = self._response_streams.pop(message.message.root.id, None)
|
||||
if stream:
|
||||
await stream.send(message.message.root)
|
||||
else:
|
||||
await self._handle_incoming(
|
||||
RuntimeError("Received response with an unknown " f"request ID: {message}")
|
||||
)
|
||||
await self._handle_incoming(
|
||||
RuntimeError("Received response with an unknown " f"request ID: {message}")
|
||||
)
|
||||
|
||||
# after the read stream is closed, we need to send errors
|
||||
# to any pending requests
|
||||
for id, stream in self._response_streams.items():
|
||||
error = ErrorData(code=CONNECTION_CLOSED, message="Connection closed")
|
||||
await stream.send(JSONRPCError(jsonrpc="2.0", id=id, error=error))
|
||||
await stream.aclose()
|
||||
self._response_streams.clear()
|
||||
except anyio.ClosedResourceError:
|
||||
# This is expected when the client disconnects abruptly.
|
||||
# Without this handler, the exception would propagate up and
|
||||
# crash the server's task group.
|
||||
logging.debug("Read stream closed by client")
|
||||
except Exception as e:
|
||||
# Other exceptions are not expected and should be logged. We purposefully
|
||||
# catch all exceptions here to avoid crashing the server.
|
||||
logging.exception(f"Unhandled exception in receive loop: {e}")
|
||||
finally:
|
||||
# after the read stream is closed, we need to send errors
|
||||
# to any pending requests
|
||||
for id, stream in self._response_streams.items():
|
||||
error = ErrorData(code=CONNECTION_CLOSED, message="Connection closed")
|
||||
try:
|
||||
await stream.send(JSONRPCError(jsonrpc="2.0", id=id, error=error))
|
||||
await stream.aclose()
|
||||
except Exception:
|
||||
# Stream might already be closed
|
||||
pass
|
||||
self._response_streams.clear()
|
||||
|
||||
async def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]) -> None:
|
||||
"""
|
||||
|
||||
@@ -1521,3 +1521,40 @@ def test_server_backwards_compatibility_no_protocol_version(basic_server, basic_
|
||||
)
|
||||
assert response.status_code == 200 # Should succeed for backwards compatibility
|
||||
assert response.headers.get("Content-Type") == "text/event-stream"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_client_crash_handled(basic_server, basic_server_url):
|
||||
"""Test that cases where the client crashes are handled gracefully."""
|
||||
|
||||
# Simulate bad client that crashes after init
|
||||
async def bad_client():
|
||||
"""Client that triggers ClosedResourceError"""
|
||||
async with streamablehttp_client(f"{basic_server_url}/mcp") as (
|
||||
read_stream,
|
||||
write_stream,
|
||||
_,
|
||||
):
|
||||
async with ClientSession(read_stream, write_stream) as session:
|
||||
await session.initialize()
|
||||
raise Exception("client crash")
|
||||
|
||||
# Run bad client a few times to trigger the crash
|
||||
for _ in range(3):
|
||||
try:
|
||||
await bad_client()
|
||||
except Exception:
|
||||
pass
|
||||
await anyio.sleep(0.1)
|
||||
|
||||
# Try a good client, it should still be able to connect and list tools
|
||||
async with streamablehttp_client(f"{basic_server_url}/mcp") as (
|
||||
read_stream,
|
||||
write_stream,
|
||||
_,
|
||||
):
|
||||
async with ClientSession(read_stream, write_stream) as session:
|
||||
result = await session.initialize()
|
||||
assert isinstance(result, InitializeResult)
|
||||
tools = await session.list_tools()
|
||||
assert tools.tools
|
||||
|
||||
Reference in New Issue
Block a user