mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-18 22:44:20 +01:00
Properly clean up response streams in BaseSession (#515)
This commit is contained in:
@@ -187,7 +187,6 @@ class BaseSession(
|
||||
self._receive_notification_type = receive_notification_type
|
||||
self._session_read_timeout_seconds = read_timeout_seconds
|
||||
self._in_flight = {}
|
||||
|
||||
self._exit_stack = AsyncExitStack()
|
||||
|
||||
async def __aenter__(self) -> Self:
|
||||
@@ -232,45 +231,48 @@ class BaseSession(
|
||||
](1)
|
||||
self._response_streams[request_id] = response_stream
|
||||
|
||||
self._exit_stack.push_async_callback(lambda: response_stream.aclose())
|
||||
self._exit_stack.push_async_callback(lambda: response_stream_reader.aclose())
|
||||
|
||||
jsonrpc_request = JSONRPCRequest(
|
||||
jsonrpc="2.0",
|
||||
id=request_id,
|
||||
**request.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
)
|
||||
|
||||
# TODO: Support progress callbacks
|
||||
|
||||
await self._write_stream.send(JSONRPCMessage(jsonrpc_request))
|
||||
|
||||
# request read timeout takes precedence over session read timeout
|
||||
timeout = None
|
||||
if request_read_timeout_seconds is not None:
|
||||
timeout = request_read_timeout_seconds.total_seconds()
|
||||
elif self._session_read_timeout_seconds is not None:
|
||||
timeout = self._session_read_timeout_seconds.total_seconds()
|
||||
|
||||
try:
|
||||
with anyio.fail_after(timeout):
|
||||
response_or_error = await response_stream_reader.receive()
|
||||
except TimeoutError:
|
||||
raise McpError(
|
||||
ErrorData(
|
||||
code=httpx.codes.REQUEST_TIMEOUT,
|
||||
message=(
|
||||
f"Timed out while waiting for response to "
|
||||
f"{request.__class__.__name__}. Waited "
|
||||
f"{timeout} seconds."
|
||||
),
|
||||
)
|
||||
jsonrpc_request = JSONRPCRequest(
|
||||
jsonrpc="2.0",
|
||||
id=request_id,
|
||||
**request.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
)
|
||||
|
||||
if isinstance(response_or_error, JSONRPCError):
|
||||
raise McpError(response_or_error.error)
|
||||
else:
|
||||
return result_type.model_validate(response_or_error.result)
|
||||
# TODO: Support progress callbacks
|
||||
|
||||
await self._write_stream.send(JSONRPCMessage(jsonrpc_request))
|
||||
|
||||
# request read timeout takes precedence over session read timeout
|
||||
timeout = None
|
||||
if request_read_timeout_seconds is not None:
|
||||
timeout = request_read_timeout_seconds.total_seconds()
|
||||
elif self._session_read_timeout_seconds is not None:
|
||||
timeout = self._session_read_timeout_seconds.total_seconds()
|
||||
|
||||
try:
|
||||
with anyio.fail_after(timeout):
|
||||
response_or_error = await response_stream_reader.receive()
|
||||
except TimeoutError:
|
||||
raise McpError(
|
||||
ErrorData(
|
||||
code=httpx.codes.REQUEST_TIMEOUT,
|
||||
message=(
|
||||
f"Timed out while waiting for response to "
|
||||
f"{request.__class__.__name__}. Waited "
|
||||
f"{timeout} seconds."
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if isinstance(response_or_error, JSONRPCError):
|
||||
raise McpError(response_or_error.error)
|
||||
else:
|
||||
return result_type.model_validate(response_or_error.result)
|
||||
|
||||
finally:
|
||||
self._response_streams.pop(request_id, None)
|
||||
await response_stream.aclose()
|
||||
await response_stream_reader.aclose()
|
||||
|
||||
async def send_notification(self, notification: SendNotificationT) -> None:
|
||||
"""
|
||||
|
||||
68
tests/client/test_resource_cleanup.py
Normal file
68
tests/client/test_resource_cleanup.py
Normal file
@@ -0,0 +1,68 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import anyio
|
||||
import pytest
|
||||
|
||||
from mcp.shared.session import BaseSession
|
||||
from mcp.types import (
|
||||
ClientRequest,
|
||||
EmptyResult,
|
||||
PingRequest,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_send_request_stream_cleanup():
|
||||
"""
|
||||
Test that send_request properly cleans up streams when an exception occurs.
|
||||
|
||||
This test mocks out most of the session functionality to focus on stream cleanup.
|
||||
"""
|
||||
|
||||
# Create a mock session with the minimal required functionality
|
||||
class TestSession(BaseSession):
|
||||
async def _send_response(self, request_id, response):
|
||||
pass
|
||||
|
||||
# Create streams
|
||||
write_stream_send, write_stream_receive = anyio.create_memory_object_stream(1)
|
||||
read_stream_send, read_stream_receive = anyio.create_memory_object_stream(1)
|
||||
|
||||
# Create the session
|
||||
session = TestSession(
|
||||
read_stream_receive,
|
||||
write_stream_send,
|
||||
object, # Request type doesn't matter for this test
|
||||
object, # Notification type doesn't matter for this test
|
||||
)
|
||||
|
||||
# Create a test request
|
||||
request = ClientRequest(
|
||||
PingRequest(
|
||||
method="ping",
|
||||
)
|
||||
)
|
||||
|
||||
# Patch the _write_stream.send method to raise an exception
|
||||
async def mock_send(*args, **kwargs):
|
||||
raise RuntimeError("Simulated network error")
|
||||
|
||||
# Record the response streams before the test
|
||||
initial_stream_count = len(session._response_streams)
|
||||
|
||||
# Run the test with the patched method
|
||||
with patch.object(session._write_stream, "send", mock_send):
|
||||
with pytest.raises(RuntimeError):
|
||||
await session.send_request(request, EmptyResult)
|
||||
|
||||
# Verify that no response streams were leaked
|
||||
assert len(session._response_streams) == initial_stream_count, (
|
||||
f"Expected {initial_stream_count} response streams after request, "
|
||||
f"but found {len(session._response_streams)}"
|
||||
)
|
||||
|
||||
# Clean up
|
||||
await write_stream_send.aclose()
|
||||
await write_stream_receive.aclose()
|
||||
await read_stream_send.aclose()
|
||||
await read_stream_receive.aclose()
|
||||
Reference in New Issue
Block a user