mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-18 22:44:20 +01:00
Fix uncaught exception in MCP server (#967)
This commit is contained in:
@@ -333,90 +333,107 @@ class BaseSession(
|
|||||||
self._read_stream,
|
self._read_stream,
|
||||||
self._write_stream,
|
self._write_stream,
|
||||||
):
|
):
|
||||||
async for message in self._read_stream:
|
try:
|
||||||
if isinstance(message, Exception):
|
async for message in self._read_stream:
|
||||||
await self._handle_incoming(message)
|
if isinstance(message, Exception):
|
||||||
elif isinstance(message.message.root, JSONRPCRequest):
|
await self._handle_incoming(message)
|
||||||
try:
|
elif isinstance(message.message.root, JSONRPCRequest):
|
||||||
validated_request = self._receive_request_type.model_validate(
|
try:
|
||||||
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
|
validated_request = self._receive_request_type.model_validate(
|
||||||
)
|
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||||
responder = RequestResponder(
|
)
|
||||||
request_id=message.message.root.id,
|
responder = RequestResponder(
|
||||||
request_meta=validated_request.root.params.meta if validated_request.root.params else None,
|
request_id=message.message.root.id,
|
||||||
request=validated_request,
|
request_meta=validated_request.root.params.meta
|
||||||
session=self,
|
if validated_request.root.params
|
||||||
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
|
else None,
|
||||||
message_metadata=message.metadata,
|
request=validated_request,
|
||||||
)
|
session=self,
|
||||||
self._in_flight[responder.request_id] = responder
|
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
|
||||||
await self._received_request(responder)
|
message_metadata=message.metadata,
|
||||||
|
)
|
||||||
|
self._in_flight[responder.request_id] = responder
|
||||||
|
await self._received_request(responder)
|
||||||
|
|
||||||
if not responder._completed: # type: ignore[reportPrivateUsage]
|
if not responder._completed: # type: ignore[reportPrivateUsage]
|
||||||
await self._handle_incoming(responder)
|
await self._handle_incoming(responder)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# For request validation errors, send a proper JSON-RPC error
|
# For request validation errors, send a proper JSON-RPC error
|
||||||
# response instead of crashing the server
|
# response instead of crashing the server
|
||||||
logging.warning(f"Failed to validate request: {e}")
|
logging.warning(f"Failed to validate request: {e}")
|
||||||
logging.debug(f"Message that failed validation: {message.message.root}")
|
logging.debug(f"Message that failed validation: {message.message.root}")
|
||||||
error_response = JSONRPCError(
|
error_response = JSONRPCError(
|
||||||
jsonrpc="2.0",
|
jsonrpc="2.0",
|
||||||
id=message.message.root.id,
|
id=message.message.root.id,
|
||||||
error=ErrorData(
|
error=ErrorData(
|
||||||
code=INVALID_PARAMS,
|
code=INVALID_PARAMS,
|
||||||
message="Invalid request parameters",
|
message="Invalid request parameters",
|
||||||
data="",
|
data="",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
session_message = SessionMessage(message=JSONRPCMessage(error_response))
|
session_message = SessionMessage(message=JSONRPCMessage(error_response))
|
||||||
await self._write_stream.send(session_message)
|
await self._write_stream.send(session_message)
|
||||||
|
|
||||||
elif isinstance(message.message.root, JSONRPCNotification):
|
elif isinstance(message.message.root, JSONRPCNotification):
|
||||||
try:
|
try:
|
||||||
notification = self._receive_notification_type.model_validate(
|
notification = self._receive_notification_type.model_validate(
|
||||||
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
|
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||||
)
|
)
|
||||||
# Handle cancellation notifications
|
# Handle cancellation notifications
|
||||||
if isinstance(notification.root, CancelledNotification):
|
if isinstance(notification.root, CancelledNotification):
|
||||||
cancelled_id = notification.root.params.requestId
|
cancelled_id = notification.root.params.requestId
|
||||||
if cancelled_id in self._in_flight:
|
if cancelled_id in self._in_flight:
|
||||||
await self._in_flight[cancelled_id].cancel()
|
await self._in_flight[cancelled_id].cancel()
|
||||||
|
else:
|
||||||
|
# Handle progress notifications callback
|
||||||
|
if isinstance(notification.root, ProgressNotification):
|
||||||
|
progress_token = notification.root.params.progressToken
|
||||||
|
# If there is a progress callback for this token,
|
||||||
|
# call it with the progress information
|
||||||
|
if progress_token in self._progress_callbacks:
|
||||||
|
callback = self._progress_callbacks[progress_token]
|
||||||
|
await callback(
|
||||||
|
notification.root.params.progress,
|
||||||
|
notification.root.params.total,
|
||||||
|
notification.root.params.message,
|
||||||
|
)
|
||||||
|
await self._received_notification(notification)
|
||||||
|
await self._handle_incoming(notification)
|
||||||
|
except Exception as e:
|
||||||
|
# For other validation errors, log and continue
|
||||||
|
logging.warning(
|
||||||
|
f"Failed to validate notification: {e}. " f"Message was: {message.message.root}"
|
||||||
|
)
|
||||||
|
else: # Response or error
|
||||||
|
stream = self._response_streams.pop(message.message.root.id, None)
|
||||||
|
if stream:
|
||||||
|
await stream.send(message.message.root)
|
||||||
else:
|
else:
|
||||||
# Handle progress notifications callback
|
await self._handle_incoming(
|
||||||
if isinstance(notification.root, ProgressNotification):
|
RuntimeError("Received response with an unknown " f"request ID: {message}")
|
||||||
progress_token = notification.root.params.progressToken
|
)
|
||||||
# If there is a progress callback for this token,
|
|
||||||
# call it with the progress information
|
|
||||||
if progress_token in self._progress_callbacks:
|
|
||||||
callback = self._progress_callbacks[progress_token]
|
|
||||||
await callback(
|
|
||||||
notification.root.params.progress,
|
|
||||||
notification.root.params.total,
|
|
||||||
notification.root.params.message,
|
|
||||||
)
|
|
||||||
await self._received_notification(notification)
|
|
||||||
await self._handle_incoming(notification)
|
|
||||||
except Exception as e:
|
|
||||||
# For other validation errors, log and continue
|
|
||||||
logging.warning(
|
|
||||||
f"Failed to validate notification: {e}. " f"Message was: {message.message.root}"
|
|
||||||
)
|
|
||||||
else: # Response or error
|
|
||||||
stream = self._response_streams.pop(message.message.root.id, None)
|
|
||||||
if stream:
|
|
||||||
await stream.send(message.message.root)
|
|
||||||
else:
|
|
||||||
await self._handle_incoming(
|
|
||||||
RuntimeError("Received response with an unknown " f"request ID: {message}")
|
|
||||||
)
|
|
||||||
|
|
||||||
# after the read stream is closed, we need to send errors
|
except anyio.ClosedResourceError:
|
||||||
# to any pending requests
|
# This is expected when the client disconnects abruptly.
|
||||||
for id, stream in self._response_streams.items():
|
# Without this handler, the exception would propagate up and
|
||||||
error = ErrorData(code=CONNECTION_CLOSED, message="Connection closed")
|
# crash the server's task group.
|
||||||
await stream.send(JSONRPCError(jsonrpc="2.0", id=id, error=error))
|
logging.debug("Read stream closed by client")
|
||||||
await stream.aclose()
|
except Exception as e:
|
||||||
self._response_streams.clear()
|
# Other exceptions are not expected and should be logged. We purposefully
|
||||||
|
# catch all exceptions here to avoid crashing the server.
|
||||||
|
logging.exception(f"Unhandled exception in receive loop: {e}")
|
||||||
|
finally:
|
||||||
|
# after the read stream is closed, we need to send errors
|
||||||
|
# to any pending requests
|
||||||
|
for id, stream in self._response_streams.items():
|
||||||
|
error = ErrorData(code=CONNECTION_CLOSED, message="Connection closed")
|
||||||
|
try:
|
||||||
|
await stream.send(JSONRPCError(jsonrpc="2.0", id=id, error=error))
|
||||||
|
await stream.aclose()
|
||||||
|
except Exception:
|
||||||
|
# Stream might already be closed
|
||||||
|
pass
|
||||||
|
self._response_streams.clear()
|
||||||
|
|
||||||
async def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]) -> None:
|
async def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]) -> None:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1521,3 +1521,40 @@ def test_server_backwards_compatibility_no_protocol_version(basic_server, basic_
|
|||||||
)
|
)
|
||||||
assert response.status_code == 200 # Should succeed for backwards compatibility
|
assert response.status_code == 200 # Should succeed for backwards compatibility
|
||||||
assert response.headers.get("Content-Type") == "text/event-stream"
|
assert response.headers.get("Content-Type") == "text/event-stream"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_client_crash_handled(basic_server, basic_server_url):
|
||||||
|
"""Test that cases where the client crashes are handled gracefully."""
|
||||||
|
|
||||||
|
# Simulate bad client that crashes after init
|
||||||
|
async def bad_client():
|
||||||
|
"""Client that triggers ClosedResourceError"""
|
||||||
|
async with streamablehttp_client(f"{basic_server_url}/mcp") as (
|
||||||
|
read_stream,
|
||||||
|
write_stream,
|
||||||
|
_,
|
||||||
|
):
|
||||||
|
async with ClientSession(read_stream, write_stream) as session:
|
||||||
|
await session.initialize()
|
||||||
|
raise Exception("client crash")
|
||||||
|
|
||||||
|
# Run bad client a few times to trigger the crash
|
||||||
|
for _ in range(3):
|
||||||
|
try:
|
||||||
|
await bad_client()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
await anyio.sleep(0.1)
|
||||||
|
|
||||||
|
# Try a good client, it should still be able to connect and list tools
|
||||||
|
async with streamablehttp_client(f"{basic_server_url}/mcp") as (
|
||||||
|
read_stream,
|
||||||
|
write_stream,
|
||||||
|
_,
|
||||||
|
):
|
||||||
|
async with ClientSession(read_stream, write_stream) as session:
|
||||||
|
result = await session.initialize()
|
||||||
|
assert isinstance(result, InitializeResult)
|
||||||
|
tools = await session.list_tools()
|
||||||
|
assert tools.tools
|
||||||
|
|||||||
Reference in New Issue
Block a user