mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 14:54:24 +01:00
fix: improve error handling and request cancellation for issue #88
This commit is contained in:
@@ -1,3 +1,4 @@
|
|||||||
|
import logging
|
||||||
from contextlib import AbstractAsyncContextManager
|
from contextlib import AbstractAsyncContextManager
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from typing import Any, Callable, Generic, TypeVar
|
from typing import Any, Callable, Generic, TypeVar
|
||||||
@@ -273,19 +274,28 @@ class BaseSession(
|
|||||||
await self._incoming_message_stream_writer.send(responder)
|
await self._incoming_message_stream_writer.send(responder)
|
||||||
|
|
||||||
elif isinstance(message.root, JSONRPCNotification):
|
elif isinstance(message.root, JSONRPCNotification):
|
||||||
notification = self._receive_notification_type.model_validate(
|
try:
|
||||||
message.root.model_dump(
|
notification = self._receive_notification_type.model_validate(
|
||||||
by_alias=True, mode="json", exclude_none=True
|
message.root.model_dump(
|
||||||
|
by_alias=True, mode="json", exclude_none=True
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Handle cancellation notifications
|
||||||
|
if isinstance(notification.root, CancelledNotification):
|
||||||
|
cancelled_id = notification.root.params.requestId
|
||||||
|
if cancelled_id in self._in_flight:
|
||||||
|
await self._in_flight[cancelled_id].cancel()
|
||||||
|
else:
|
||||||
|
await self._received_notification(notification)
|
||||||
|
await self._incoming_message_stream_writer.send(
|
||||||
|
notification
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
# For other validation errors, log and continue
|
||||||
|
logging.warning(
|
||||||
|
f"Failed to validate notification: {e}. "
|
||||||
|
f"Message was: {message.root}"
|
||||||
)
|
)
|
||||||
)
|
|
||||||
# Handle cancellation notifications
|
|
||||||
if isinstance(notification.root, CancelledNotification):
|
|
||||||
cancelled_id = notification.root.params.requestId
|
|
||||||
if cancelled_id in self._in_flight:
|
|
||||||
await self._in_flight[cancelled_id].cancel()
|
|
||||||
else:
|
|
||||||
await self._received_notification(notification)
|
|
||||||
await self._incoming_message_stream_writer.send(notification)
|
|
||||||
else: # Response or error
|
else: # Response or error
|
||||||
stream = self._response_streams.pop(message.root.id, None)
|
stream = self._response_streams.pop(message.root.id, None)
|
||||||
if stream:
|
if stream:
|
||||||
|
|||||||
@@ -30,19 +30,23 @@ async def test_notification_validation_error(tmp_path: Path):
|
|||||||
|
|
||||||
server = Server(name="test")
|
server = Server(name="test")
|
||||||
request_count = 0
|
request_count = 0
|
||||||
slow_request_complete = False
|
slow_request_started = anyio.Event()
|
||||||
|
slow_request_complete = anyio.Event()
|
||||||
|
|
||||||
@server.call_tool()
|
@server.call_tool()
|
||||||
async def slow_tool(
|
async def slow_tool(
|
||||||
name: str, arg
|
name: str, arg
|
||||||
) -> Sequence[TextContent | ImageContent | EmbeddedResource]:
|
) -> Sequence[TextContent | ImageContent | EmbeddedResource]:
|
||||||
nonlocal request_count, slow_request_complete
|
nonlocal request_count
|
||||||
request_count += 1
|
request_count += 1
|
||||||
|
|
||||||
if name == "slow":
|
if name == "slow":
|
||||||
|
# Signal that slow request has started
|
||||||
|
slow_request_started.set()
|
||||||
# Long enough to ensure timeout
|
# Long enough to ensure timeout
|
||||||
await anyio.sleep(0.2)
|
await anyio.sleep(0.2)
|
||||||
slow_request_complete = True
|
# Signal completion
|
||||||
|
slow_request_complete.set()
|
||||||
return [TextContent(type="text", text=f"slow {request_count}")]
|
return [TextContent(type="text", text=f"slow {request_count}")]
|
||||||
elif name == "fast":
|
elif name == "fast":
|
||||||
# Fast enough to complete before timeout
|
# Fast enough to complete before timeout
|
||||||
@@ -71,7 +75,7 @@ async def test_notification_validation_error(tmp_path: Path):
|
|||||||
# First call should work (fast operation)
|
# First call should work (fast operation)
|
||||||
result = await session.call_tool("fast")
|
result = await session.call_tool("fast")
|
||||||
assert result.content == [TextContent(type="text", text="fast 1")]
|
assert result.content == [TextContent(type="text", text="fast 1")]
|
||||||
assert not slow_request_complete
|
assert not slow_request_complete.is_set()
|
||||||
|
|
||||||
# Second call should timeout (slow operation)
|
# Second call should timeout (slow operation)
|
||||||
with pytest.raises(McpError) as exc_info:
|
with pytest.raises(McpError) as exc_info:
|
||||||
@@ -79,8 +83,8 @@ async def test_notification_validation_error(tmp_path: Path):
|
|||||||
assert "Timed out while waiting" in str(exc_info.value)
|
assert "Timed out while waiting" in str(exc_info.value)
|
||||||
|
|
||||||
# Wait for slow request to complete in the background
|
# Wait for slow request to complete in the background
|
||||||
await anyio.sleep(0.3)
|
with anyio.fail_after(1): # Timeout after 1 second
|
||||||
assert slow_request_complete
|
await slow_request_complete.wait()
|
||||||
|
|
||||||
# Third call should work (fast operation),
|
# Third call should work (fast operation),
|
||||||
# proving server is still responsive
|
# proving server is still responsive
|
||||||
@@ -91,10 +95,17 @@ async def test_notification_validation_error(tmp_path: Path):
|
|||||||
server_writer, server_reader = anyio.create_memory_object_stream(1)
|
server_writer, server_reader = anyio.create_memory_object_stream(1)
|
||||||
client_writer, client_reader = anyio.create_memory_object_stream(1)
|
client_writer, client_reader = anyio.create_memory_object_stream(1)
|
||||||
|
|
||||||
|
server_ready = anyio.Event()
|
||||||
|
|
||||||
|
async def wrapped_server_handler(read_stream, write_stream):
|
||||||
|
server_ready.set()
|
||||||
|
await server_handler(read_stream, write_stream)
|
||||||
|
|
||||||
async with anyio.create_task_group() as tg:
|
async with anyio.create_task_group() as tg:
|
||||||
tg.start_soon(server_handler, server_reader, client_writer)
|
tg.start_soon(wrapped_server_handler, server_reader, client_writer)
|
||||||
# Wait for server to start and initialize
|
# Wait for server to start and initialize
|
||||||
await anyio.sleep(0.1)
|
with anyio.fail_after(1): # Timeout after 1 second
|
||||||
|
await server_ready.wait()
|
||||||
# Run client in a separate task to avoid cancellation
|
# Run client in a separate task to avoid cancellation
|
||||||
async with anyio.create_task_group() as client_tg:
|
async with anyio.create_task_group() as client_tg:
|
||||||
client_tg.start_soon(client, client_reader, server_writer)
|
client_tg.start_soon(client, client_reader, server_writer)
|
||||||
|
|||||||
Reference in New Issue
Block a user