Create ClientSessionGroup for managing multiple session connections. (#639)

This commit is contained in:
Mo
2025-05-13 04:58:24 -04:00
committed by GitHub
parent fdb538bc28
commit 7b6a903eb9
3 changed files with 771 additions and 0 deletions

View File

@@ -1,4 +1,5 @@
from .client.session import ClientSession
from .client.session_group import ClientSessionGroup
from .client.stdio import StdioServerParameters, stdio_client
from .server.session import ServerSession
from .server.stdio import stdio_server
@@ -63,6 +64,7 @@ __all__ = [
"ClientRequest",
"ClientResult",
"ClientSession",
"ClientSessionGroup",
"CreateMessageRequest",
"CreateMessageResult",
"ErrorData",

View File

@@ -0,0 +1,372 @@
"""
SessionGroup concurrently manages multiple MCP session connections.
Tools, resources, and prompts are aggregated across servers. Servers may
be connected to or disconnected from at any point after initialization.
This abstractions can handle naming collisions using a custom user-provided
hook.
"""
import contextlib
import logging
from collections.abc import Callable
from datetime import timedelta
from types import TracebackType
from typing import Any, TypeAlias
import anyio
from pydantic import BaseModel
from typing_extensions import Self
import mcp
from mcp import types
from mcp.client.sse import sse_client
from mcp.client.stdio import StdioServerParameters
from mcp.client.streamable_http import streamablehttp_client
from mcp.shared.exceptions import McpError
class SseServerParameters(BaseModel):
"""Parameters for intializing a sse_client."""
# The endpoint URL.
url: str
# Optional headers to include in requests.
headers: dict[str, Any] | None = None
# HTTP timeout for regular operations.
timeout: float = 5
# Timeout for SSE read operations.
sse_read_timeout: float = 60 * 5
class StreamableHttpParameters(BaseModel):
"""Parameters for intializing a streamablehttp_client."""
# The endpoint URL.
url: str
# Optional headers to include in requests.
headers: dict[str, Any] | None = None
# HTTP timeout for regular operations.
timeout: timedelta = timedelta(seconds=30)
# Timeout for SSE read operations.
sse_read_timeout: timedelta = timedelta(seconds=60 * 5)
# Close the client session when the transport closes.
terminate_on_close: bool = True
ServerParameters: TypeAlias = (
StdioServerParameters | SseServerParameters | StreamableHttpParameters
)
class ClientSessionGroup:
"""Client for managing connections to multiple MCP servers.
This class is responsible for encapsulating management of server connections.
It aggregates tools, resources, and prompts from all connected servers.
For auxiliary handlers, such as resource subscription, this is delegated to
the client and can be accessed via the session.
Example Usage:
name_fn = lambda name, server_info: f"{(server_info.name)}-{name}"
async with ClientSessionGroup(component_name_hook=name_fn) as group:
for server_params in server_params:
group.connect_to_server(server_param)
...
"""
class _ComponentNames(BaseModel):
"""Used for reverse index to find components."""
prompts: set[str] = set()
resources: set[str] = set()
tools: set[str] = set()
# Standard MCP components.
_prompts: dict[str, types.Prompt]
_resources: dict[str, types.Resource]
_tools: dict[str, types.Tool]
# Client-server connection management.
_sessions: dict[mcp.ClientSession, _ComponentNames]
_tool_to_session: dict[str, mcp.ClientSession]
_exit_stack: contextlib.AsyncExitStack
_session_exit_stacks: dict[mcp.ClientSession, contextlib.AsyncExitStack]
# Optional fn consuming (component_name, serverInfo) for custom names.
# This is provide a means to mitigate naming conflicts across servers.
# Example: (tool_name, serverInfo) => "{result.serverInfo.name}.{tool_name}"
_ComponentNameHook: TypeAlias = Callable[[str, types.Implementation], str]
_component_name_hook: _ComponentNameHook | None
def __init__(
self,
exit_stack: contextlib.AsyncExitStack | None = None,
component_name_hook: _ComponentNameHook | None = None,
) -> None:
"""Initializes the MCP client."""
self._tools = {}
self._resources = {}
self._prompts = {}
self._sessions = {}
self._tool_to_session = {}
if exit_stack is None:
self._exit_stack = contextlib.AsyncExitStack()
self._owns_exit_stack = True
else:
self._exit_stack = exit_stack
self._owns_exit_stack = False
self._session_exit_stacks = {}
self._component_name_hook = component_name_hook
async def __aenter__(self) -> Self:
# Enter the exit stack only if we created it ourselves
if self._owns_exit_stack:
await self._exit_stack.__aenter__()
return self
async def __aexit__(
self,
_exc_type: type[BaseException] | None,
_exc_val: BaseException | None,
_exc_tb: TracebackType | None,
) -> bool | None:
"""Closes session exit stacks and main exit stack upon completion."""
# Concurrently close session stacks.
async with anyio.create_task_group() as tg:
for exit_stack in self._session_exit_stacks.values():
tg.start_soon(exit_stack.aclose)
# Only close the main exit stack if we created it
if self._owns_exit_stack:
await self._exit_stack.aclose()
@property
def sessions(self) -> list[mcp.ClientSession]:
"""Returns the list of sessions being managed."""
return list(self._sessions.keys())
@property
def prompts(self) -> dict[str, types.Prompt]:
"""Returns the prompts as a dictionary of names to prompts."""
return self._prompts
@property
def resources(self) -> dict[str, types.Resource]:
"""Returns the resources as a dictionary of names to resources."""
return self._resources
@property
def tools(self) -> dict[str, types.Tool]:
"""Returns the tools as a dictionary of names to tools."""
return self._tools
async def call_tool(self, name: str, args: dict[str, Any]) -> types.CallToolResult:
"""Executes a tool given its name and arguments."""
session = self._tool_to_session[name]
session_tool_name = self.tools[name].name
return await session.call_tool(session_tool_name, args)
async def disconnect_from_server(self, session: mcp.ClientSession) -> None:
"""Disconnects from a single MCP server."""
session_known_for_components = session in self._sessions
session_known_for_stack = session in self._session_exit_stacks
if not session_known_for_components and not session_known_for_stack:
raise McpError(
types.ErrorData(
code=types.INVALID_PARAMS,
message="Provided session is not managed or already disconnected.",
)
)
if session_known_for_components:
component_names = self._sessions.pop(session) # Pop from _sessions tracking
# Remove prompts associated with the session.
for name in component_names.prompts:
if name in self._prompts:
del self._prompts[name]
# Remove resources associated with the session.
for name in component_names.resources:
if name in self._resources:
del self._resources[name]
# Remove tools associated with the session.
for name in component_names.tools:
if name in self._tools:
del self._tools[name]
if name in self._tool_to_session:
del self._tool_to_session[name]
# Clean up the session's resources via its dedicated exit stack
if session_known_for_stack:
session_stack_to_close = self._session_exit_stacks.pop(session)
await session_stack_to_close.aclose()
async def connect_with_session(
self, server_info: types.Implementation, session: mcp.ClientSession
) -> mcp.ClientSession:
"""Connects to a single MCP server."""
await self._aggregate_components(server_info, session)
return session
async def connect_to_server(
self,
server_params: ServerParameters,
) -> mcp.ClientSession:
"""Connects to a single MCP server."""
server_info, session = await self._establish_session(server_params)
return await self.connect_with_session(server_info, session)
async def _establish_session(
self, server_params: ServerParameters
) -> tuple[types.Implementation, mcp.ClientSession]:
"""Establish a client session to an MCP server."""
session_stack = contextlib.AsyncExitStack()
try:
# Create read and write streams that facilitate io with the server.
if isinstance(server_params, StdioServerParameters):
client = mcp.stdio_client(server_params)
read, write = await session_stack.enter_async_context(client)
elif isinstance(server_params, SseServerParameters):
client = sse_client(
url=server_params.url,
headers=server_params.headers,
timeout=server_params.timeout,
sse_read_timeout=server_params.sse_read_timeout,
)
read, write = await session_stack.enter_async_context(client)
else:
client = streamablehttp_client(
url=server_params.url,
headers=server_params.headers,
timeout=server_params.timeout,
sse_read_timeout=server_params.sse_read_timeout,
terminate_on_close=server_params.terminate_on_close,
)
read, write, _ = await session_stack.enter_async_context(client)
session = await session_stack.enter_async_context(
mcp.ClientSession(read, write)
)
result = await session.initialize()
# Session successfully initialized.
# Store its stack and register the stack with the main group stack.
self._session_exit_stacks[session] = session_stack
# session_stack itself becomes a resource managed by the
# main _exit_stack.
await self._exit_stack.enter_async_context(session_stack)
return result.serverInfo, session
except Exception:
# If anything during this setup fails, ensure the session-specific
# stack is closed.
await session_stack.aclose()
raise
async def _aggregate_components(
self, server_info: types.Implementation, session: mcp.ClientSession
) -> None:
"""Aggregates prompts, resources, and tools from a given session."""
# Create a reverse index so we can find all prompts, resources, and
# tools belonging to this session. Used for removing components from
# the session group via self.disconnect_from_server.
component_names = self._ComponentNames()
# Temporary components dicts. We do not want to modify the aggregate
# lists in case of an intermediate failure.
prompts_temp: dict[str, types.Prompt] = {}
resources_temp: dict[str, types.Resource] = {}
tools_temp: dict[str, types.Tool] = {}
tool_to_session_temp: dict[str, mcp.ClientSession] = {}
# Query the server for its prompts and aggregate to list.
try:
prompts = (await session.list_prompts()).prompts
for prompt in prompts:
name = self._component_name(prompt.name, server_info)
prompts_temp[name] = prompt
component_names.prompts.add(name)
except McpError as err:
logging.warning(f"Could not fetch prompts: {err}")
# Query the server for its resources and aggregate to list.
try:
resources = (await session.list_resources()).resources
for resource in resources:
name = self._component_name(resource.name, server_info)
resources_temp[name] = resource
component_names.resources.add(name)
except McpError as err:
logging.warning(f"Could not fetch resources: {err}")
# Query the server for its tools and aggregate to list.
try:
tools = (await session.list_tools()).tools
for tool in tools:
name = self._component_name(tool.name, server_info)
tools_temp[name] = tool
tool_to_session_temp[name] = session
component_names.tools.add(name)
except McpError as err:
logging.warning(f"Could not fetch tools: {err}")
# Clean up exit stack for session if we couldn't retrieve anything
# from the server.
if not any((prompts_temp, resources_temp, tools_temp)):
del self._session_exit_stacks[session]
# Check for duplicates.
matching_prompts = prompts_temp.keys() & self._prompts.keys()
if matching_prompts:
raise McpError(
types.ErrorData(
code=types.INVALID_PARAMS,
message=f"{matching_prompts} already exist in group prompts.",
)
)
matching_resources = resources_temp.keys() & self._resources.keys()
if matching_resources:
raise McpError(
types.ErrorData(
code=types.INVALID_PARAMS,
message=f"{matching_resources} already exist in group resources.",
)
)
matching_tools = tools_temp.keys() & self._tools.keys()
if matching_tools:
raise McpError(
types.ErrorData(
code=types.INVALID_PARAMS,
message=f"{matching_tools} already exist in group tools.",
)
)
# Aggregate components.
self._sessions[session] = component_names
self._prompts.update(prompts_temp)
self._resources.update(resources_temp)
self._tools.update(tools_temp)
self._tool_to_session.update(tool_to_session_temp)
def _component_name(self, name: str, server_info: types.Implementation) -> str:
if self._component_name_hook:
return self._component_name_hook(name, server_info)
return name

View File

@@ -0,0 +1,397 @@
import contextlib
from unittest import mock
import pytest
import mcp
from mcp import types
from mcp.client.session_group import (
ClientSessionGroup,
SseServerParameters,
StreamableHttpParameters,
)
from mcp.client.stdio import StdioServerParameters
from mcp.shared.exceptions import McpError
@pytest.fixture
def mock_exit_stack():
"""Fixture for a mocked AsyncExitStack."""
# Use unittest.mock.Mock directly if needed, or just a plain object
# if only attribute access/existence is needed.
# For AsyncExitStack, Mock or MagicMock is usually fine.
return mock.MagicMock(spec=contextlib.AsyncExitStack)
@pytest.mark.anyio
class TestClientSessionGroup:
def test_init(self):
mcp_session_group = ClientSessionGroup()
assert not mcp_session_group._tools
assert not mcp_session_group._resources
assert not mcp_session_group._prompts
assert not mcp_session_group._tool_to_session
def test_component_properties(self):
# --- Mock Dependencies ---
mock_prompt = mock.Mock()
mock_resource = mock.Mock()
mock_tool = mock.Mock()
# --- Prepare Session Group ---
mcp_session_group = ClientSessionGroup()
mcp_session_group._prompts = {"my_prompt": mock_prompt}
mcp_session_group._resources = {"my_resource": mock_resource}
mcp_session_group._tools = {"my_tool": mock_tool}
# --- Assertions ---
assert mcp_session_group.prompts == {"my_prompt": mock_prompt}
assert mcp_session_group.resources == {"my_resource": mock_resource}
assert mcp_session_group.tools == {"my_tool": mock_tool}
async def test_call_tool(self):
# --- Mock Dependencies ---
mock_session = mock.AsyncMock()
# --- Prepare Session Group ---
def hook(name, server_info):
return f"{(server_info.name)}-{name}"
mcp_session_group = ClientSessionGroup(component_name_hook=hook)
mcp_session_group._tools = {
"server1-my_tool": types.Tool(name="my_tool", inputSchema={})
}
mcp_session_group._tool_to_session = {"server1-my_tool": mock_session}
text_content = types.TextContent(type="text", text="OK")
mock_session.call_tool.return_value = types.CallToolResult(
content=[text_content]
)
# --- Test Execution ---
result = await mcp_session_group.call_tool(
name="server1-my_tool",
args={
"name": "value1",
"args": {},
},
)
# --- Assertions ---
assert result.content == [text_content]
mock_session.call_tool.assert_called_once_with(
"my_tool",
{"name": "value1", "args": {}},
)
async def test_connect_to_server(self, mock_exit_stack):
"""Test connecting to a server and aggregating components."""
# --- Mock Dependencies ---
mock_server_info = mock.Mock(spec=types.Implementation)
mock_server_info.name = "TestServer1"
mock_session = mock.AsyncMock(spec=mcp.ClientSession)
mock_tool1 = mock.Mock(spec=types.Tool)
mock_tool1.name = "tool_a"
mock_resource1 = mock.Mock(spec=types.Resource)
mock_resource1.name = "resource_b"
mock_prompt1 = mock.Mock(spec=types.Prompt)
mock_prompt1.name = "prompt_c"
mock_session.list_tools.return_value = mock.AsyncMock(tools=[mock_tool1])
mock_session.list_resources.return_value = mock.AsyncMock(
resources=[mock_resource1]
)
mock_session.list_prompts.return_value = mock.AsyncMock(prompts=[mock_prompt1])
# --- Test Execution ---
group = ClientSessionGroup(exit_stack=mock_exit_stack)
with mock.patch.object(
group, "_establish_session", return_value=(mock_server_info, mock_session)
):
await group.connect_to_server(StdioServerParameters(command="test"))
# --- Assertions ---
assert mock_session in group._sessions
assert len(group.tools) == 1
assert "tool_a" in group.tools
assert group.tools["tool_a"] == mock_tool1
assert group._tool_to_session["tool_a"] == mock_session
assert len(group.resources) == 1
assert "resource_b" in group.resources
assert group.resources["resource_b"] == mock_resource1
assert len(group.prompts) == 1
assert "prompt_c" in group.prompts
assert group.prompts["prompt_c"] == mock_prompt1
mock_session.list_tools.assert_awaited_once()
mock_session.list_resources.assert_awaited_once()
mock_session.list_prompts.assert_awaited_once()
async def test_connect_to_server_with_name_hook(self, mock_exit_stack):
"""Test connecting with a component name hook."""
# --- Mock Dependencies ---
mock_server_info = mock.Mock(spec=types.Implementation)
mock_server_info.name = "HookServer"
mock_session = mock.AsyncMock(spec=mcp.ClientSession)
mock_tool = mock.Mock(spec=types.Tool)
mock_tool.name = "base_tool"
mock_session.list_tools.return_value = mock.AsyncMock(tools=[mock_tool])
mock_session.list_resources.return_value = mock.AsyncMock(resources=[])
mock_session.list_prompts.return_value = mock.AsyncMock(prompts=[])
# --- Test Setup ---
def name_hook(name: str, server_info: types.Implementation) -> str:
return f"{server_info.name}.{name}"
# --- Test Execution ---
group = ClientSessionGroup(
exit_stack=mock_exit_stack, component_name_hook=name_hook
)
with mock.patch.object(
group, "_establish_session", return_value=(mock_server_info, mock_session)
):
await group.connect_to_server(StdioServerParameters(command="test"))
# --- Assertions ---
assert mock_session in group._sessions
assert len(group.tools) == 1
expected_tool_name = "HookServer.base_tool"
assert expected_tool_name in group.tools
assert group.tools[expected_tool_name] == mock_tool
assert group._tool_to_session[expected_tool_name] == mock_session
async def test_disconnect_from_server(self): # No mock arguments needed
"""Test disconnecting from a server."""
# --- Test Setup ---
group = ClientSessionGroup()
server_name = "ServerToDisconnect"
# Manually populate state using standard mocks
mock_session1 = mock.MagicMock(spec=mcp.ClientSession)
mock_session2 = mock.MagicMock(spec=mcp.ClientSession)
mock_tool1 = mock.Mock(spec=types.Tool)
mock_tool1.name = "tool1"
mock_resource1 = mock.Mock(spec=types.Resource)
mock_resource1.name = "res1"
mock_prompt1 = mock.Mock(spec=types.Prompt)
mock_prompt1.name = "prm1"
mock_tool2 = mock.Mock(spec=types.Tool)
mock_tool2.name = "tool2"
mock_component_named_like_server = mock.Mock()
mock_session = mock.Mock(spec=mcp.ClientSession)
group._tools = {
"tool1": mock_tool1,
"tool2": mock_tool2,
server_name: mock_component_named_like_server,
}
group._tool_to_session = {
"tool1": mock_session1,
"tool2": mock_session2,
server_name: mock_session1,
}
group._resources = {
"res1": mock_resource1,
server_name: mock_component_named_like_server,
}
group._prompts = {
"prm1": mock_prompt1,
server_name: mock_component_named_like_server,
}
group._sessions = {
mock_session: ClientSessionGroup._ComponentNames(
prompts=set({"prm1"}),
resources=set({"res1"}),
tools=set({"tool1", "tool2"}),
)
}
# --- Assertions ---
assert mock_session in group._sessions
assert "tool1" in group._tools
assert "tool2" in group._tools
assert "res1" in group._resources
assert "prm1" in group._prompts
# --- Test Execution ---
await group.disconnect_from_server(mock_session)
# --- Assertions ---
assert mock_session not in group._sessions
assert "tool1" not in group._tools
assert "tool2" not in group._tools
assert "res1" not in group._resources
assert "prm1" not in group._prompts
async def test_connect_to_server_duplicate_tool_raises_error(self, mock_exit_stack):
"""Test McpError raised when connecting a server with a dup name."""
# --- Setup Pre-existing State ---
group = ClientSessionGroup(exit_stack=mock_exit_stack)
existing_tool_name = "shared_tool"
# Manually add a tool to simulate a previous connection
group._tools[existing_tool_name] = mock.Mock(spec=types.Tool)
group._tools[existing_tool_name].name = existing_tool_name
# Need a dummy session associated with the existing tool
mock_session = mock.MagicMock(spec=mcp.ClientSession)
group._tool_to_session[existing_tool_name] = mock_session
group._session_exit_stacks[mock_session] = mock.Mock(
spec=contextlib.AsyncExitStack
)
# --- Mock New Connection Attempt ---
mock_server_info_new = mock.Mock(spec=types.Implementation)
mock_server_info_new.name = "ServerWithDuplicate"
mock_session_new = mock.AsyncMock(spec=mcp.ClientSession)
# Configure the new session to return a tool with the *same name*
duplicate_tool = mock.Mock(spec=types.Tool)
duplicate_tool.name = existing_tool_name
mock_session_new.list_tools.return_value = mock.AsyncMock(
tools=[duplicate_tool]
)
# Keep other lists empty for simplicity
mock_session_new.list_resources.return_value = mock.AsyncMock(resources=[])
mock_session_new.list_prompts.return_value = mock.AsyncMock(prompts=[])
# --- Test Execution and Assertion ---
with pytest.raises(McpError) as excinfo:
with mock.patch.object(
group,
"_establish_session",
return_value=(mock_server_info_new, mock_session_new),
):
await group.connect_to_server(StdioServerParameters(command="test"))
# Assert details about the raised error
assert excinfo.value.error.code == types.INVALID_PARAMS
assert existing_tool_name in excinfo.value.error.message
assert "already exist " in excinfo.value.error.message
# Verify the duplicate tool was *not* added again (state should be unchanged)
assert len(group._tools) == 1 # Should still only have the original
assert (
group._tools[existing_tool_name] is not duplicate_tool
) # Ensure it's the original mock
# No patching needed here
async def test_disconnect_non_existent_server(self):
"""Test disconnecting a server that isn't connected."""
session = mock.Mock(spec=mcp.ClientSession)
group = ClientSessionGroup()
with pytest.raises(McpError):
await group.disconnect_from_server(session)
@pytest.mark.parametrize(
"server_params_instance, client_type_name, patch_target_for_client_func",
[
(
StdioServerParameters(command="test_stdio_cmd"),
"stdio",
"mcp.client.session_group.mcp.stdio_client",
),
(
SseServerParameters(url="http://test.com/sse", timeout=10),
"sse",
"mcp.client.session_group.sse_client",
), # url, headers, timeout, sse_read_timeout
(
StreamableHttpParameters(
url="http://test.com/stream", terminate_on_close=False
),
"streamablehttp",
"mcp.client.session_group.streamablehttp_client",
), # url, headers, timeout, sse_read_timeout, terminate_on_close
],
)
async def test_establish_session_parameterized(
self,
server_params_instance,
client_type_name, # Just for clarity or conditional logic if needed
patch_target_for_client_func,
):
with mock.patch(
"mcp.client.session_group.mcp.ClientSession"
) as mock_ClientSession_class:
with mock.patch(patch_target_for_client_func) as mock_specific_client_func:
mock_client_cm_instance = mock.AsyncMock(
name=f"{client_type_name}ClientCM"
)
mock_read_stream = mock.AsyncMock(name=f"{client_type_name}Read")
mock_write_stream = mock.AsyncMock(name=f"{client_type_name}Write")
# streamablehttp_client's __aenter__ returns three values
if client_type_name == "streamablehttp":
mock_extra_stream_val = mock.AsyncMock(name="StreamableExtra")
mock_client_cm_instance.__aenter__.return_value = (
mock_read_stream,
mock_write_stream,
mock_extra_stream_val,
)
else:
mock_client_cm_instance.__aenter__.return_value = (
mock_read_stream,
mock_write_stream,
)
mock_client_cm_instance.__aexit__ = mock.AsyncMock(return_value=None)
mock_specific_client_func.return_value = mock_client_cm_instance
# --- Mock mcp.ClientSession (class) ---
# mock_ClientSession_class is already provided by the outer patch
mock_raw_session_cm = mock.AsyncMock(name="RawSessionCM")
mock_ClientSession_class.return_value = mock_raw_session_cm
mock_entered_session = mock.AsyncMock(name="EnteredSessionInstance")
mock_raw_session_cm.__aenter__.return_value = mock_entered_session
mock_raw_session_cm.__aexit__ = mock.AsyncMock(return_value=None)
# Mock session.initialize()
mock_initialize_result = mock.AsyncMock(name="InitializeResult")
mock_initialize_result.serverInfo = types.Implementation(
name="foo", version="1"
)
mock_entered_session.initialize.return_value = mock_initialize_result
# --- Test Execution ---
group = ClientSessionGroup()
returned_server_info = None
returned_session = None
async with contextlib.AsyncExitStack() as stack:
group._exit_stack = stack
(
returned_server_info,
returned_session,
) = await group._establish_session(server_params_instance)
# --- Assertions ---
# 1. Assert the correct specific client function was called
if client_type_name == "stdio":
mock_specific_client_func.assert_called_once_with(
server_params_instance
)
elif client_type_name == "sse":
mock_specific_client_func.assert_called_once_with(
url=server_params_instance.url,
headers=server_params_instance.headers,
timeout=server_params_instance.timeout,
sse_read_timeout=server_params_instance.sse_read_timeout,
)
elif client_type_name == "streamablehttp":
mock_specific_client_func.assert_called_once_with(
url=server_params_instance.url,
headers=server_params_instance.headers,
timeout=server_params_instance.timeout,
sse_read_timeout=server_params_instance.sse_read_timeout,
terminate_on_close=server_params_instance.terminate_on_close,
)
mock_client_cm_instance.__aenter__.assert_awaited_once()
# 2. Assert ClientSession was called correctly
mock_ClientSession_class.assert_called_once_with(
mock_read_stream, mock_write_stream
)
mock_raw_session_cm.__aenter__.assert_awaited_once()
mock_entered_session.initialize.assert_awaited_once()
# 3. Assert returned values
assert returned_server_info is mock_initialize_result.serverInfo
assert returned_session is mock_entered_session