Properly clean up response streams in BaseSession (#515)

This commit is contained in:
bhosmer-ant
2025-05-01 09:45:47 -04:00
committed by GitHub
parent 1a330ac672
commit 82bd8bc1d9
2 changed files with 107 additions and 37 deletions

View File

@@ -187,7 +187,6 @@ class BaseSession(
self._receive_notification_type = receive_notification_type
self._session_read_timeout_seconds = read_timeout_seconds
self._in_flight = {}
self._exit_stack = AsyncExitStack()
async def __aenter__(self) -> Self:
@@ -232,45 +231,48 @@ class BaseSession(
](1)
self._response_streams[request_id] = response_stream
self._exit_stack.push_async_callback(lambda: response_stream.aclose())
self._exit_stack.push_async_callback(lambda: response_stream_reader.aclose())
jsonrpc_request = JSONRPCRequest(
jsonrpc="2.0",
id=request_id,
**request.model_dump(by_alias=True, mode="json", exclude_none=True),
)
# TODO: Support progress callbacks
await self._write_stream.send(JSONRPCMessage(jsonrpc_request))
# request read timeout takes precedence over session read timeout
timeout = None
if request_read_timeout_seconds is not None:
timeout = request_read_timeout_seconds.total_seconds()
elif self._session_read_timeout_seconds is not None:
timeout = self._session_read_timeout_seconds.total_seconds()
try:
with anyio.fail_after(timeout):
response_or_error = await response_stream_reader.receive()
except TimeoutError:
raise McpError(
ErrorData(
code=httpx.codes.REQUEST_TIMEOUT,
message=(
f"Timed out while waiting for response to "
f"{request.__class__.__name__}. Waited "
f"{timeout} seconds."
),
)
jsonrpc_request = JSONRPCRequest(
jsonrpc="2.0",
id=request_id,
**request.model_dump(by_alias=True, mode="json", exclude_none=True),
)
if isinstance(response_or_error, JSONRPCError):
raise McpError(response_or_error.error)
else:
return result_type.model_validate(response_or_error.result)
# TODO: Support progress callbacks
await self._write_stream.send(JSONRPCMessage(jsonrpc_request))
# request read timeout takes precedence over session read timeout
timeout = None
if request_read_timeout_seconds is not None:
timeout = request_read_timeout_seconds.total_seconds()
elif self._session_read_timeout_seconds is not None:
timeout = self._session_read_timeout_seconds.total_seconds()
try:
with anyio.fail_after(timeout):
response_or_error = await response_stream_reader.receive()
except TimeoutError:
raise McpError(
ErrorData(
code=httpx.codes.REQUEST_TIMEOUT,
message=(
f"Timed out while waiting for response to "
f"{request.__class__.__name__}. Waited "
f"{timeout} seconds."
),
)
)
if isinstance(response_or_error, JSONRPCError):
raise McpError(response_or_error.error)
else:
return result_type.model_validate(response_or_error.result)
finally:
self._response_streams.pop(request_id, None)
await response_stream.aclose()
await response_stream_reader.aclose()
async def send_notification(self, notification: SendNotificationT) -> None:
"""

View File

@@ -0,0 +1,68 @@
from unittest.mock import patch
import anyio
import pytest
from mcp.shared.session import BaseSession
from mcp.types import (
ClientRequest,
EmptyResult,
PingRequest,
)
@pytest.mark.anyio
async def test_send_request_stream_cleanup():
"""
Test that send_request properly cleans up streams when an exception occurs.
This test mocks out most of the session functionality to focus on stream cleanup.
"""
# Create a mock session with the minimal required functionality
class TestSession(BaseSession):
async def _send_response(self, request_id, response):
pass
# Create streams
write_stream_send, write_stream_receive = anyio.create_memory_object_stream(1)
read_stream_send, read_stream_receive = anyio.create_memory_object_stream(1)
# Create the session
session = TestSession(
read_stream_receive,
write_stream_send,
object, # Request type doesn't matter for this test
object, # Notification type doesn't matter for this test
)
# Create a test request
request = ClientRequest(
PingRequest(
method="ping",
)
)
# Patch the _write_stream.send method to raise an exception
async def mock_send(*args, **kwargs):
raise RuntimeError("Simulated network error")
# Record the response streams before the test
initial_stream_count = len(session._response_streams)
# Run the test with the patched method
with patch.object(session._write_stream, "send", mock_send):
with pytest.raises(RuntimeError):
await session.send_request(request, EmptyResult)
# Verify that no response streams were leaked
assert len(session._response_streams) == initial_stream_count, (
f"Expected {initial_stream_count} response streams after request, "
f"but found {len(session._response_streams)}"
)
# Clean up
await write_stream_send.aclose()
await write_stream_receive.aclose()
await read_stream_send.aclose()
await read_stream_receive.aclose()