mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-18 14:34:27 +01:00
368 lines
16 KiB
Python
368 lines
16 KiB
Python
import contextlib
|
|
from unittest import mock
|
|
|
|
import pytest
|
|
|
|
import mcp
|
|
from mcp import types
|
|
from mcp.client.session_group import (
|
|
ClientSessionGroup,
|
|
SseServerParameters,
|
|
StreamableHttpParameters,
|
|
)
|
|
from mcp.client.stdio import StdioServerParameters
|
|
from mcp.shared.exceptions import McpError
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_exit_stack():
|
|
"""Fixture for a mocked AsyncExitStack."""
|
|
# Use unittest.mock.Mock directly if needed, or just a plain object
|
|
# if only attribute access/existence is needed.
|
|
# For AsyncExitStack, Mock or MagicMock is usually fine.
|
|
return mock.MagicMock(spec=contextlib.AsyncExitStack)
|
|
|
|
|
|
@pytest.mark.anyio
|
|
class TestClientSessionGroup:
|
|
def test_init(self):
|
|
mcp_session_group = ClientSessionGroup()
|
|
assert not mcp_session_group._tools
|
|
assert not mcp_session_group._resources
|
|
assert not mcp_session_group._prompts
|
|
assert not mcp_session_group._tool_to_session
|
|
|
|
def test_component_properties(self):
|
|
# --- Mock Dependencies ---
|
|
mock_prompt = mock.Mock()
|
|
mock_resource = mock.Mock()
|
|
mock_tool = mock.Mock()
|
|
|
|
# --- Prepare Session Group ---
|
|
mcp_session_group = ClientSessionGroup()
|
|
mcp_session_group._prompts = {"my_prompt": mock_prompt}
|
|
mcp_session_group._resources = {"my_resource": mock_resource}
|
|
mcp_session_group._tools = {"my_tool": mock_tool}
|
|
|
|
# --- Assertions ---
|
|
assert mcp_session_group.prompts == {"my_prompt": mock_prompt}
|
|
assert mcp_session_group.resources == {"my_resource": mock_resource}
|
|
assert mcp_session_group.tools == {"my_tool": mock_tool}
|
|
|
|
async def test_call_tool(self):
|
|
# --- Mock Dependencies ---
|
|
mock_session = mock.AsyncMock()
|
|
|
|
# --- Prepare Session Group ---
|
|
def hook(name, server_info):
|
|
return f"{(server_info.name)}-{name}"
|
|
|
|
mcp_session_group = ClientSessionGroup(component_name_hook=hook)
|
|
mcp_session_group._tools = {"server1-my_tool": types.Tool(name="my_tool", inputSchema={})}
|
|
mcp_session_group._tool_to_session = {"server1-my_tool": mock_session}
|
|
text_content = types.TextContent(type="text", text="OK")
|
|
mock_session.call_tool.return_value = types.CallToolResult(content=[text_content])
|
|
|
|
# --- Test Execution ---
|
|
result = await mcp_session_group.call_tool(
|
|
name="server1-my_tool",
|
|
args={
|
|
"name": "value1",
|
|
"args": {},
|
|
},
|
|
)
|
|
|
|
# --- Assertions ---
|
|
assert result.content == [text_content]
|
|
mock_session.call_tool.assert_called_once_with(
|
|
"my_tool",
|
|
{"name": "value1", "args": {}},
|
|
)
|
|
|
|
async def test_connect_to_server(self, mock_exit_stack):
|
|
"""Test connecting to a server and aggregating components."""
|
|
# --- Mock Dependencies ---
|
|
mock_server_info = mock.Mock(spec=types.Implementation)
|
|
mock_server_info.name = "TestServer1"
|
|
mock_session = mock.AsyncMock(spec=mcp.ClientSession)
|
|
mock_tool1 = mock.Mock(spec=types.Tool)
|
|
mock_tool1.name = "tool_a"
|
|
mock_resource1 = mock.Mock(spec=types.Resource)
|
|
mock_resource1.name = "resource_b"
|
|
mock_prompt1 = mock.Mock(spec=types.Prompt)
|
|
mock_prompt1.name = "prompt_c"
|
|
mock_session.list_tools.return_value = mock.AsyncMock(tools=[mock_tool1])
|
|
mock_session.list_resources.return_value = mock.AsyncMock(resources=[mock_resource1])
|
|
mock_session.list_prompts.return_value = mock.AsyncMock(prompts=[mock_prompt1])
|
|
|
|
# --- Test Execution ---
|
|
group = ClientSessionGroup(exit_stack=mock_exit_stack)
|
|
with mock.patch.object(group, "_establish_session", return_value=(mock_server_info, mock_session)):
|
|
await group.connect_to_server(StdioServerParameters(command="test"))
|
|
|
|
# --- Assertions ---
|
|
assert mock_session in group._sessions
|
|
assert len(group.tools) == 1
|
|
assert "tool_a" in group.tools
|
|
assert group.tools["tool_a"] == mock_tool1
|
|
assert group._tool_to_session["tool_a"] == mock_session
|
|
assert len(group.resources) == 1
|
|
assert "resource_b" in group.resources
|
|
assert group.resources["resource_b"] == mock_resource1
|
|
assert len(group.prompts) == 1
|
|
assert "prompt_c" in group.prompts
|
|
assert group.prompts["prompt_c"] == mock_prompt1
|
|
mock_session.list_tools.assert_awaited_once()
|
|
mock_session.list_resources.assert_awaited_once()
|
|
mock_session.list_prompts.assert_awaited_once()
|
|
|
|
async def test_connect_to_server_with_name_hook(self, mock_exit_stack):
|
|
"""Test connecting with a component name hook."""
|
|
# --- Mock Dependencies ---
|
|
mock_server_info = mock.Mock(spec=types.Implementation)
|
|
mock_server_info.name = "HookServer"
|
|
mock_session = mock.AsyncMock(spec=mcp.ClientSession)
|
|
mock_tool = mock.Mock(spec=types.Tool)
|
|
mock_tool.name = "base_tool"
|
|
mock_session.list_tools.return_value = mock.AsyncMock(tools=[mock_tool])
|
|
mock_session.list_resources.return_value = mock.AsyncMock(resources=[])
|
|
mock_session.list_prompts.return_value = mock.AsyncMock(prompts=[])
|
|
|
|
# --- Test Setup ---
|
|
def name_hook(name: str, server_info: types.Implementation) -> str:
|
|
return f"{server_info.name}.{name}"
|
|
|
|
# --- Test Execution ---
|
|
group = ClientSessionGroup(exit_stack=mock_exit_stack, component_name_hook=name_hook)
|
|
with mock.patch.object(group, "_establish_session", return_value=(mock_server_info, mock_session)):
|
|
await group.connect_to_server(StdioServerParameters(command="test"))
|
|
|
|
# --- Assertions ---
|
|
assert mock_session in group._sessions
|
|
assert len(group.tools) == 1
|
|
expected_tool_name = "HookServer.base_tool"
|
|
assert expected_tool_name in group.tools
|
|
assert group.tools[expected_tool_name] == mock_tool
|
|
assert group._tool_to_session[expected_tool_name] == mock_session
|
|
|
|
async def test_disconnect_from_server(self): # No mock arguments needed
|
|
"""Test disconnecting from a server."""
|
|
# --- Test Setup ---
|
|
group = ClientSessionGroup()
|
|
server_name = "ServerToDisconnect"
|
|
|
|
# Manually populate state using standard mocks
|
|
mock_session1 = mock.MagicMock(spec=mcp.ClientSession)
|
|
mock_session2 = mock.MagicMock(spec=mcp.ClientSession)
|
|
mock_tool1 = mock.Mock(spec=types.Tool)
|
|
mock_tool1.name = "tool1"
|
|
mock_resource1 = mock.Mock(spec=types.Resource)
|
|
mock_resource1.name = "res1"
|
|
mock_prompt1 = mock.Mock(spec=types.Prompt)
|
|
mock_prompt1.name = "prm1"
|
|
mock_tool2 = mock.Mock(spec=types.Tool)
|
|
mock_tool2.name = "tool2"
|
|
mock_component_named_like_server = mock.Mock()
|
|
mock_session = mock.Mock(spec=mcp.ClientSession)
|
|
|
|
group._tools = {
|
|
"tool1": mock_tool1,
|
|
"tool2": mock_tool2,
|
|
server_name: mock_component_named_like_server,
|
|
}
|
|
group._tool_to_session = {
|
|
"tool1": mock_session1,
|
|
"tool2": mock_session2,
|
|
server_name: mock_session1,
|
|
}
|
|
group._resources = {
|
|
"res1": mock_resource1,
|
|
server_name: mock_component_named_like_server,
|
|
}
|
|
group._prompts = {
|
|
"prm1": mock_prompt1,
|
|
server_name: mock_component_named_like_server,
|
|
}
|
|
group._sessions = {
|
|
mock_session: ClientSessionGroup._ComponentNames(
|
|
prompts=set({"prm1"}),
|
|
resources=set({"res1"}),
|
|
tools=set({"tool1", "tool2"}),
|
|
)
|
|
}
|
|
|
|
# --- Assertions ---
|
|
assert mock_session in group._sessions
|
|
assert "tool1" in group._tools
|
|
assert "tool2" in group._tools
|
|
assert "res1" in group._resources
|
|
assert "prm1" in group._prompts
|
|
|
|
# --- Test Execution ---
|
|
await group.disconnect_from_server(mock_session)
|
|
|
|
# --- Assertions ---
|
|
assert mock_session not in group._sessions
|
|
assert "tool1" not in group._tools
|
|
assert "tool2" not in group._tools
|
|
assert "res1" not in group._resources
|
|
assert "prm1" not in group._prompts
|
|
|
|
async def test_connect_to_server_duplicate_tool_raises_error(self, mock_exit_stack):
|
|
"""Test McpError raised when connecting a server with a dup name."""
|
|
# --- Setup Pre-existing State ---
|
|
group = ClientSessionGroup(exit_stack=mock_exit_stack)
|
|
existing_tool_name = "shared_tool"
|
|
# Manually add a tool to simulate a previous connection
|
|
group._tools[existing_tool_name] = mock.Mock(spec=types.Tool)
|
|
group._tools[existing_tool_name].name = existing_tool_name
|
|
# Need a dummy session associated with the existing tool
|
|
mock_session = mock.MagicMock(spec=mcp.ClientSession)
|
|
group._tool_to_session[existing_tool_name] = mock_session
|
|
group._session_exit_stacks[mock_session] = mock.Mock(spec=contextlib.AsyncExitStack)
|
|
|
|
# --- Mock New Connection Attempt ---
|
|
mock_server_info_new = mock.Mock(spec=types.Implementation)
|
|
mock_server_info_new.name = "ServerWithDuplicate"
|
|
mock_session_new = mock.AsyncMock(spec=mcp.ClientSession)
|
|
|
|
# Configure the new session to return a tool with the *same name*
|
|
duplicate_tool = mock.Mock(spec=types.Tool)
|
|
duplicate_tool.name = existing_tool_name
|
|
mock_session_new.list_tools.return_value = mock.AsyncMock(tools=[duplicate_tool])
|
|
# Keep other lists empty for simplicity
|
|
mock_session_new.list_resources.return_value = mock.AsyncMock(resources=[])
|
|
mock_session_new.list_prompts.return_value = mock.AsyncMock(prompts=[])
|
|
|
|
# --- Test Execution and Assertion ---
|
|
with pytest.raises(McpError) as excinfo:
|
|
with mock.patch.object(
|
|
group,
|
|
"_establish_session",
|
|
return_value=(mock_server_info_new, mock_session_new),
|
|
):
|
|
await group.connect_to_server(StdioServerParameters(command="test"))
|
|
|
|
# Assert details about the raised error
|
|
assert excinfo.value.error.code == types.INVALID_PARAMS
|
|
assert existing_tool_name in excinfo.value.error.message
|
|
assert "already exist " in excinfo.value.error.message
|
|
|
|
# Verify the duplicate tool was *not* added again (state should be unchanged)
|
|
assert len(group._tools) == 1 # Should still only have the original
|
|
assert group._tools[existing_tool_name] is not duplicate_tool # Ensure it's the original mock
|
|
|
|
# No patching needed here
|
|
async def test_disconnect_non_existent_server(self):
|
|
"""Test disconnecting a server that isn't connected."""
|
|
session = mock.Mock(spec=mcp.ClientSession)
|
|
group = ClientSessionGroup()
|
|
with pytest.raises(McpError):
|
|
await group.disconnect_from_server(session)
|
|
|
|
@pytest.mark.parametrize(
|
|
"server_params_instance, client_type_name, patch_target_for_client_func",
|
|
[
|
|
(
|
|
StdioServerParameters(command="test_stdio_cmd"),
|
|
"stdio",
|
|
"mcp.client.session_group.mcp.stdio_client",
|
|
),
|
|
(
|
|
SseServerParameters(url="http://test.com/sse", timeout=10),
|
|
"sse",
|
|
"mcp.client.session_group.sse_client",
|
|
), # url, headers, timeout, sse_read_timeout
|
|
(
|
|
StreamableHttpParameters(url="http://test.com/stream", terminate_on_close=False),
|
|
"streamablehttp",
|
|
"mcp.client.session_group.streamablehttp_client",
|
|
), # url, headers, timeout, sse_read_timeout, terminate_on_close
|
|
],
|
|
)
|
|
async def test_establish_session_parameterized(
|
|
self,
|
|
server_params_instance,
|
|
client_type_name, # Just for clarity or conditional logic if needed
|
|
patch_target_for_client_func,
|
|
):
|
|
with mock.patch("mcp.client.session_group.mcp.ClientSession") as mock_ClientSession_class:
|
|
with mock.patch(patch_target_for_client_func) as mock_specific_client_func:
|
|
mock_client_cm_instance = mock.AsyncMock(name=f"{client_type_name}ClientCM")
|
|
mock_read_stream = mock.AsyncMock(name=f"{client_type_name}Read")
|
|
mock_write_stream = mock.AsyncMock(name=f"{client_type_name}Write")
|
|
|
|
# streamablehttp_client's __aenter__ returns three values
|
|
if client_type_name == "streamablehttp":
|
|
mock_extra_stream_val = mock.AsyncMock(name="StreamableExtra")
|
|
mock_client_cm_instance.__aenter__.return_value = (
|
|
mock_read_stream,
|
|
mock_write_stream,
|
|
mock_extra_stream_val,
|
|
)
|
|
else:
|
|
mock_client_cm_instance.__aenter__.return_value = (
|
|
mock_read_stream,
|
|
mock_write_stream,
|
|
)
|
|
|
|
mock_client_cm_instance.__aexit__ = mock.AsyncMock(return_value=None)
|
|
mock_specific_client_func.return_value = mock_client_cm_instance
|
|
|
|
# --- Mock mcp.ClientSession (class) ---
|
|
# mock_ClientSession_class is already provided by the outer patch
|
|
mock_raw_session_cm = mock.AsyncMock(name="RawSessionCM")
|
|
mock_ClientSession_class.return_value = mock_raw_session_cm
|
|
|
|
mock_entered_session = mock.AsyncMock(name="EnteredSessionInstance")
|
|
mock_raw_session_cm.__aenter__.return_value = mock_entered_session
|
|
mock_raw_session_cm.__aexit__ = mock.AsyncMock(return_value=None)
|
|
|
|
# Mock session.initialize()
|
|
mock_initialize_result = mock.AsyncMock(name="InitializeResult")
|
|
mock_initialize_result.serverInfo = types.Implementation(name="foo", version="1")
|
|
mock_entered_session.initialize.return_value = mock_initialize_result
|
|
|
|
# --- Test Execution ---
|
|
group = ClientSessionGroup()
|
|
returned_server_info = None
|
|
returned_session = None
|
|
|
|
async with contextlib.AsyncExitStack() as stack:
|
|
group._exit_stack = stack
|
|
(
|
|
returned_server_info,
|
|
returned_session,
|
|
) = await group._establish_session(server_params_instance)
|
|
|
|
# --- Assertions ---
|
|
# 1. Assert the correct specific client function was called
|
|
if client_type_name == "stdio":
|
|
mock_specific_client_func.assert_called_once_with(server_params_instance)
|
|
elif client_type_name == "sse":
|
|
mock_specific_client_func.assert_called_once_with(
|
|
url=server_params_instance.url,
|
|
headers=server_params_instance.headers,
|
|
timeout=server_params_instance.timeout,
|
|
sse_read_timeout=server_params_instance.sse_read_timeout,
|
|
)
|
|
elif client_type_name == "streamablehttp":
|
|
mock_specific_client_func.assert_called_once_with(
|
|
url=server_params_instance.url,
|
|
headers=server_params_instance.headers,
|
|
timeout=server_params_instance.timeout,
|
|
sse_read_timeout=server_params_instance.sse_read_timeout,
|
|
terminate_on_close=server_params_instance.terminate_on_close,
|
|
)
|
|
|
|
mock_client_cm_instance.__aenter__.assert_awaited_once()
|
|
|
|
# 2. Assert ClientSession was called correctly
|
|
mock_ClientSession_class.assert_called_once_with(mock_read_stream, mock_write_stream)
|
|
mock_raw_session_cm.__aenter__.assert_awaited_once()
|
|
mock_entered_session.initialize.assert_awaited_once()
|
|
|
|
# 3. Assert returned values
|
|
assert returned_server_info is mock_initialize_result.serverInfo
|
|
assert returned_session is mock_entered_session
|