Use 120 characters instead of 88 (#856)

This commit is contained in:
Marcelo Trylesinski
2025-06-11 02:45:50 -07:00
committed by GitHub
parent f7265f7b91
commit 543961968c
90 changed files with 687 additions and 2142 deletions

View File

@@ -58,14 +58,10 @@ class TestClientSessionGroup:
return f"{(server_info.name)}-{name}"
mcp_session_group = ClientSessionGroup(component_name_hook=hook)
mcp_session_group._tools = {
"server1-my_tool": types.Tool(name="my_tool", inputSchema={})
}
mcp_session_group._tools = {"server1-my_tool": types.Tool(name="my_tool", inputSchema={})}
mcp_session_group._tool_to_session = {"server1-my_tool": mock_session}
text_content = types.TextContent(type="text", text="OK")
mock_session.call_tool.return_value = types.CallToolResult(
content=[text_content]
)
mock_session.call_tool.return_value = types.CallToolResult(content=[text_content])
# --- Test Execution ---
result = await mcp_session_group.call_tool(
@@ -96,16 +92,12 @@ class TestClientSessionGroup:
mock_prompt1 = mock.Mock(spec=types.Prompt)
mock_prompt1.name = "prompt_c"
mock_session.list_tools.return_value = mock.AsyncMock(tools=[mock_tool1])
mock_session.list_resources.return_value = mock.AsyncMock(
resources=[mock_resource1]
)
mock_session.list_resources.return_value = mock.AsyncMock(resources=[mock_resource1])
mock_session.list_prompts.return_value = mock.AsyncMock(prompts=[mock_prompt1])
# --- Test Execution ---
group = ClientSessionGroup(exit_stack=mock_exit_stack)
with mock.patch.object(
group, "_establish_session", return_value=(mock_server_info, mock_session)
):
with mock.patch.object(group, "_establish_session", return_value=(mock_server_info, mock_session)):
await group.connect_to_server(StdioServerParameters(command="test"))
# --- Assertions ---
@@ -141,12 +133,8 @@ class TestClientSessionGroup:
return f"{server_info.name}.{name}"
# --- Test Execution ---
group = ClientSessionGroup(
exit_stack=mock_exit_stack, component_name_hook=name_hook
)
with mock.patch.object(
group, "_establish_session", return_value=(mock_server_info, mock_session)
):
group = ClientSessionGroup(exit_stack=mock_exit_stack, component_name_hook=name_hook)
with mock.patch.object(group, "_establish_session", return_value=(mock_server_info, mock_session)):
await group.connect_to_server(StdioServerParameters(command="test"))
# --- Assertions ---
@@ -231,9 +219,7 @@ class TestClientSessionGroup:
# Need a dummy session associated with the existing tool
mock_session = mock.MagicMock(spec=mcp.ClientSession)
group._tool_to_session[existing_tool_name] = mock_session
group._session_exit_stacks[mock_session] = mock.Mock(
spec=contextlib.AsyncExitStack
)
group._session_exit_stacks[mock_session] = mock.Mock(spec=contextlib.AsyncExitStack)
# --- Mock New Connection Attempt ---
mock_server_info_new = mock.Mock(spec=types.Implementation)
@@ -243,9 +229,7 @@ class TestClientSessionGroup:
# Configure the new session to return a tool with the *same name*
duplicate_tool = mock.Mock(spec=types.Tool)
duplicate_tool.name = existing_tool_name
mock_session_new.list_tools.return_value = mock.AsyncMock(
tools=[duplicate_tool]
)
mock_session_new.list_tools.return_value = mock.AsyncMock(tools=[duplicate_tool])
# Keep other lists empty for simplicity
mock_session_new.list_resources.return_value = mock.AsyncMock(resources=[])
mock_session_new.list_prompts.return_value = mock.AsyncMock(prompts=[])
@@ -266,9 +250,7 @@ class TestClientSessionGroup:
# Verify the duplicate tool was *not* added again (state should be unchanged)
assert len(group._tools) == 1 # Should still only have the original
assert (
group._tools[existing_tool_name] is not duplicate_tool
) # Ensure it's the original mock
assert group._tools[existing_tool_name] is not duplicate_tool # Ensure it's the original mock
# No patching needed here
async def test_disconnect_non_existent_server(self):
@@ -292,9 +274,7 @@ class TestClientSessionGroup:
"mcp.client.session_group.sse_client",
), # url, headers, timeout, sse_read_timeout
(
StreamableHttpParameters(
url="http://test.com/stream", terminate_on_close=False
),
StreamableHttpParameters(url="http://test.com/stream", terminate_on_close=False),
"streamablehttp",
"mcp.client.session_group.streamablehttp_client",
), # url, headers, timeout, sse_read_timeout, terminate_on_close
@@ -306,13 +286,9 @@ class TestClientSessionGroup:
client_type_name, # Just for clarity or conditional logic if needed
patch_target_for_client_func,
):
with mock.patch(
"mcp.client.session_group.mcp.ClientSession"
) as mock_ClientSession_class:
with mock.patch("mcp.client.session_group.mcp.ClientSession") as mock_ClientSession_class:
with mock.patch(patch_target_for_client_func) as mock_specific_client_func:
mock_client_cm_instance = mock.AsyncMock(
name=f"{client_type_name}ClientCM"
)
mock_client_cm_instance = mock.AsyncMock(name=f"{client_type_name}ClientCM")
mock_read_stream = mock.AsyncMock(name=f"{client_type_name}Read")
mock_write_stream = mock.AsyncMock(name=f"{client_type_name}Write")
@@ -344,9 +320,7 @@ class TestClientSessionGroup:
# Mock session.initialize()
mock_initialize_result = mock.AsyncMock(name="InitializeResult")
mock_initialize_result.serverInfo = types.Implementation(
name="foo", version="1"
)
mock_initialize_result.serverInfo = types.Implementation(name="foo", version="1")
mock_entered_session.initialize.return_value = mock_initialize_result
# --- Test Execution ---
@@ -364,9 +338,7 @@ class TestClientSessionGroup:
# --- Assertions ---
# 1. Assert the correct specific client function was called
if client_type_name == "stdio":
mock_specific_client_func.assert_called_once_with(
server_params_instance
)
mock_specific_client_func.assert_called_once_with(server_params_instance)
elif client_type_name == "sse":
mock_specific_client_func.assert_called_once_with(
url=server_params_instance.url,
@@ -386,9 +358,7 @@ class TestClientSessionGroup:
mock_client_cm_instance.__aenter__.assert_awaited_once()
# 2. Assert ClientSession was called correctly
mock_ClientSession_class.assert_called_once_with(
mock_read_stream, mock_write_stream
)
mock_ClientSession_class.assert_called_once_with(mock_read_stream, mock_write_stream)
mock_raw_session_cm.__aenter__.assert_awaited_once()
mock_entered_session.initialize.assert_awaited_once()