fix: improve error handling and request cancellation for issue #88

This commit is contained in:
David Soria Parra
2025-02-04 13:58:44 +00:00
parent 827e494df4
commit 08cfbe522a
2 changed files with 41 additions and 20 deletions

View File

@@ -1,3 +1,4 @@
import logging
from contextlib import AbstractAsyncContextManager from contextlib import AbstractAsyncContextManager
from datetime import timedelta from datetime import timedelta
from typing import Any, Callable, Generic, TypeVar from typing import Any, Callable, Generic, TypeVar
@@ -273,19 +274,28 @@ class BaseSession(
await self._incoming_message_stream_writer.send(responder) await self._incoming_message_stream_writer.send(responder)
elif isinstance(message.root, JSONRPCNotification): elif isinstance(message.root, JSONRPCNotification):
notification = self._receive_notification_type.model_validate( try:
message.root.model_dump( notification = self._receive_notification_type.model_validate(
by_alias=True, mode="json", exclude_none=True message.root.model_dump(
by_alias=True, mode="json", exclude_none=True
)
)
# Handle cancellation notifications
if isinstance(notification.root, CancelledNotification):
cancelled_id = notification.root.params.requestId
if cancelled_id in self._in_flight:
await self._in_flight[cancelled_id].cancel()
else:
await self._received_notification(notification)
await self._incoming_message_stream_writer.send(
notification
)
except Exception as e:
# For other validation errors, log and continue
logging.warning(
f"Failed to validate notification: {e}. "
f"Message was: {message.root}"
) )
)
# Handle cancellation notifications
if isinstance(notification.root, CancelledNotification):
cancelled_id = notification.root.params.requestId
if cancelled_id in self._in_flight:
await self._in_flight[cancelled_id].cancel()
else:
await self._received_notification(notification)
await self._incoming_message_stream_writer.send(notification)
else: # Response or error else: # Response or error
stream = self._response_streams.pop(message.root.id, None) stream = self._response_streams.pop(message.root.id, None)
if stream: if stream:

View File

@@ -30,19 +30,23 @@ async def test_notification_validation_error(tmp_path: Path):
server = Server(name="test") server = Server(name="test")
request_count = 0 request_count = 0
slow_request_complete = False slow_request_started = anyio.Event()
slow_request_complete = anyio.Event()
@server.call_tool() @server.call_tool()
async def slow_tool( async def slow_tool(
name: str, arg name: str, arg
) -> Sequence[TextContent | ImageContent | EmbeddedResource]: ) -> Sequence[TextContent | ImageContent | EmbeddedResource]:
nonlocal request_count, slow_request_complete nonlocal request_count
request_count += 1 request_count += 1
if name == "slow": if name == "slow":
# Signal that slow request has started
slow_request_started.set()
# Long enough to ensure timeout # Long enough to ensure timeout
await anyio.sleep(0.2) await anyio.sleep(0.2)
slow_request_complete = True # Signal completion
slow_request_complete.set()
return [TextContent(type="text", text=f"slow {request_count}")] return [TextContent(type="text", text=f"slow {request_count}")]
elif name == "fast": elif name == "fast":
# Fast enough to complete before timeout # Fast enough to complete before timeout
@@ -71,7 +75,7 @@ async def test_notification_validation_error(tmp_path: Path):
# First call should work (fast operation) # First call should work (fast operation)
result = await session.call_tool("fast") result = await session.call_tool("fast")
assert result.content == [TextContent(type="text", text="fast 1")] assert result.content == [TextContent(type="text", text="fast 1")]
assert not slow_request_complete assert not slow_request_complete.is_set()
# Second call should timeout (slow operation) # Second call should timeout (slow operation)
with pytest.raises(McpError) as exc_info: with pytest.raises(McpError) as exc_info:
@@ -79,8 +83,8 @@ async def test_notification_validation_error(tmp_path: Path):
assert "Timed out while waiting" in str(exc_info.value) assert "Timed out while waiting" in str(exc_info.value)
# Wait for slow request to complete in the background # Wait for slow request to complete in the background
await anyio.sleep(0.3) with anyio.fail_after(1): # Timeout after 1 second
assert slow_request_complete await slow_request_complete.wait()
# Third call should work (fast operation), # Third call should work (fast operation),
# proving server is still responsive # proving server is still responsive
@@ -91,10 +95,17 @@ async def test_notification_validation_error(tmp_path: Path):
server_writer, server_reader = anyio.create_memory_object_stream(1) server_writer, server_reader = anyio.create_memory_object_stream(1)
client_writer, client_reader = anyio.create_memory_object_stream(1) client_writer, client_reader = anyio.create_memory_object_stream(1)
server_ready = anyio.Event()
async def wrapped_server_handler(read_stream, write_stream):
server_ready.set()
await server_handler(read_stream, write_stream)
async with anyio.create_task_group() as tg: async with anyio.create_task_group() as tg:
tg.start_soon(server_handler, server_reader, client_writer) tg.start_soon(wrapped_server_handler, server_reader, client_writer)
# Wait for server to start and initialize # Wait for server to start and initialize
await anyio.sleep(0.1) with anyio.fail_after(1): # Timeout after 1 second
await server_ready.wait()
# Run client in a separate task to avoid cancellation # Run client in a separate task to avoid cancellation
async with anyio.create_task_group() as client_tg: async with anyio.create_task_group() as client_tg:
client_tg.start_soon(client, client_reader, server_writer) client_tg.start_soon(client, client_reader, server_writer)