send errors to pending requests if server closes (#333)

Co-authored-by: ihrpr <inna@anthropic.com>
This commit is contained in:
Marshall Roch
2025-05-27 17:55:27 -04:00
committed by GitHub
parent 9dad26620f
commit 7901552eba
3 changed files with 71 additions and 1 deletions

View File

@@ -14,6 +14,7 @@ from typing_extensions import Self
from mcp.shared.exceptions import McpError from mcp.shared.exceptions import McpError
from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage
from mcp.types import ( from mcp.types import (
CONNECTION_CLOSED,
CancelledNotification, CancelledNotification,
ClientNotification, ClientNotification,
ClientRequest, ClientRequest,
@@ -417,6 +418,14 @@ class BaseSession(
) )
) )
# after the read stream is closed, we need to send errors
# to any pending requests
for id, stream in self._response_streams.items():
error = ErrorData(code=CONNECTION_CLOSED, message="Connection closed")
await stream.send(JSONRPCError(jsonrpc="2.0", id=id, error=error))
await stream.aclose()
self._response_streams.clear()
async def _received_request( async def _received_request(
self, responder: RequestResponder[ReceiveRequestT, SendResultT] self, responder: RequestResponder[ReceiveRequestT, SendResultT]
) -> None: ) -> None:

View File

@@ -149,6 +149,10 @@ class JSONRPCResponse(BaseModel):
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
# SDK error codes
CONNECTION_CLOSED = -32000
# REQUEST_TIMEOUT = -32001 # the typescript sdk uses this
# Standard JSON-RPC error codes # Standard JSON-RPC error codes
PARSE_ERROR = -32700 PARSE_ERROR = -32700
INVALID_REQUEST = -32600 INVALID_REQUEST = -32600

View File

@@ -7,7 +7,10 @@ import mcp.types as types
from mcp.client.session import ClientSession from mcp.client.session import ClientSession
from mcp.server.lowlevel.server import Server from mcp.server.lowlevel.server import Server
from mcp.shared.exceptions import McpError from mcp.shared.exceptions import McpError
from mcp.shared.memory import create_connected_server_and_client_session from mcp.shared.memory import (
create_client_server_memory_streams,
create_connected_server_and_client_session,
)
from mcp.types import ( from mcp.types import (
CancelledNotification, CancelledNotification,
CancelledNotificationParams, CancelledNotificationParams,
@@ -124,3 +127,57 @@ async def test_request_cancellation():
# Give cancellation time to process # Give cancellation time to process
with anyio.fail_after(1): with anyio.fail_after(1):
await ev_cancelled.wait() await ev_cancelled.wait()
@pytest.mark.anyio
async def test_connection_closed():
"""
Test that pending requests are cancelled when the connection is closed remotely.
"""
ev_closed = anyio.Event()
ev_response = anyio.Event()
async with create_client_server_memory_streams() as (
client_streams,
server_streams,
):
client_read, client_write = client_streams
server_read, server_write = server_streams
async def make_request(client_session):
"""Send a request in a separate task"""
nonlocal ev_response
try:
# any request will do
await client_session.initialize()
pytest.fail("Request should have errored")
except McpError as e:
# Expected - request errored
assert "Connection closed" in str(e)
ev_response.set()
async def mock_server():
"""Wait for a request, then close the connection"""
nonlocal ev_closed
# Wait for a request
await server_read.receive()
# Close the connection, as if the server exited
server_write.close()
server_read.close()
ev_closed.set()
async with (
anyio.create_task_group() as tg,
ClientSession(
read_stream=client_read,
write_stream=client_write,
) as client_session,
):
tg.start_soon(make_request, client_session)
tg.start_soon(mock_server)
with anyio.fail_after(1):
await ev_closed.wait()
with anyio.fail_after(1):
await ev_response.wait()