From 08cfbe522aae48365f74147b20636e8bd715174d Mon Sep 17 00:00:00 2001 From: David Soria Parra Date: Tue, 4 Feb 2025 13:58:44 +0000 Subject: [PATCH] fix: improve error handling and request cancellation for issue #88 --- src/mcp/shared/session.py | 34 ++++++++++++++++++---------- tests/issues/test_88_random_error.py | 27 +++++++++++++++------- 2 files changed, 41 insertions(+), 20 deletions(-) diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index ddfa909..e21bcbc 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -1,3 +1,4 @@ +import logging from contextlib import AbstractAsyncContextManager from datetime import timedelta from typing import Any, Callable, Generic, TypeVar @@ -273,19 +274,28 @@ class BaseSession( await self._incoming_message_stream_writer.send(responder) elif isinstance(message.root, JSONRPCNotification): - notification = self._receive_notification_type.model_validate( - message.root.model_dump( - by_alias=True, mode="json", exclude_none=True + try: + notification = self._receive_notification_type.model_validate( + message.root.model_dump( + by_alias=True, mode="json", exclude_none=True + ) + ) + # Handle cancellation notifications + if isinstance(notification.root, CancelledNotification): + cancelled_id = notification.root.params.requestId + if cancelled_id in self._in_flight: + await self._in_flight[cancelled_id].cancel() + else: + await self._received_notification(notification) + await self._incoming_message_stream_writer.send( + notification + ) + except Exception as e: + # For other validation errors, log and continue + logging.warning( + f"Failed to validate notification: {e}. " + f"Message was: {message.root}" ) - ) - # Handle cancellation notifications - if isinstance(notification.root, CancelledNotification): - cancelled_id = notification.root.params.requestId - if cancelled_id in self._in_flight: - await self._in_flight[cancelled_id].cancel() - else: - await self._received_notification(notification) - await self._incoming_message_stream_writer.send(notification) else: # Response or error stream = self._response_streams.pop(message.root.id, None) if stream: diff --git a/tests/issues/test_88_random_error.py b/tests/issues/test_88_random_error.py index 8b979ab..8609c20 100644 --- a/tests/issues/test_88_random_error.py +++ b/tests/issues/test_88_random_error.py @@ -30,19 +30,23 @@ async def test_notification_validation_error(tmp_path: Path): server = Server(name="test") request_count = 0 - slow_request_complete = False + slow_request_started = anyio.Event() + slow_request_complete = anyio.Event() @server.call_tool() async def slow_tool( name: str, arg ) -> Sequence[TextContent | ImageContent | EmbeddedResource]: - nonlocal request_count, slow_request_complete + nonlocal request_count request_count += 1 if name == "slow": + # Signal that slow request has started + slow_request_started.set() # Long enough to ensure timeout await anyio.sleep(0.2) - slow_request_complete = True + # Signal completion + slow_request_complete.set() return [TextContent(type="text", text=f"slow {request_count}")] elif name == "fast": # Fast enough to complete before timeout @@ -71,7 +75,7 @@ async def test_notification_validation_error(tmp_path: Path): # First call should work (fast operation) result = await session.call_tool("fast") assert result.content == [TextContent(type="text", text="fast 1")] - assert not slow_request_complete + assert not slow_request_complete.is_set() # Second call should timeout (slow operation) with pytest.raises(McpError) as exc_info: @@ -79,8 +83,8 @@ async def test_notification_validation_error(tmp_path: Path): assert "Timed out while waiting" in str(exc_info.value) # Wait for slow request to complete in the background - await anyio.sleep(0.3) - assert slow_request_complete + with anyio.fail_after(1): # Timeout after 1 second + await slow_request_complete.wait() # Third call should work (fast operation), # proving server is still responsive @@ -91,10 +95,17 @@ async def test_notification_validation_error(tmp_path: Path): server_writer, server_reader = anyio.create_memory_object_stream(1) client_writer, client_reader = anyio.create_memory_object_stream(1) + server_ready = anyio.Event() + + async def wrapped_server_handler(read_stream, write_stream): + server_ready.set() + await server_handler(read_stream, write_stream) + async with anyio.create_task_group() as tg: - tg.start_soon(server_handler, server_reader, client_writer) + tg.start_soon(wrapped_server_handler, server_reader, client_writer) # Wait for server to start and initialize - await anyio.sleep(0.1) + with anyio.fail_after(1): # Timeout after 1 second + await server_ready.wait() # Run client in a separate task to avoid cancellation async with anyio.create_task_group() as client_tg: client_tg.start_soon(client, client_reader, server_writer)