send errors to pending requests if server closes (#333)

Co-authored-by: ihrpr <inna@anthropic.com>
This commit is contained in:
Marshall Roch
2025-05-27 17:55:27 -04:00
committed by GitHub
parent 9dad26620f
commit 7901552eba
3 changed files with 71 additions and 1 deletions

View File

@@ -7,7 +7,10 @@ import mcp.types as types
from mcp.client.session import ClientSession
from mcp.server.lowlevel.server import Server
from mcp.shared.exceptions import McpError
from mcp.shared.memory import create_connected_server_and_client_session
from mcp.shared.memory import (
create_client_server_memory_streams,
create_connected_server_and_client_session,
)
from mcp.types import (
CancelledNotification,
CancelledNotificationParams,
@@ -124,3 +127,57 @@ async def test_request_cancellation():
# Give cancellation time to process
with anyio.fail_after(1):
await ev_cancelled.wait()
@pytest.mark.anyio
async def test_connection_closed():
"""
Test that pending requests are cancelled when the connection is closed remotely.
"""
ev_closed = anyio.Event()
ev_response = anyio.Event()
async with create_client_server_memory_streams() as (
client_streams,
server_streams,
):
client_read, client_write = client_streams
server_read, server_write = server_streams
async def make_request(client_session):
"""Send a request in a separate task"""
nonlocal ev_response
try:
# any request will do
await client_session.initialize()
pytest.fail("Request should have errored")
except McpError as e:
# Expected - request errored
assert "Connection closed" in str(e)
ev_response.set()
async def mock_server():
"""Wait for a request, then close the connection"""
nonlocal ev_closed
# Wait for a request
await server_read.receive()
# Close the connection, as if the server exited
server_write.close()
server_read.close()
ev_closed.set()
async with (
anyio.create_task_group() as tg,
ClientSession(
read_stream=client_read,
write_stream=client_write,
) as client_session,
):
tg.start_soon(make_request, client_session)
tg.start_soon(mock_server)
with anyio.fail_after(1):
await ev_closed.wait()
with anyio.fail_after(1):
await ev_response.wait()