mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 23:04:25 +01:00
Use 120 characters instead of 88 (#856)
This commit is contained in:
committed by
GitHub
parent
f7265f7b91
commit
543961968c
@@ -58,14 +58,10 @@ class TestClientSessionGroup:
|
||||
return f"{(server_info.name)}-{name}"
|
||||
|
||||
mcp_session_group = ClientSessionGroup(component_name_hook=hook)
|
||||
mcp_session_group._tools = {
|
||||
"server1-my_tool": types.Tool(name="my_tool", inputSchema={})
|
||||
}
|
||||
mcp_session_group._tools = {"server1-my_tool": types.Tool(name="my_tool", inputSchema={})}
|
||||
mcp_session_group._tool_to_session = {"server1-my_tool": mock_session}
|
||||
text_content = types.TextContent(type="text", text="OK")
|
||||
mock_session.call_tool.return_value = types.CallToolResult(
|
||||
content=[text_content]
|
||||
)
|
||||
mock_session.call_tool.return_value = types.CallToolResult(content=[text_content])
|
||||
|
||||
# --- Test Execution ---
|
||||
result = await mcp_session_group.call_tool(
|
||||
@@ -96,16 +92,12 @@ class TestClientSessionGroup:
|
||||
mock_prompt1 = mock.Mock(spec=types.Prompt)
|
||||
mock_prompt1.name = "prompt_c"
|
||||
mock_session.list_tools.return_value = mock.AsyncMock(tools=[mock_tool1])
|
||||
mock_session.list_resources.return_value = mock.AsyncMock(
|
||||
resources=[mock_resource1]
|
||||
)
|
||||
mock_session.list_resources.return_value = mock.AsyncMock(resources=[mock_resource1])
|
||||
mock_session.list_prompts.return_value = mock.AsyncMock(prompts=[mock_prompt1])
|
||||
|
||||
# --- Test Execution ---
|
||||
group = ClientSessionGroup(exit_stack=mock_exit_stack)
|
||||
with mock.patch.object(
|
||||
group, "_establish_session", return_value=(mock_server_info, mock_session)
|
||||
):
|
||||
with mock.patch.object(group, "_establish_session", return_value=(mock_server_info, mock_session)):
|
||||
await group.connect_to_server(StdioServerParameters(command="test"))
|
||||
|
||||
# --- Assertions ---
|
||||
@@ -141,12 +133,8 @@ class TestClientSessionGroup:
|
||||
return f"{server_info.name}.{name}"
|
||||
|
||||
# --- Test Execution ---
|
||||
group = ClientSessionGroup(
|
||||
exit_stack=mock_exit_stack, component_name_hook=name_hook
|
||||
)
|
||||
with mock.patch.object(
|
||||
group, "_establish_session", return_value=(mock_server_info, mock_session)
|
||||
):
|
||||
group = ClientSessionGroup(exit_stack=mock_exit_stack, component_name_hook=name_hook)
|
||||
with mock.patch.object(group, "_establish_session", return_value=(mock_server_info, mock_session)):
|
||||
await group.connect_to_server(StdioServerParameters(command="test"))
|
||||
|
||||
# --- Assertions ---
|
||||
@@ -231,9 +219,7 @@ class TestClientSessionGroup:
|
||||
# Need a dummy session associated with the existing tool
|
||||
mock_session = mock.MagicMock(spec=mcp.ClientSession)
|
||||
group._tool_to_session[existing_tool_name] = mock_session
|
||||
group._session_exit_stacks[mock_session] = mock.Mock(
|
||||
spec=contextlib.AsyncExitStack
|
||||
)
|
||||
group._session_exit_stacks[mock_session] = mock.Mock(spec=contextlib.AsyncExitStack)
|
||||
|
||||
# --- Mock New Connection Attempt ---
|
||||
mock_server_info_new = mock.Mock(spec=types.Implementation)
|
||||
@@ -243,9 +229,7 @@ class TestClientSessionGroup:
|
||||
# Configure the new session to return a tool with the *same name*
|
||||
duplicate_tool = mock.Mock(spec=types.Tool)
|
||||
duplicate_tool.name = existing_tool_name
|
||||
mock_session_new.list_tools.return_value = mock.AsyncMock(
|
||||
tools=[duplicate_tool]
|
||||
)
|
||||
mock_session_new.list_tools.return_value = mock.AsyncMock(tools=[duplicate_tool])
|
||||
# Keep other lists empty for simplicity
|
||||
mock_session_new.list_resources.return_value = mock.AsyncMock(resources=[])
|
||||
mock_session_new.list_prompts.return_value = mock.AsyncMock(prompts=[])
|
||||
@@ -266,9 +250,7 @@ class TestClientSessionGroup:
|
||||
|
||||
# Verify the duplicate tool was *not* added again (state should be unchanged)
|
||||
assert len(group._tools) == 1 # Should still only have the original
|
||||
assert (
|
||||
group._tools[existing_tool_name] is not duplicate_tool
|
||||
) # Ensure it's the original mock
|
||||
assert group._tools[existing_tool_name] is not duplicate_tool # Ensure it's the original mock
|
||||
|
||||
# No patching needed here
|
||||
async def test_disconnect_non_existent_server(self):
|
||||
@@ -292,9 +274,7 @@ class TestClientSessionGroup:
|
||||
"mcp.client.session_group.sse_client",
|
||||
), # url, headers, timeout, sse_read_timeout
|
||||
(
|
||||
StreamableHttpParameters(
|
||||
url="http://test.com/stream", terminate_on_close=False
|
||||
),
|
||||
StreamableHttpParameters(url="http://test.com/stream", terminate_on_close=False),
|
||||
"streamablehttp",
|
||||
"mcp.client.session_group.streamablehttp_client",
|
||||
), # url, headers, timeout, sse_read_timeout, terminate_on_close
|
||||
@@ -306,13 +286,9 @@ class TestClientSessionGroup:
|
||||
client_type_name, # Just for clarity or conditional logic if needed
|
||||
patch_target_for_client_func,
|
||||
):
|
||||
with mock.patch(
|
||||
"mcp.client.session_group.mcp.ClientSession"
|
||||
) as mock_ClientSession_class:
|
||||
with mock.patch("mcp.client.session_group.mcp.ClientSession") as mock_ClientSession_class:
|
||||
with mock.patch(patch_target_for_client_func) as mock_specific_client_func:
|
||||
mock_client_cm_instance = mock.AsyncMock(
|
||||
name=f"{client_type_name}ClientCM"
|
||||
)
|
||||
mock_client_cm_instance = mock.AsyncMock(name=f"{client_type_name}ClientCM")
|
||||
mock_read_stream = mock.AsyncMock(name=f"{client_type_name}Read")
|
||||
mock_write_stream = mock.AsyncMock(name=f"{client_type_name}Write")
|
||||
|
||||
@@ -344,9 +320,7 @@ class TestClientSessionGroup:
|
||||
|
||||
# Mock session.initialize()
|
||||
mock_initialize_result = mock.AsyncMock(name="InitializeResult")
|
||||
mock_initialize_result.serverInfo = types.Implementation(
|
||||
name="foo", version="1"
|
||||
)
|
||||
mock_initialize_result.serverInfo = types.Implementation(name="foo", version="1")
|
||||
mock_entered_session.initialize.return_value = mock_initialize_result
|
||||
|
||||
# --- Test Execution ---
|
||||
@@ -364,9 +338,7 @@ class TestClientSessionGroup:
|
||||
# --- Assertions ---
|
||||
# 1. Assert the correct specific client function was called
|
||||
if client_type_name == "stdio":
|
||||
mock_specific_client_func.assert_called_once_with(
|
||||
server_params_instance
|
||||
)
|
||||
mock_specific_client_func.assert_called_once_with(server_params_instance)
|
||||
elif client_type_name == "sse":
|
||||
mock_specific_client_func.assert_called_once_with(
|
||||
url=server_params_instance.url,
|
||||
@@ -386,9 +358,7 @@ class TestClientSessionGroup:
|
||||
mock_client_cm_instance.__aenter__.assert_awaited_once()
|
||||
|
||||
# 2. Assert ClientSession was called correctly
|
||||
mock_ClientSession_class.assert_called_once_with(
|
||||
mock_read_stream, mock_write_stream
|
||||
)
|
||||
mock_ClientSession_class.assert_called_once_with(mock_read_stream, mock_write_stream)
|
||||
mock_raw_session_cm.__aenter__.assert_awaited_once()
|
||||
mock_entered_session.initialize.assert_awaited_once()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user