mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-18 22:44:20 +01:00
Create ClientSessionGroup for managing multiple session connections. (#639)
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
from .client.session import ClientSession
|
||||
from .client.session_group import ClientSessionGroup
|
||||
from .client.stdio import StdioServerParameters, stdio_client
|
||||
from .server.session import ServerSession
|
||||
from .server.stdio import stdio_server
|
||||
@@ -63,6 +64,7 @@ __all__ = [
|
||||
"ClientRequest",
|
||||
"ClientResult",
|
||||
"ClientSession",
|
||||
"ClientSessionGroup",
|
||||
"CreateMessageRequest",
|
||||
"CreateMessageResult",
|
||||
"ErrorData",
|
||||
|
||||
372
src/mcp/client/session_group.py
Normal file
372
src/mcp/client/session_group.py
Normal file
@@ -0,0 +1,372 @@
|
||||
"""
|
||||
SessionGroup concurrently manages multiple MCP session connections.
|
||||
|
||||
Tools, resources, and prompts are aggregated across servers. Servers may
|
||||
be connected to or disconnected from at any point after initialization.
|
||||
|
||||
This abstractions can handle naming collisions using a custom user-provided
|
||||
hook.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from datetime import timedelta
|
||||
from types import TracebackType
|
||||
from typing import Any, TypeAlias
|
||||
|
||||
import anyio
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
import mcp
|
||||
from mcp import types
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.client.stdio import StdioServerParameters
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
from mcp.shared.exceptions import McpError
|
||||
|
||||
|
||||
class SseServerParameters(BaseModel):
|
||||
"""Parameters for intializing a sse_client."""
|
||||
|
||||
# The endpoint URL.
|
||||
url: str
|
||||
|
||||
# Optional headers to include in requests.
|
||||
headers: dict[str, Any] | None = None
|
||||
|
||||
# HTTP timeout for regular operations.
|
||||
timeout: float = 5
|
||||
|
||||
# Timeout for SSE read operations.
|
||||
sse_read_timeout: float = 60 * 5
|
||||
|
||||
|
||||
class StreamableHttpParameters(BaseModel):
|
||||
"""Parameters for intializing a streamablehttp_client."""
|
||||
|
||||
# The endpoint URL.
|
||||
url: str
|
||||
|
||||
# Optional headers to include in requests.
|
||||
headers: dict[str, Any] | None = None
|
||||
|
||||
# HTTP timeout for regular operations.
|
||||
timeout: timedelta = timedelta(seconds=30)
|
||||
|
||||
# Timeout for SSE read operations.
|
||||
sse_read_timeout: timedelta = timedelta(seconds=60 * 5)
|
||||
|
||||
# Close the client session when the transport closes.
|
||||
terminate_on_close: bool = True
|
||||
|
||||
|
||||
ServerParameters: TypeAlias = (
|
||||
StdioServerParameters | SseServerParameters | StreamableHttpParameters
|
||||
)
|
||||
|
||||
|
||||
class ClientSessionGroup:
|
||||
"""Client for managing connections to multiple MCP servers.
|
||||
|
||||
This class is responsible for encapsulating management of server connections.
|
||||
It aggregates tools, resources, and prompts from all connected servers.
|
||||
|
||||
For auxiliary handlers, such as resource subscription, this is delegated to
|
||||
the client and can be accessed via the session.
|
||||
|
||||
Example Usage:
|
||||
name_fn = lambda name, server_info: f"{(server_info.name)}-{name}"
|
||||
async with ClientSessionGroup(component_name_hook=name_fn) as group:
|
||||
for server_params in server_params:
|
||||
group.connect_to_server(server_param)
|
||||
...
|
||||
|
||||
"""
|
||||
|
||||
class _ComponentNames(BaseModel):
|
||||
"""Used for reverse index to find components."""
|
||||
|
||||
prompts: set[str] = set()
|
||||
resources: set[str] = set()
|
||||
tools: set[str] = set()
|
||||
|
||||
# Standard MCP components.
|
||||
_prompts: dict[str, types.Prompt]
|
||||
_resources: dict[str, types.Resource]
|
||||
_tools: dict[str, types.Tool]
|
||||
|
||||
# Client-server connection management.
|
||||
_sessions: dict[mcp.ClientSession, _ComponentNames]
|
||||
_tool_to_session: dict[str, mcp.ClientSession]
|
||||
_exit_stack: contextlib.AsyncExitStack
|
||||
_session_exit_stacks: dict[mcp.ClientSession, contextlib.AsyncExitStack]
|
||||
|
||||
# Optional fn consuming (component_name, serverInfo) for custom names.
|
||||
# This is provide a means to mitigate naming conflicts across servers.
|
||||
# Example: (tool_name, serverInfo) => "{result.serverInfo.name}.{tool_name}"
|
||||
_ComponentNameHook: TypeAlias = Callable[[str, types.Implementation], str]
|
||||
_component_name_hook: _ComponentNameHook | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
exit_stack: contextlib.AsyncExitStack | None = None,
|
||||
component_name_hook: _ComponentNameHook | None = None,
|
||||
) -> None:
|
||||
"""Initializes the MCP client."""
|
||||
|
||||
self._tools = {}
|
||||
self._resources = {}
|
||||
self._prompts = {}
|
||||
|
||||
self._sessions = {}
|
||||
self._tool_to_session = {}
|
||||
if exit_stack is None:
|
||||
self._exit_stack = contextlib.AsyncExitStack()
|
||||
self._owns_exit_stack = True
|
||||
else:
|
||||
self._exit_stack = exit_stack
|
||||
self._owns_exit_stack = False
|
||||
self._session_exit_stacks = {}
|
||||
self._component_name_hook = component_name_hook
|
||||
|
||||
async def __aenter__(self) -> Self:
|
||||
# Enter the exit stack only if we created it ourselves
|
||||
if self._owns_exit_stack:
|
||||
await self._exit_stack.__aenter__()
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
_exc_type: type[BaseException] | None,
|
||||
_exc_val: BaseException | None,
|
||||
_exc_tb: TracebackType | None,
|
||||
) -> bool | None:
|
||||
"""Closes session exit stacks and main exit stack upon completion."""
|
||||
|
||||
# Concurrently close session stacks.
|
||||
async with anyio.create_task_group() as tg:
|
||||
for exit_stack in self._session_exit_stacks.values():
|
||||
tg.start_soon(exit_stack.aclose)
|
||||
|
||||
# Only close the main exit stack if we created it
|
||||
if self._owns_exit_stack:
|
||||
await self._exit_stack.aclose()
|
||||
|
||||
@property
|
||||
def sessions(self) -> list[mcp.ClientSession]:
|
||||
"""Returns the list of sessions being managed."""
|
||||
return list(self._sessions.keys())
|
||||
|
||||
@property
|
||||
def prompts(self) -> dict[str, types.Prompt]:
|
||||
"""Returns the prompts as a dictionary of names to prompts."""
|
||||
return self._prompts
|
||||
|
||||
@property
|
||||
def resources(self) -> dict[str, types.Resource]:
|
||||
"""Returns the resources as a dictionary of names to resources."""
|
||||
return self._resources
|
||||
|
||||
@property
|
||||
def tools(self) -> dict[str, types.Tool]:
|
||||
"""Returns the tools as a dictionary of names to tools."""
|
||||
return self._tools
|
||||
|
||||
async def call_tool(self, name: str, args: dict[str, Any]) -> types.CallToolResult:
|
||||
"""Executes a tool given its name and arguments."""
|
||||
session = self._tool_to_session[name]
|
||||
session_tool_name = self.tools[name].name
|
||||
return await session.call_tool(session_tool_name, args)
|
||||
|
||||
async def disconnect_from_server(self, session: mcp.ClientSession) -> None:
|
||||
"""Disconnects from a single MCP server."""
|
||||
|
||||
session_known_for_components = session in self._sessions
|
||||
session_known_for_stack = session in self._session_exit_stacks
|
||||
|
||||
if not session_known_for_components and not session_known_for_stack:
|
||||
raise McpError(
|
||||
types.ErrorData(
|
||||
code=types.INVALID_PARAMS,
|
||||
message="Provided session is not managed or already disconnected.",
|
||||
)
|
||||
)
|
||||
|
||||
if session_known_for_components:
|
||||
component_names = self._sessions.pop(session) # Pop from _sessions tracking
|
||||
|
||||
# Remove prompts associated with the session.
|
||||
for name in component_names.prompts:
|
||||
if name in self._prompts:
|
||||
del self._prompts[name]
|
||||
# Remove resources associated with the session.
|
||||
for name in component_names.resources:
|
||||
if name in self._resources:
|
||||
del self._resources[name]
|
||||
# Remove tools associated with the session.
|
||||
for name in component_names.tools:
|
||||
if name in self._tools:
|
||||
del self._tools[name]
|
||||
if name in self._tool_to_session:
|
||||
del self._tool_to_session[name]
|
||||
|
||||
# Clean up the session's resources via its dedicated exit stack
|
||||
if session_known_for_stack:
|
||||
session_stack_to_close = self._session_exit_stacks.pop(session)
|
||||
await session_stack_to_close.aclose()
|
||||
|
||||
async def connect_with_session(
|
||||
self, server_info: types.Implementation, session: mcp.ClientSession
|
||||
) -> mcp.ClientSession:
|
||||
"""Connects to a single MCP server."""
|
||||
await self._aggregate_components(server_info, session)
|
||||
return session
|
||||
|
||||
async def connect_to_server(
|
||||
self,
|
||||
server_params: ServerParameters,
|
||||
) -> mcp.ClientSession:
|
||||
"""Connects to a single MCP server."""
|
||||
server_info, session = await self._establish_session(server_params)
|
||||
return await self.connect_with_session(server_info, session)
|
||||
|
||||
async def _establish_session(
|
||||
self, server_params: ServerParameters
|
||||
) -> tuple[types.Implementation, mcp.ClientSession]:
|
||||
"""Establish a client session to an MCP server."""
|
||||
|
||||
session_stack = contextlib.AsyncExitStack()
|
||||
try:
|
||||
# Create read and write streams that facilitate io with the server.
|
||||
if isinstance(server_params, StdioServerParameters):
|
||||
client = mcp.stdio_client(server_params)
|
||||
read, write = await session_stack.enter_async_context(client)
|
||||
elif isinstance(server_params, SseServerParameters):
|
||||
client = sse_client(
|
||||
url=server_params.url,
|
||||
headers=server_params.headers,
|
||||
timeout=server_params.timeout,
|
||||
sse_read_timeout=server_params.sse_read_timeout,
|
||||
)
|
||||
read, write = await session_stack.enter_async_context(client)
|
||||
else:
|
||||
client = streamablehttp_client(
|
||||
url=server_params.url,
|
||||
headers=server_params.headers,
|
||||
timeout=server_params.timeout,
|
||||
sse_read_timeout=server_params.sse_read_timeout,
|
||||
terminate_on_close=server_params.terminate_on_close,
|
||||
)
|
||||
read, write, _ = await session_stack.enter_async_context(client)
|
||||
|
||||
session = await session_stack.enter_async_context(
|
||||
mcp.ClientSession(read, write)
|
||||
)
|
||||
result = await session.initialize()
|
||||
|
||||
# Session successfully initialized.
|
||||
# Store its stack and register the stack with the main group stack.
|
||||
self._session_exit_stacks[session] = session_stack
|
||||
# session_stack itself becomes a resource managed by the
|
||||
# main _exit_stack.
|
||||
await self._exit_stack.enter_async_context(session_stack)
|
||||
|
||||
return result.serverInfo, session
|
||||
except Exception:
|
||||
# If anything during this setup fails, ensure the session-specific
|
||||
# stack is closed.
|
||||
await session_stack.aclose()
|
||||
raise
|
||||
|
||||
async def _aggregate_components(
|
||||
self, server_info: types.Implementation, session: mcp.ClientSession
|
||||
) -> None:
|
||||
"""Aggregates prompts, resources, and tools from a given session."""
|
||||
|
||||
# Create a reverse index so we can find all prompts, resources, and
|
||||
# tools belonging to this session. Used for removing components from
|
||||
# the session group via self.disconnect_from_server.
|
||||
component_names = self._ComponentNames()
|
||||
|
||||
# Temporary components dicts. We do not want to modify the aggregate
|
||||
# lists in case of an intermediate failure.
|
||||
prompts_temp: dict[str, types.Prompt] = {}
|
||||
resources_temp: dict[str, types.Resource] = {}
|
||||
tools_temp: dict[str, types.Tool] = {}
|
||||
tool_to_session_temp: dict[str, mcp.ClientSession] = {}
|
||||
|
||||
# Query the server for its prompts and aggregate to list.
|
||||
try:
|
||||
prompts = (await session.list_prompts()).prompts
|
||||
for prompt in prompts:
|
||||
name = self._component_name(prompt.name, server_info)
|
||||
prompts_temp[name] = prompt
|
||||
component_names.prompts.add(name)
|
||||
except McpError as err:
|
||||
logging.warning(f"Could not fetch prompts: {err}")
|
||||
|
||||
# Query the server for its resources and aggregate to list.
|
||||
try:
|
||||
resources = (await session.list_resources()).resources
|
||||
for resource in resources:
|
||||
name = self._component_name(resource.name, server_info)
|
||||
resources_temp[name] = resource
|
||||
component_names.resources.add(name)
|
||||
except McpError as err:
|
||||
logging.warning(f"Could not fetch resources: {err}")
|
||||
|
||||
# Query the server for its tools and aggregate to list.
|
||||
try:
|
||||
tools = (await session.list_tools()).tools
|
||||
for tool in tools:
|
||||
name = self._component_name(tool.name, server_info)
|
||||
tools_temp[name] = tool
|
||||
tool_to_session_temp[name] = session
|
||||
component_names.tools.add(name)
|
||||
except McpError as err:
|
||||
logging.warning(f"Could not fetch tools: {err}")
|
||||
|
||||
# Clean up exit stack for session if we couldn't retrieve anything
|
||||
# from the server.
|
||||
if not any((prompts_temp, resources_temp, tools_temp)):
|
||||
del self._session_exit_stacks[session]
|
||||
|
||||
# Check for duplicates.
|
||||
matching_prompts = prompts_temp.keys() & self._prompts.keys()
|
||||
if matching_prompts:
|
||||
raise McpError(
|
||||
types.ErrorData(
|
||||
code=types.INVALID_PARAMS,
|
||||
message=f"{matching_prompts} already exist in group prompts.",
|
||||
)
|
||||
)
|
||||
matching_resources = resources_temp.keys() & self._resources.keys()
|
||||
if matching_resources:
|
||||
raise McpError(
|
||||
types.ErrorData(
|
||||
code=types.INVALID_PARAMS,
|
||||
message=f"{matching_resources} already exist in group resources.",
|
||||
)
|
||||
)
|
||||
matching_tools = tools_temp.keys() & self._tools.keys()
|
||||
if matching_tools:
|
||||
raise McpError(
|
||||
types.ErrorData(
|
||||
code=types.INVALID_PARAMS,
|
||||
message=f"{matching_tools} already exist in group tools.",
|
||||
)
|
||||
)
|
||||
|
||||
# Aggregate components.
|
||||
self._sessions[session] = component_names
|
||||
self._prompts.update(prompts_temp)
|
||||
self._resources.update(resources_temp)
|
||||
self._tools.update(tools_temp)
|
||||
self._tool_to_session.update(tool_to_session_temp)
|
||||
|
||||
def _component_name(self, name: str, server_info: types.Implementation) -> str:
|
||||
if self._component_name_hook:
|
||||
return self._component_name_hook(name, server_info)
|
||||
return name
|
||||
397
tests/client/test_session_group.py
Normal file
397
tests/client/test_session_group.py
Normal file
@@ -0,0 +1,397 @@
|
||||
import contextlib
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
import mcp
|
||||
from mcp import types
|
||||
from mcp.client.session_group import (
|
||||
ClientSessionGroup,
|
||||
SseServerParameters,
|
||||
StreamableHttpParameters,
|
||||
)
|
||||
from mcp.client.stdio import StdioServerParameters
|
||||
from mcp.shared.exceptions import McpError
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_exit_stack():
|
||||
"""Fixture for a mocked AsyncExitStack."""
|
||||
# Use unittest.mock.Mock directly if needed, or just a plain object
|
||||
# if only attribute access/existence is needed.
|
||||
# For AsyncExitStack, Mock or MagicMock is usually fine.
|
||||
return mock.MagicMock(spec=contextlib.AsyncExitStack)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
class TestClientSessionGroup:
|
||||
def test_init(self):
|
||||
mcp_session_group = ClientSessionGroup()
|
||||
assert not mcp_session_group._tools
|
||||
assert not mcp_session_group._resources
|
||||
assert not mcp_session_group._prompts
|
||||
assert not mcp_session_group._tool_to_session
|
||||
|
||||
def test_component_properties(self):
|
||||
# --- Mock Dependencies ---
|
||||
mock_prompt = mock.Mock()
|
||||
mock_resource = mock.Mock()
|
||||
mock_tool = mock.Mock()
|
||||
|
||||
# --- Prepare Session Group ---
|
||||
mcp_session_group = ClientSessionGroup()
|
||||
mcp_session_group._prompts = {"my_prompt": mock_prompt}
|
||||
mcp_session_group._resources = {"my_resource": mock_resource}
|
||||
mcp_session_group._tools = {"my_tool": mock_tool}
|
||||
|
||||
# --- Assertions ---
|
||||
assert mcp_session_group.prompts == {"my_prompt": mock_prompt}
|
||||
assert mcp_session_group.resources == {"my_resource": mock_resource}
|
||||
assert mcp_session_group.tools == {"my_tool": mock_tool}
|
||||
|
||||
async def test_call_tool(self):
|
||||
# --- Mock Dependencies ---
|
||||
mock_session = mock.AsyncMock()
|
||||
|
||||
# --- Prepare Session Group ---
|
||||
def hook(name, server_info):
|
||||
return f"{(server_info.name)}-{name}"
|
||||
|
||||
mcp_session_group = ClientSessionGroup(component_name_hook=hook)
|
||||
mcp_session_group._tools = {
|
||||
"server1-my_tool": types.Tool(name="my_tool", inputSchema={})
|
||||
}
|
||||
mcp_session_group._tool_to_session = {"server1-my_tool": mock_session}
|
||||
text_content = types.TextContent(type="text", text="OK")
|
||||
mock_session.call_tool.return_value = types.CallToolResult(
|
||||
content=[text_content]
|
||||
)
|
||||
|
||||
# --- Test Execution ---
|
||||
result = await mcp_session_group.call_tool(
|
||||
name="server1-my_tool",
|
||||
args={
|
||||
"name": "value1",
|
||||
"args": {},
|
||||
},
|
||||
)
|
||||
|
||||
# --- Assertions ---
|
||||
assert result.content == [text_content]
|
||||
mock_session.call_tool.assert_called_once_with(
|
||||
"my_tool",
|
||||
{"name": "value1", "args": {}},
|
||||
)
|
||||
|
||||
async def test_connect_to_server(self, mock_exit_stack):
|
||||
"""Test connecting to a server and aggregating components."""
|
||||
# --- Mock Dependencies ---
|
||||
mock_server_info = mock.Mock(spec=types.Implementation)
|
||||
mock_server_info.name = "TestServer1"
|
||||
mock_session = mock.AsyncMock(spec=mcp.ClientSession)
|
||||
mock_tool1 = mock.Mock(spec=types.Tool)
|
||||
mock_tool1.name = "tool_a"
|
||||
mock_resource1 = mock.Mock(spec=types.Resource)
|
||||
mock_resource1.name = "resource_b"
|
||||
mock_prompt1 = mock.Mock(spec=types.Prompt)
|
||||
mock_prompt1.name = "prompt_c"
|
||||
mock_session.list_tools.return_value = mock.AsyncMock(tools=[mock_tool1])
|
||||
mock_session.list_resources.return_value = mock.AsyncMock(
|
||||
resources=[mock_resource1]
|
||||
)
|
||||
mock_session.list_prompts.return_value = mock.AsyncMock(prompts=[mock_prompt1])
|
||||
|
||||
# --- Test Execution ---
|
||||
group = ClientSessionGroup(exit_stack=mock_exit_stack)
|
||||
with mock.patch.object(
|
||||
group, "_establish_session", return_value=(mock_server_info, mock_session)
|
||||
):
|
||||
await group.connect_to_server(StdioServerParameters(command="test"))
|
||||
|
||||
# --- Assertions ---
|
||||
assert mock_session in group._sessions
|
||||
assert len(group.tools) == 1
|
||||
assert "tool_a" in group.tools
|
||||
assert group.tools["tool_a"] == mock_tool1
|
||||
assert group._tool_to_session["tool_a"] == mock_session
|
||||
assert len(group.resources) == 1
|
||||
assert "resource_b" in group.resources
|
||||
assert group.resources["resource_b"] == mock_resource1
|
||||
assert len(group.prompts) == 1
|
||||
assert "prompt_c" in group.prompts
|
||||
assert group.prompts["prompt_c"] == mock_prompt1
|
||||
mock_session.list_tools.assert_awaited_once()
|
||||
mock_session.list_resources.assert_awaited_once()
|
||||
mock_session.list_prompts.assert_awaited_once()
|
||||
|
||||
async def test_connect_to_server_with_name_hook(self, mock_exit_stack):
|
||||
"""Test connecting with a component name hook."""
|
||||
# --- Mock Dependencies ---
|
||||
mock_server_info = mock.Mock(spec=types.Implementation)
|
||||
mock_server_info.name = "HookServer"
|
||||
mock_session = mock.AsyncMock(spec=mcp.ClientSession)
|
||||
mock_tool = mock.Mock(spec=types.Tool)
|
||||
mock_tool.name = "base_tool"
|
||||
mock_session.list_tools.return_value = mock.AsyncMock(tools=[mock_tool])
|
||||
mock_session.list_resources.return_value = mock.AsyncMock(resources=[])
|
||||
mock_session.list_prompts.return_value = mock.AsyncMock(prompts=[])
|
||||
|
||||
# --- Test Setup ---
|
||||
def name_hook(name: str, server_info: types.Implementation) -> str:
|
||||
return f"{server_info.name}.{name}"
|
||||
|
||||
# --- Test Execution ---
|
||||
group = ClientSessionGroup(
|
||||
exit_stack=mock_exit_stack, component_name_hook=name_hook
|
||||
)
|
||||
with mock.patch.object(
|
||||
group, "_establish_session", return_value=(mock_server_info, mock_session)
|
||||
):
|
||||
await group.connect_to_server(StdioServerParameters(command="test"))
|
||||
|
||||
# --- Assertions ---
|
||||
assert mock_session in group._sessions
|
||||
assert len(group.tools) == 1
|
||||
expected_tool_name = "HookServer.base_tool"
|
||||
assert expected_tool_name in group.tools
|
||||
assert group.tools[expected_tool_name] == mock_tool
|
||||
assert group._tool_to_session[expected_tool_name] == mock_session
|
||||
|
||||
async def test_disconnect_from_server(self): # No mock arguments needed
|
||||
"""Test disconnecting from a server."""
|
||||
# --- Test Setup ---
|
||||
group = ClientSessionGroup()
|
||||
server_name = "ServerToDisconnect"
|
||||
|
||||
# Manually populate state using standard mocks
|
||||
mock_session1 = mock.MagicMock(spec=mcp.ClientSession)
|
||||
mock_session2 = mock.MagicMock(spec=mcp.ClientSession)
|
||||
mock_tool1 = mock.Mock(spec=types.Tool)
|
||||
mock_tool1.name = "tool1"
|
||||
mock_resource1 = mock.Mock(spec=types.Resource)
|
||||
mock_resource1.name = "res1"
|
||||
mock_prompt1 = mock.Mock(spec=types.Prompt)
|
||||
mock_prompt1.name = "prm1"
|
||||
mock_tool2 = mock.Mock(spec=types.Tool)
|
||||
mock_tool2.name = "tool2"
|
||||
mock_component_named_like_server = mock.Mock()
|
||||
mock_session = mock.Mock(spec=mcp.ClientSession)
|
||||
|
||||
group._tools = {
|
||||
"tool1": mock_tool1,
|
||||
"tool2": mock_tool2,
|
||||
server_name: mock_component_named_like_server,
|
||||
}
|
||||
group._tool_to_session = {
|
||||
"tool1": mock_session1,
|
||||
"tool2": mock_session2,
|
||||
server_name: mock_session1,
|
||||
}
|
||||
group._resources = {
|
||||
"res1": mock_resource1,
|
||||
server_name: mock_component_named_like_server,
|
||||
}
|
||||
group._prompts = {
|
||||
"prm1": mock_prompt1,
|
||||
server_name: mock_component_named_like_server,
|
||||
}
|
||||
group._sessions = {
|
||||
mock_session: ClientSessionGroup._ComponentNames(
|
||||
prompts=set({"prm1"}),
|
||||
resources=set({"res1"}),
|
||||
tools=set({"tool1", "tool2"}),
|
||||
)
|
||||
}
|
||||
|
||||
# --- Assertions ---
|
||||
assert mock_session in group._sessions
|
||||
assert "tool1" in group._tools
|
||||
assert "tool2" in group._tools
|
||||
assert "res1" in group._resources
|
||||
assert "prm1" in group._prompts
|
||||
|
||||
# --- Test Execution ---
|
||||
await group.disconnect_from_server(mock_session)
|
||||
|
||||
# --- Assertions ---
|
||||
assert mock_session not in group._sessions
|
||||
assert "tool1" not in group._tools
|
||||
assert "tool2" not in group._tools
|
||||
assert "res1" not in group._resources
|
||||
assert "prm1" not in group._prompts
|
||||
|
||||
async def test_connect_to_server_duplicate_tool_raises_error(self, mock_exit_stack):
|
||||
"""Test McpError raised when connecting a server with a dup name."""
|
||||
# --- Setup Pre-existing State ---
|
||||
group = ClientSessionGroup(exit_stack=mock_exit_stack)
|
||||
existing_tool_name = "shared_tool"
|
||||
# Manually add a tool to simulate a previous connection
|
||||
group._tools[existing_tool_name] = mock.Mock(spec=types.Tool)
|
||||
group._tools[existing_tool_name].name = existing_tool_name
|
||||
# Need a dummy session associated with the existing tool
|
||||
mock_session = mock.MagicMock(spec=mcp.ClientSession)
|
||||
group._tool_to_session[existing_tool_name] = mock_session
|
||||
group._session_exit_stacks[mock_session] = mock.Mock(
|
||||
spec=contextlib.AsyncExitStack
|
||||
)
|
||||
|
||||
# --- Mock New Connection Attempt ---
|
||||
mock_server_info_new = mock.Mock(spec=types.Implementation)
|
||||
mock_server_info_new.name = "ServerWithDuplicate"
|
||||
mock_session_new = mock.AsyncMock(spec=mcp.ClientSession)
|
||||
|
||||
# Configure the new session to return a tool with the *same name*
|
||||
duplicate_tool = mock.Mock(spec=types.Tool)
|
||||
duplicate_tool.name = existing_tool_name
|
||||
mock_session_new.list_tools.return_value = mock.AsyncMock(
|
||||
tools=[duplicate_tool]
|
||||
)
|
||||
# Keep other lists empty for simplicity
|
||||
mock_session_new.list_resources.return_value = mock.AsyncMock(resources=[])
|
||||
mock_session_new.list_prompts.return_value = mock.AsyncMock(prompts=[])
|
||||
|
||||
# --- Test Execution and Assertion ---
|
||||
with pytest.raises(McpError) as excinfo:
|
||||
with mock.patch.object(
|
||||
group,
|
||||
"_establish_session",
|
||||
return_value=(mock_server_info_new, mock_session_new),
|
||||
):
|
||||
await group.connect_to_server(StdioServerParameters(command="test"))
|
||||
|
||||
# Assert details about the raised error
|
||||
assert excinfo.value.error.code == types.INVALID_PARAMS
|
||||
assert existing_tool_name in excinfo.value.error.message
|
||||
assert "already exist " in excinfo.value.error.message
|
||||
|
||||
# Verify the duplicate tool was *not* added again (state should be unchanged)
|
||||
assert len(group._tools) == 1 # Should still only have the original
|
||||
assert (
|
||||
group._tools[existing_tool_name] is not duplicate_tool
|
||||
) # Ensure it's the original mock
|
||||
|
||||
# No patching needed here
|
||||
async def test_disconnect_non_existent_server(self):
|
||||
"""Test disconnecting a server that isn't connected."""
|
||||
session = mock.Mock(spec=mcp.ClientSession)
|
||||
group = ClientSessionGroup()
|
||||
with pytest.raises(McpError):
|
||||
await group.disconnect_from_server(session)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"server_params_instance, client_type_name, patch_target_for_client_func",
|
||||
[
|
||||
(
|
||||
StdioServerParameters(command="test_stdio_cmd"),
|
||||
"stdio",
|
||||
"mcp.client.session_group.mcp.stdio_client",
|
||||
),
|
||||
(
|
||||
SseServerParameters(url="http://test.com/sse", timeout=10),
|
||||
"sse",
|
||||
"mcp.client.session_group.sse_client",
|
||||
), # url, headers, timeout, sse_read_timeout
|
||||
(
|
||||
StreamableHttpParameters(
|
||||
url="http://test.com/stream", terminate_on_close=False
|
||||
),
|
||||
"streamablehttp",
|
||||
"mcp.client.session_group.streamablehttp_client",
|
||||
), # url, headers, timeout, sse_read_timeout, terminate_on_close
|
||||
],
|
||||
)
|
||||
async def test_establish_session_parameterized(
|
||||
self,
|
||||
server_params_instance,
|
||||
client_type_name, # Just for clarity or conditional logic if needed
|
||||
patch_target_for_client_func,
|
||||
):
|
||||
with mock.patch(
|
||||
"mcp.client.session_group.mcp.ClientSession"
|
||||
) as mock_ClientSession_class:
|
||||
with mock.patch(patch_target_for_client_func) as mock_specific_client_func:
|
||||
mock_client_cm_instance = mock.AsyncMock(
|
||||
name=f"{client_type_name}ClientCM"
|
||||
)
|
||||
mock_read_stream = mock.AsyncMock(name=f"{client_type_name}Read")
|
||||
mock_write_stream = mock.AsyncMock(name=f"{client_type_name}Write")
|
||||
|
||||
# streamablehttp_client's __aenter__ returns three values
|
||||
if client_type_name == "streamablehttp":
|
||||
mock_extra_stream_val = mock.AsyncMock(name="StreamableExtra")
|
||||
mock_client_cm_instance.__aenter__.return_value = (
|
||||
mock_read_stream,
|
||||
mock_write_stream,
|
||||
mock_extra_stream_val,
|
||||
)
|
||||
else:
|
||||
mock_client_cm_instance.__aenter__.return_value = (
|
||||
mock_read_stream,
|
||||
mock_write_stream,
|
||||
)
|
||||
|
||||
mock_client_cm_instance.__aexit__ = mock.AsyncMock(return_value=None)
|
||||
mock_specific_client_func.return_value = mock_client_cm_instance
|
||||
|
||||
# --- Mock mcp.ClientSession (class) ---
|
||||
# mock_ClientSession_class is already provided by the outer patch
|
||||
mock_raw_session_cm = mock.AsyncMock(name="RawSessionCM")
|
||||
mock_ClientSession_class.return_value = mock_raw_session_cm
|
||||
|
||||
mock_entered_session = mock.AsyncMock(name="EnteredSessionInstance")
|
||||
mock_raw_session_cm.__aenter__.return_value = mock_entered_session
|
||||
mock_raw_session_cm.__aexit__ = mock.AsyncMock(return_value=None)
|
||||
|
||||
# Mock session.initialize()
|
||||
mock_initialize_result = mock.AsyncMock(name="InitializeResult")
|
||||
mock_initialize_result.serverInfo = types.Implementation(
|
||||
name="foo", version="1"
|
||||
)
|
||||
mock_entered_session.initialize.return_value = mock_initialize_result
|
||||
|
||||
# --- Test Execution ---
|
||||
group = ClientSessionGroup()
|
||||
returned_server_info = None
|
||||
returned_session = None
|
||||
|
||||
async with contextlib.AsyncExitStack() as stack:
|
||||
group._exit_stack = stack
|
||||
(
|
||||
returned_server_info,
|
||||
returned_session,
|
||||
) = await group._establish_session(server_params_instance)
|
||||
|
||||
# --- Assertions ---
|
||||
# 1. Assert the correct specific client function was called
|
||||
if client_type_name == "stdio":
|
||||
mock_specific_client_func.assert_called_once_with(
|
||||
server_params_instance
|
||||
)
|
||||
elif client_type_name == "sse":
|
||||
mock_specific_client_func.assert_called_once_with(
|
||||
url=server_params_instance.url,
|
||||
headers=server_params_instance.headers,
|
||||
timeout=server_params_instance.timeout,
|
||||
sse_read_timeout=server_params_instance.sse_read_timeout,
|
||||
)
|
||||
elif client_type_name == "streamablehttp":
|
||||
mock_specific_client_func.assert_called_once_with(
|
||||
url=server_params_instance.url,
|
||||
headers=server_params_instance.headers,
|
||||
timeout=server_params_instance.timeout,
|
||||
sse_read_timeout=server_params_instance.sse_read_timeout,
|
||||
terminate_on_close=server_params_instance.terminate_on_close,
|
||||
)
|
||||
|
||||
mock_client_cm_instance.__aenter__.assert_awaited_once()
|
||||
|
||||
# 2. Assert ClientSession was called correctly
|
||||
mock_ClientSession_class.assert_called_once_with(
|
||||
mock_read_stream, mock_write_stream
|
||||
)
|
||||
mock_raw_session_cm.__aenter__.assert_awaited_once()
|
||||
mock_entered_session.initialize.assert_awaited_once()
|
||||
|
||||
# 3. Assert returned values
|
||||
assert returned_server_info is mock_initialize_result.serverInfo
|
||||
assert returned_session is mock_entered_session
|
||||
Reference in New Issue
Block a user