mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 14:54:24 +01:00
fix: improve error handling and request cancellation for issue #88
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import logging
|
||||
from contextlib import AbstractAsyncContextManager
|
||||
from datetime import timedelta
|
||||
from typing import Any, Callable, Generic, TypeVar
|
||||
@@ -273,6 +274,7 @@ class BaseSession(
|
||||
await self._incoming_message_stream_writer.send(responder)
|
||||
|
||||
elif isinstance(message.root, JSONRPCNotification):
|
||||
try:
|
||||
notification = self._receive_notification_type.model_validate(
|
||||
message.root.model_dump(
|
||||
by_alias=True, mode="json", exclude_none=True
|
||||
@@ -285,7 +287,15 @@ class BaseSession(
|
||||
await self._in_flight[cancelled_id].cancel()
|
||||
else:
|
||||
await self._received_notification(notification)
|
||||
await self._incoming_message_stream_writer.send(notification)
|
||||
await self._incoming_message_stream_writer.send(
|
||||
notification
|
||||
)
|
||||
except Exception as e:
|
||||
# For other validation errors, log and continue
|
||||
logging.warning(
|
||||
f"Failed to validate notification: {e}. "
|
||||
f"Message was: {message.root}"
|
||||
)
|
||||
else: # Response or error
|
||||
stream = self._response_streams.pop(message.root.id, None)
|
||||
if stream:
|
||||
|
||||
@@ -30,19 +30,23 @@ async def test_notification_validation_error(tmp_path: Path):
|
||||
|
||||
server = Server(name="test")
|
||||
request_count = 0
|
||||
slow_request_complete = False
|
||||
slow_request_started = anyio.Event()
|
||||
slow_request_complete = anyio.Event()
|
||||
|
||||
@server.call_tool()
|
||||
async def slow_tool(
|
||||
name: str, arg
|
||||
) -> Sequence[TextContent | ImageContent | EmbeddedResource]:
|
||||
nonlocal request_count, slow_request_complete
|
||||
nonlocal request_count
|
||||
request_count += 1
|
||||
|
||||
if name == "slow":
|
||||
# Signal that slow request has started
|
||||
slow_request_started.set()
|
||||
# Long enough to ensure timeout
|
||||
await anyio.sleep(0.2)
|
||||
slow_request_complete = True
|
||||
# Signal completion
|
||||
slow_request_complete.set()
|
||||
return [TextContent(type="text", text=f"slow {request_count}")]
|
||||
elif name == "fast":
|
||||
# Fast enough to complete before timeout
|
||||
@@ -71,7 +75,7 @@ async def test_notification_validation_error(tmp_path: Path):
|
||||
# First call should work (fast operation)
|
||||
result = await session.call_tool("fast")
|
||||
assert result.content == [TextContent(type="text", text="fast 1")]
|
||||
assert not slow_request_complete
|
||||
assert not slow_request_complete.is_set()
|
||||
|
||||
# Second call should timeout (slow operation)
|
||||
with pytest.raises(McpError) as exc_info:
|
||||
@@ -79,8 +83,8 @@ async def test_notification_validation_error(tmp_path: Path):
|
||||
assert "Timed out while waiting" in str(exc_info.value)
|
||||
|
||||
# Wait for slow request to complete in the background
|
||||
await anyio.sleep(0.3)
|
||||
assert slow_request_complete
|
||||
with anyio.fail_after(1): # Timeout after 1 second
|
||||
await slow_request_complete.wait()
|
||||
|
||||
# Third call should work (fast operation),
|
||||
# proving server is still responsive
|
||||
@@ -91,10 +95,17 @@ async def test_notification_validation_error(tmp_path: Path):
|
||||
server_writer, server_reader = anyio.create_memory_object_stream(1)
|
||||
client_writer, client_reader = anyio.create_memory_object_stream(1)
|
||||
|
||||
server_ready = anyio.Event()
|
||||
|
||||
async def wrapped_server_handler(read_stream, write_stream):
|
||||
server_ready.set()
|
||||
await server_handler(read_stream, write_stream)
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(server_handler, server_reader, client_writer)
|
||||
tg.start_soon(wrapped_server_handler, server_reader, client_writer)
|
||||
# Wait for server to start and initialize
|
||||
await anyio.sleep(0.1)
|
||||
with anyio.fail_after(1): # Timeout after 1 second
|
||||
await server_ready.wait()
|
||||
# Run client in a separate task to avoid cancellation
|
||||
async with anyio.create_task_group() as client_tg:
|
||||
client_tg.start_soon(client, client_reader, server_writer)
|
||||
|
||||
Reference in New Issue
Block a user