mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 14:54:24 +01:00
send errors to pending requests if server closes (#333)
Co-authored-by: ihrpr <inna@anthropic.com>
This commit is contained in:
@@ -14,6 +14,7 @@ from typing_extensions import Self
|
|||||||
from mcp.shared.exceptions import McpError
|
from mcp.shared.exceptions import McpError
|
||||||
from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage
|
from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage
|
||||||
from mcp.types import (
|
from mcp.types import (
|
||||||
|
CONNECTION_CLOSED,
|
||||||
CancelledNotification,
|
CancelledNotification,
|
||||||
ClientNotification,
|
ClientNotification,
|
||||||
ClientRequest,
|
ClientRequest,
|
||||||
@@ -417,6 +418,14 @@ class BaseSession(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# after the read stream is closed, we need to send errors
|
||||||
|
# to any pending requests
|
||||||
|
for id, stream in self._response_streams.items():
|
||||||
|
error = ErrorData(code=CONNECTION_CLOSED, message="Connection closed")
|
||||||
|
await stream.send(JSONRPCError(jsonrpc="2.0", id=id, error=error))
|
||||||
|
await stream.aclose()
|
||||||
|
self._response_streams.clear()
|
||||||
|
|
||||||
async def _received_request(
|
async def _received_request(
|
||||||
self, responder: RequestResponder[ReceiveRequestT, SendResultT]
|
self, responder: RequestResponder[ReceiveRequestT, SendResultT]
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|||||||
@@ -149,6 +149,10 @@ class JSONRPCResponse(BaseModel):
|
|||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
|
||||||
|
# SDK error codes
|
||||||
|
CONNECTION_CLOSED = -32000
|
||||||
|
# REQUEST_TIMEOUT = -32001 # the typescript sdk uses this
|
||||||
|
|
||||||
# Standard JSON-RPC error codes
|
# Standard JSON-RPC error codes
|
||||||
PARSE_ERROR = -32700
|
PARSE_ERROR = -32700
|
||||||
INVALID_REQUEST = -32600
|
INVALID_REQUEST = -32600
|
||||||
|
|||||||
@@ -7,7 +7,10 @@ import mcp.types as types
|
|||||||
from mcp.client.session import ClientSession
|
from mcp.client.session import ClientSession
|
||||||
from mcp.server.lowlevel.server import Server
|
from mcp.server.lowlevel.server import Server
|
||||||
from mcp.shared.exceptions import McpError
|
from mcp.shared.exceptions import McpError
|
||||||
from mcp.shared.memory import create_connected_server_and_client_session
|
from mcp.shared.memory import (
|
||||||
|
create_client_server_memory_streams,
|
||||||
|
create_connected_server_and_client_session,
|
||||||
|
)
|
||||||
from mcp.types import (
|
from mcp.types import (
|
||||||
CancelledNotification,
|
CancelledNotification,
|
||||||
CancelledNotificationParams,
|
CancelledNotificationParams,
|
||||||
@@ -124,3 +127,57 @@ async def test_request_cancellation():
|
|||||||
# Give cancellation time to process
|
# Give cancellation time to process
|
||||||
with anyio.fail_after(1):
|
with anyio.fail_after(1):
|
||||||
await ev_cancelled.wait()
|
await ev_cancelled.wait()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_connection_closed():
|
||||||
|
"""
|
||||||
|
Test that pending requests are cancelled when the connection is closed remotely.
|
||||||
|
"""
|
||||||
|
|
||||||
|
ev_closed = anyio.Event()
|
||||||
|
ev_response = anyio.Event()
|
||||||
|
|
||||||
|
async with create_client_server_memory_streams() as (
|
||||||
|
client_streams,
|
||||||
|
server_streams,
|
||||||
|
):
|
||||||
|
client_read, client_write = client_streams
|
||||||
|
server_read, server_write = server_streams
|
||||||
|
|
||||||
|
async def make_request(client_session):
|
||||||
|
"""Send a request in a separate task"""
|
||||||
|
nonlocal ev_response
|
||||||
|
try:
|
||||||
|
# any request will do
|
||||||
|
await client_session.initialize()
|
||||||
|
pytest.fail("Request should have errored")
|
||||||
|
except McpError as e:
|
||||||
|
# Expected - request errored
|
||||||
|
assert "Connection closed" in str(e)
|
||||||
|
ev_response.set()
|
||||||
|
|
||||||
|
async def mock_server():
|
||||||
|
"""Wait for a request, then close the connection"""
|
||||||
|
nonlocal ev_closed
|
||||||
|
# Wait for a request
|
||||||
|
await server_read.receive()
|
||||||
|
# Close the connection, as if the server exited
|
||||||
|
server_write.close()
|
||||||
|
server_read.close()
|
||||||
|
ev_closed.set()
|
||||||
|
|
||||||
|
async with (
|
||||||
|
anyio.create_task_group() as tg,
|
||||||
|
ClientSession(
|
||||||
|
read_stream=client_read,
|
||||||
|
write_stream=client_write,
|
||||||
|
) as client_session,
|
||||||
|
):
|
||||||
|
tg.start_soon(make_request, client_session)
|
||||||
|
tg.start_soon(mock_server)
|
||||||
|
|
||||||
|
with anyio.fail_after(1):
|
||||||
|
await ev_closed.wait()
|
||||||
|
with anyio.fail_after(1):
|
||||||
|
await ev_response.wait()
|
||||||
|
|||||||
Reference in New Issue
Block a user