Use 120 characters instead of 88 (#856)

This commit is contained in:
Marcelo Trylesinski
2025-06-11 02:45:50 -07:00
committed by GitHub
parent f7265f7b91
commit 543961968c
90 changed files with 687 additions and 2142 deletions

View File

@@ -49,8 +49,7 @@ class StreamSpyCollection:
return [
req.message.root
for req in self.client.sent_messages
if isinstance(req.message.root, JSONRPCRequest)
and (method is None or req.message.root.method == method)
if isinstance(req.message.root, JSONRPCRequest) and (method is None or req.message.root.method == method)
]
def get_server_requests(self, method: str | None = None) -> list[JSONRPCRequest]:
@@ -58,13 +57,10 @@ class StreamSpyCollection:
return [
req.message.root
for req in self.server.sent_messages
if isinstance(req.message.root, JSONRPCRequest)
and (method is None or req.message.root.method == method)
if isinstance(req.message.root, JSONRPCRequest) and (method is None or req.message.root.method == method)
]
def get_client_notifications(
self, method: str | None = None
) -> list[JSONRPCNotification]:
def get_client_notifications(self, method: str | None = None) -> list[JSONRPCNotification]:
"""Get client-sent notifications, optionally filtered by method."""
return [
notif.message.root
@@ -73,9 +69,7 @@ class StreamSpyCollection:
and (method is None or notif.message.root.method == method)
]
def get_server_notifications(
self, method: str | None = None
) -> list[JSONRPCNotification]:
def get_server_notifications(self, method: str | None = None) -> list[JSONRPCNotification]:
"""Get server-sent notifications, optionally filtered by method."""
return [
notif.message.root
@@ -133,9 +127,7 @@ def stream_spy():
yield (client_read, spy_client_write), (server_read, spy_server_write)
# Apply the patch for the duration of the test
with patch(
"mcp.shared.memory.create_client_server_memory_streams", patched_create_streams
):
with patch("mcp.shared.memory.create_client_server_memory_streams", patched_create_streams):
# Return a collection with helper methods
def get_spy_collection() -> StreamSpyCollection:
assert client_spy is not None, "client_spy was not initialized"

View File

@@ -134,9 +134,7 @@ class TestOAuthClientProvider:
assert len(verifier) == 128
# Check charset (RFC 7636: A-Z, a-z, 0-9, "-", ".", "_", "~")
allowed_chars = set(
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~"
)
allowed_chars = set("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~")
assert set(verifier) <= allowed_chars
# Check uniqueness (generate multiple and ensure they're different)
@@ -151,9 +149,7 @@ class TestOAuthClientProvider:
# Manually calculate expected challenge
expected_digest = hashlib.sha256(verifier.encode()).digest()
expected_challenge = (
base64.urlsafe_b64encode(expected_digest).decode().rstrip("=")
)
expected_challenge = base64.urlsafe_b64encode(expected_digest).decode().rstrip("=")
assert challenge == expected_challenge
@@ -166,29 +162,19 @@ class TestOAuthClientProvider:
async def test_get_authorization_base_url(self, oauth_provider):
"""Test authorization base URL extraction."""
# Test with path
assert (
oauth_provider._get_authorization_base_url("https://api.example.com/v1/mcp")
== "https://api.example.com"
)
assert oauth_provider._get_authorization_base_url("https://api.example.com/v1/mcp") == "https://api.example.com"
# Test with no path
assert (
oauth_provider._get_authorization_base_url("https://api.example.com")
== "https://api.example.com"
)
assert oauth_provider._get_authorization_base_url("https://api.example.com") == "https://api.example.com"
# Test with port
assert (
oauth_provider._get_authorization_base_url(
"https://api.example.com:8080/path/to/mcp"
)
oauth_provider._get_authorization_base_url("https://api.example.com:8080/path/to/mcp")
== "https://api.example.com:8080"
)
@pytest.mark.anyio
async def test_discover_oauth_metadata_success(
self, oauth_provider, oauth_metadata
):
async def test_discover_oauth_metadata_success(self, oauth_provider, oauth_metadata):
"""Test successful OAuth metadata discovery."""
metadata_response = oauth_metadata.model_dump(by_alias=True, mode="json")
@@ -201,23 +187,16 @@ class TestOAuthClientProvider:
mock_response.json.return_value = metadata_response
mock_client.get.return_value = mock_response
result = await oauth_provider._discover_oauth_metadata(
"https://api.example.com/v1/mcp"
)
result = await oauth_provider._discover_oauth_metadata("https://api.example.com/v1/mcp")
assert result is not None
assert (
result.authorization_endpoint == oauth_metadata.authorization_endpoint
)
assert result.authorization_endpoint == oauth_metadata.authorization_endpoint
assert result.token_endpoint == oauth_metadata.token_endpoint
# Verify correct URL was called
mock_client.get.assert_called_once()
call_args = mock_client.get.call_args[0]
assert (
call_args[0]
== "https://api.example.com/.well-known/oauth-authorization-server"
)
assert call_args[0] == "https://api.example.com/.well-known/oauth-authorization-server"
@pytest.mark.anyio
async def test_discover_oauth_metadata_not_found(self, oauth_provider):
@@ -230,16 +209,12 @@ class TestOAuthClientProvider:
mock_response.status_code = 404
mock_client.get.return_value = mock_response
result = await oauth_provider._discover_oauth_metadata(
"https://api.example.com/v1/mcp"
)
result = await oauth_provider._discover_oauth_metadata("https://api.example.com/v1/mcp")
assert result is None
@pytest.mark.anyio
async def test_discover_oauth_metadata_cors_fallback(
self, oauth_provider, oauth_metadata
):
async def test_discover_oauth_metadata_cors_fallback(self, oauth_provider, oauth_metadata):
"""Test OAuth metadata discovery with CORS fallback."""
metadata_response = oauth_metadata.model_dump(by_alias=True, mode="json")
@@ -257,17 +232,13 @@ class TestOAuthClientProvider:
mock_response_success, # Second call succeeds
]
result = await oauth_provider._discover_oauth_metadata(
"https://api.example.com/v1/mcp"
)
result = await oauth_provider._discover_oauth_metadata("https://api.example.com/v1/mcp")
assert result is not None
assert mock_client.get.call_count == 2
@pytest.mark.anyio
async def test_register_oauth_client_success(
self, oauth_provider, oauth_metadata, oauth_client_info
):
async def test_register_oauth_client_success(self, oauth_provider, oauth_metadata, oauth_client_info):
"""Test successful OAuth client registration."""
registration_response = oauth_client_info.model_dump(by_alias=True, mode="json")
@@ -295,9 +266,7 @@ class TestOAuthClientProvider:
assert call_args[0][0] == str(oauth_metadata.registration_endpoint)
@pytest.mark.anyio
async def test_register_oauth_client_fallback_endpoint(
self, oauth_provider, oauth_client_info
):
async def test_register_oauth_client_fallback_endpoint(self, oauth_provider, oauth_client_info):
"""Test OAuth client registration with fallback endpoint."""
registration_response = oauth_client_info.model_dump(by_alias=True, mode="json")
@@ -311,9 +280,7 @@ class TestOAuthClientProvider:
mock_client.post.return_value = mock_response
# Mock metadata discovery to return None (fallback)
with patch.object(
oauth_provider, "_discover_oauth_metadata", return_value=None
):
with patch.object(oauth_provider, "_discover_oauth_metadata", return_value=None):
result = await oauth_provider._register_oauth_client(
"https://api.example.com/v1/mcp",
oauth_provider.client_metadata,
@@ -340,9 +307,7 @@ class TestOAuthClientProvider:
mock_client.post.return_value = mock_response
# Mock metadata discovery to return None (fallback)
with patch.object(
oauth_provider, "_discover_oauth_metadata", return_value=None
):
with patch.object(oauth_provider, "_discover_oauth_metadata", return_value=None):
with pytest.raises(httpx.HTTPStatusError):
await oauth_provider._register_oauth_client(
"https://api.example.com/v1/mcp",
@@ -406,9 +371,7 @@ class TestOAuthClientProvider:
await oauth_provider._validate_token_scopes(token)
@pytest.mark.anyio
async def test_validate_token_scopes_unauthorized(
self, oauth_provider, client_metadata
):
async def test_validate_token_scopes_unauthorized(self, oauth_provider, client_metadata):
"""Test scope validation with unauthorized scopes."""
oauth_provider.client_metadata = client_metadata
token = OAuthToken(
@@ -436,9 +399,7 @@ class TestOAuthClientProvider:
await oauth_provider._validate_token_scopes(token)
@pytest.mark.anyio
async def test_initialize(
self, oauth_provider, mock_storage, oauth_token, oauth_client_info
):
async def test_initialize(self, oauth_provider, mock_storage, oauth_token, oauth_client_info):
"""Test initialization loading from storage."""
mock_storage._tokens = oauth_token
mock_storage._client_info = oauth_client_info
@@ -449,9 +410,7 @@ class TestOAuthClientProvider:
assert oauth_provider._client_info == oauth_client_info
@pytest.mark.anyio
async def test_get_or_register_client_existing(
self, oauth_provider, oauth_client_info
):
async def test_get_or_register_client_existing(self, oauth_provider, oauth_client_info):
"""Test getting existing client info."""
oauth_provider._client_info = oauth_client_info
@@ -460,13 +419,9 @@ class TestOAuthClientProvider:
assert result == oauth_client_info
@pytest.mark.anyio
async def test_get_or_register_client_register_new(
self, oauth_provider, oauth_client_info
):
async def test_get_or_register_client_register_new(self, oauth_provider, oauth_client_info):
"""Test registering new client."""
with patch.object(
oauth_provider, "_register_oauth_client", return_value=oauth_client_info
) as mock_register:
with patch.object(oauth_provider, "_register_oauth_client", return_value=oauth_client_info) as mock_register:
result = await oauth_provider._get_or_register_client()
assert result == oauth_client_info
@@ -474,9 +429,7 @@ class TestOAuthClientProvider:
mock_register.assert_called_once()
@pytest.mark.anyio
async def test_exchange_code_for_token_success(
self, oauth_provider, oauth_client_info, oauth_token
):
async def test_exchange_code_for_token_success(self, oauth_provider, oauth_client_info, oauth_token):
"""Test successful code exchange for token."""
oauth_provider._code_verifier = "test_verifier"
token_response = oauth_token.model_dump(by_alias=True, mode="json")
@@ -490,23 +443,14 @@ class TestOAuthClientProvider:
mock_response.json.return_value = token_response
mock_client.post.return_value = mock_response
with patch.object(
oauth_provider, "_validate_token_scopes"
) as mock_validate:
await oauth_provider._exchange_code_for_token(
"test_auth_code", oauth_client_info
)
with patch.object(oauth_provider, "_validate_token_scopes") as mock_validate:
await oauth_provider._exchange_code_for_token("test_auth_code", oauth_client_info)
assert (
oauth_provider._current_tokens.access_token
== oauth_token.access_token
)
assert oauth_provider._current_tokens.access_token == oauth_token.access_token
mock_validate.assert_called_once()
@pytest.mark.anyio
async def test_exchange_code_for_token_failure(
self, oauth_provider, oauth_client_info
):
async def test_exchange_code_for_token_failure(self, oauth_provider, oauth_client_info):
"""Test failed code exchange for token."""
oauth_provider._code_verifier = "test_verifier"
@@ -520,14 +464,10 @@ class TestOAuthClientProvider:
mock_client.post.return_value = mock_response
with pytest.raises(Exception, match="Token exchange failed"):
await oauth_provider._exchange_code_for_token(
"invalid_auth_code", oauth_client_info
)
await oauth_provider._exchange_code_for_token("invalid_auth_code", oauth_client_info)
@pytest.mark.anyio
async def test_refresh_access_token_success(
self, oauth_provider, oauth_client_info, oauth_token
):
async def test_refresh_access_token_success(self, oauth_provider, oauth_client_info, oauth_token):
"""Test successful token refresh."""
oauth_provider._current_tokens = oauth_token
oauth_provider._client_info = oauth_client_info
@@ -550,16 +490,11 @@ class TestOAuthClientProvider:
mock_response.json.return_value = token_response
mock_client.post.return_value = mock_response
with patch.object(
oauth_provider, "_validate_token_scopes"
) as mock_validate:
with patch.object(oauth_provider, "_validate_token_scopes") as mock_validate:
result = await oauth_provider._refresh_access_token()
assert result is True
assert (
oauth_provider._current_tokens.access_token
== new_token.access_token
)
assert oauth_provider._current_tokens.access_token == new_token.access_token
mock_validate.assert_called_once()
@pytest.mark.anyio
@@ -575,9 +510,7 @@ class TestOAuthClientProvider:
assert result is False
@pytest.mark.anyio
async def test_refresh_access_token_failure(
self, oauth_provider, oauth_client_info, oauth_token
):
async def test_refresh_access_token_failure(self, oauth_provider, oauth_client_info, oauth_token):
"""Test failed token refresh."""
oauth_provider._current_tokens = oauth_token
oauth_provider._client_info = oauth_client_info
@@ -594,9 +527,7 @@ class TestOAuthClientProvider:
assert result is False
@pytest.mark.anyio
async def test_perform_oauth_flow_success(
self, oauth_provider, oauth_metadata, oauth_client_info
):
async def test_perform_oauth_flow_success(self, oauth_provider, oauth_metadata, oauth_client_info):
"""Test successful OAuth flow."""
oauth_provider._metadata = oauth_metadata
oauth_provider._client_info = oauth_client_info
@@ -640,9 +571,7 @@ class TestOAuthClientProvider:
mock_exchange.assert_called_once_with("test_auth_code", oauth_client_info)
@pytest.mark.anyio
async def test_perform_oauth_flow_state_mismatch(
self, oauth_provider, oauth_metadata, oauth_client_info
):
async def test_perform_oauth_flow_state_mismatch(self, oauth_provider, oauth_metadata, oauth_client_info):
"""Test OAuth flow with state parameter mismatch."""
oauth_provider._metadata = oauth_metadata
oauth_provider._client_info = oauth_client_info
@@ -678,9 +607,7 @@ class TestOAuthClientProvider:
oauth_provider._current_tokens = oauth_token
oauth_provider._token_expiry_time = time.time() - 3600 # Expired
with patch.object(
oauth_provider, "_refresh_access_token", return_value=True
) as mock_refresh:
with patch.object(oauth_provider, "_refresh_access_token", return_value=True) as mock_refresh:
await oauth_provider.ensure_token()
mock_refresh.assert_called_once()
@@ -707,10 +634,7 @@ class TestOAuthClientProvider:
auth_flow = oauth_provider.async_auth_flow(request)
updated_request = await auth_flow.__anext__()
assert (
updated_request.headers["Authorization"]
== f"Bearer {oauth_token.access_token}"
)
assert updated_request.headers["Authorization"] == f"Bearer {oauth_token.access_token}"
# Send mock response
try:
@@ -761,9 +685,7 @@ class TestOAuthClientProvider:
assert "Authorization" not in updated_request.headers
@pytest.mark.anyio
async def test_scope_priority_client_metadata_first(
self, oauth_provider, oauth_client_info
):
async def test_scope_priority_client_metadata_first(self, oauth_provider, oauth_client_info):
"""Test that client metadata scope takes priority."""
oauth_provider.client_metadata.scope = "read write"
oauth_provider._client_info = oauth_client_info
@@ -782,18 +704,13 @@ class TestOAuthClientProvider:
# Apply scope logic from _perform_oauth_flow
if oauth_provider.client_metadata.scope:
auth_params["scope"] = oauth_provider.client_metadata.scope
elif (
hasattr(oauth_provider._client_info, "scope")
and oauth_provider._client_info.scope
):
elif hasattr(oauth_provider._client_info, "scope") and oauth_provider._client_info.scope:
auth_params["scope"] = oauth_provider._client_info.scope
assert auth_params["scope"] == "read write"
@pytest.mark.anyio
async def test_scope_priority_no_client_metadata_scope(
self, oauth_provider, oauth_client_info
):
async def test_scope_priority_no_client_metadata_scope(self, oauth_provider, oauth_client_info):
"""Test that no scope parameter is set when client metadata has no scope."""
oauth_provider.client_metadata.scope = None
oauth_provider._client_info = oauth_client_info
@@ -837,10 +754,7 @@ class TestOAuthClientProvider:
# Apply scope logic from _perform_oauth_flow
if oauth_provider.client_metadata.scope:
auth_params["scope"] = oauth_provider.client_metadata.scope
elif (
hasattr(oauth_provider._client_info, "scope")
and oauth_provider._client_info.scope
):
elif hasattr(oauth_provider._client_info, "scope") and oauth_provider._client_info.scope:
auth_params["scope"] = oauth_provider._client_info.scope
# No scope should be set
@@ -866,9 +780,7 @@ class TestOAuthClientProvider:
oauth_provider.redirect_handler = mock_redirect_handler
# Patch secrets.compare_digest to verify it's being called
with patch(
"mcp.client.auth.secrets.compare_digest", return_value=False
) as mock_compare:
with patch("mcp.client.auth.secrets.compare_digest", return_value=False) as mock_compare:
with pytest.raises(Exception, match="State parameter mismatch"):
await oauth_provider._perform_oauth_flow()
@@ -876,9 +788,7 @@ class TestOAuthClientProvider:
mock_compare.assert_called_once()
@pytest.mark.anyio
async def test_state_parameter_validation_none_state(
self, oauth_provider, oauth_metadata, oauth_client_info
):
async def test_state_parameter_validation_none_state(self, oauth_provider, oauth_metadata, oauth_client_info):
"""Test that None state is handled correctly."""
oauth_provider._metadata = oauth_metadata
oauth_provider._client_info = oauth_client_info
@@ -913,9 +823,7 @@ class TestOAuthClientProvider:
mock_client.post.return_value = mock_response
with pytest.raises(Exception, match="Token exchange failed"):
await oauth_provider._exchange_code_for_token(
"invalid_auth_code", oauth_client_info
)
await oauth_provider._exchange_code_for_token("invalid_auth_code", oauth_client_info)
@pytest.mark.parametrize(
@@ -968,9 +876,7 @@ def test_build_metadata(
metadata = build_metadata(
issuer_url=AnyHttpUrl(issuer_url),
service_documentation_url=AnyHttpUrl(service_documentation_url),
client_registration_options=ClientRegistrationOptions(
enabled=True, valid_scopes=["read", "write", "admin"]
),
client_registration_options=ClientRegistrationOptions(enabled=True, valid_scopes=["read", "write", "admin"]),
revocation_options=RevocationOptions(enabled=True),
)

View File

@@ -44,9 +44,7 @@ def test_command_execution(mock_config_path: Path):
test_args = [command] + args + ["--help"]
result = subprocess.run(
test_args, capture_output=True, text=True, timeout=5, check=False
)
result = subprocess.run(test_args, capture_output=True, text=True, timeout=5, check=False)
assert result.returncode == 0
assert "usage" in result.stdout.lower()

View File

@@ -182,9 +182,7 @@ async def test_list_resource_templates_cursor_parameter(stream_spy):
# Test without cursor parameter (omitted)
_ = await client_session.list_resource_templates()
list_templates_requests = spies.get_client_requests(
method="resources/templates/list"
)
list_templates_requests = spies.get_client_requests(method="resources/templates/list")
assert len(list_templates_requests) == 1
assert list_templates_requests[0].params is None
@@ -192,9 +190,7 @@ async def test_list_resource_templates_cursor_parameter(stream_spy):
# Test with cursor=None
_ = await client_session.list_resource_templates(cursor=None)
list_templates_requests = spies.get_client_requests(
method="resources/templates/list"
)
list_templates_requests = spies.get_client_requests(method="resources/templates/list")
assert len(list_templates_requests) == 1
assert list_templates_requests[0].params is None
@@ -202,9 +198,7 @@ async def test_list_resource_templates_cursor_parameter(stream_spy):
# Test with cursor as string
_ = await client_session.list_resource_templates(cursor="some_cursor")
list_templates_requests = spies.get_client_requests(
method="resources/templates/list"
)
list_templates_requests = spies.get_client_requests(method="resources/templates/list")
assert len(list_templates_requests) == 1
assert list_templates_requests[0].params is not None
assert list_templates_requests[0].params["cursor"] == "some_cursor"
@@ -213,9 +207,7 @@ async def test_list_resource_templates_cursor_parameter(stream_spy):
# Test with empty string cursor
_ = await client_session.list_resource_templates(cursor="")
list_templates_requests = spies.get_client_requests(
method="resources/templates/list"
)
list_templates_requests = spies.get_client_requests(method="resources/templates/list")
assert len(list_templates_requests) == 1
assert list_templates_requests[0].params is not None
assert list_templates_requests[0].params["cursor"] == ""

View File

@@ -41,13 +41,9 @@ async def test_list_roots_callback():
return True
# Test with list_roots callback
async with create_session(
server._mcp_server, list_roots_callback=list_roots_callback
) as client_session:
async with create_session(server._mcp_server, list_roots_callback=list_roots_callback) as client_session:
# Make a request to trigger sampling callback
result = await client_session.call_tool(
"test_list_roots", {"message": "test message"}
)
result = await client_session.call_tool("test_list_roots", {"message": "test message"})
assert result.isError is False
assert isinstance(result.content[0], TextContent)
assert result.content[0].text == "true"
@@ -55,12 +51,7 @@ async def test_list_roots_callback():
# Test without list_roots callback
async with create_session(server._mcp_server) as client_session:
# Make a request to trigger sampling callback
result = await client_session.call_tool(
"test_list_roots", {"message": "test message"}
)
result = await client_session.call_tool("test_list_roots", {"message": "test message"})
assert result.isError is True
assert isinstance(result.content[0], TextContent)
assert (
result.content[0].text
== "Error executing tool test_list_roots: List roots not supported"
)
assert result.content[0].text == "Error executing tool test_list_roots: List roots not supported"

View File

@@ -49,9 +49,7 @@ async def test_logging_callback():
# Create a message handler to catch exceptions
async def message_handler(
message: RequestResponder[types.ServerRequest, types.ClientResult]
| types.ServerNotification
| Exception,
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
) -> None:
if isinstance(message, Exception):
raise message

View File

@@ -21,9 +21,7 @@ async def test_sampling_callback():
callback_return = CreateMessageResult(
role="assistant",
content=TextContent(
type="text", text="This is a response from the sampling callback"
),
content=TextContent(type="text", text="This is a response from the sampling callback"),
model="test-model",
stopReason="endTurn",
)
@@ -37,24 +35,16 @@ async def test_sampling_callback():
@server.tool("test_sampling")
async def test_sampling_tool(message: str):
value = await server.get_context().session.create_message(
messages=[
SamplingMessage(
role="user", content=TextContent(type="text", text=message)
)
],
messages=[SamplingMessage(role="user", content=TextContent(type="text", text=message))],
max_tokens=100,
)
assert value == callback_return
return True
# Test with sampling callback
async with create_session(
server._mcp_server, sampling_callback=sampling_callback
) as client_session:
async with create_session(server._mcp_server, sampling_callback=sampling_callback) as client_session:
# Make a request to trigger sampling callback
result = await client_session.call_tool(
"test_sampling", {"message": "Test message for sampling"}
)
result = await client_session.call_tool("test_sampling", {"message": "Test message for sampling"})
assert result.isError is False
assert isinstance(result.content[0], TextContent)
assert result.content[0].text == "true"
@@ -62,12 +52,7 @@ async def test_sampling_callback():
# Test without sampling callback
async with create_session(server._mcp_server) as client_session:
# Make a request to trigger sampling callback
result = await client_session.call_tool(
"test_sampling", {"message": "Test message for sampling"}
)
result = await client_session.call_tool("test_sampling", {"message": "Test message for sampling"})
assert result.isError is True
assert isinstance(result.content[0], TextContent)
assert (
result.content[0].text
== "Error executing tool test_sampling: Sampling not supported"
)
assert result.content[0].text == "Error executing tool test_sampling: Sampling not supported"

View File

@@ -28,12 +28,8 @@ from mcp.types import (
@pytest.mark.anyio
async def test_client_session_initialize():
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
SessionMessage
](1)
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
SessionMessage
](1)
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
initialized_notification = None
@@ -70,9 +66,7 @@ async def test_client_session_initialize():
JSONRPCResponse(
jsonrpc="2.0",
id=jsonrpc_request.root.id,
result=result.model_dump(
by_alias=True, mode="json", exclude_none=True
),
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
)
)
)
@@ -81,16 +75,12 @@ async def test_client_session_initialize():
jsonrpc_notification = session_notification.message
assert isinstance(jsonrpc_notification.root, JSONRPCNotification)
initialized_notification = ClientNotification.model_validate(
jsonrpc_notification.model_dump(
by_alias=True, mode="json", exclude_none=True
)
jsonrpc_notification.model_dump(by_alias=True, mode="json", exclude_none=True)
)
# Create a message handler to catch exceptions
async def message_handler(
message: RequestResponder[types.ServerRequest, types.ClientResult]
| types.ServerNotification
| Exception,
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
) -> None:
if isinstance(message, Exception):
raise message
@@ -124,12 +114,8 @@ async def test_client_session_initialize():
@pytest.mark.anyio
async def test_client_session_custom_client_info():
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
SessionMessage
](1)
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
SessionMessage
](1)
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
custom_client_info = Implementation(name="test-client", version="1.2.3")
received_client_info = None
@@ -161,9 +147,7 @@ async def test_client_session_custom_client_info():
JSONRPCResponse(
jsonrpc="2.0",
id=jsonrpc_request.root.id,
result=result.model_dump(
by_alias=True, mode="json", exclude_none=True
),
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
)
)
)
@@ -192,12 +176,8 @@ async def test_client_session_custom_client_info():
@pytest.mark.anyio
async def test_client_session_default_client_info():
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
SessionMessage
](1)
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
SessionMessage
](1)
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
received_client_info = None
@@ -228,9 +208,7 @@ async def test_client_session_default_client_info():
JSONRPCResponse(
jsonrpc="2.0",
id=jsonrpc_request.root.id,
result=result.model_dump(
by_alias=True, mode="json", exclude_none=True
),
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
)
)
)
@@ -259,12 +237,8 @@ async def test_client_session_default_client_info():
@pytest.mark.anyio
async def test_client_session_version_negotiation_success():
"""Test successful version negotiation with supported version"""
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
SessionMessage
](1)
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
SessionMessage
](1)
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
async def mock_server():
session_message = await client_to_server_receive.receive()
@@ -294,9 +268,7 @@ async def test_client_session_version_negotiation_success():
JSONRPCResponse(
jsonrpc="2.0",
id=jsonrpc_request.root.id,
result=result.model_dump(
by_alias=True, mode="json", exclude_none=True
),
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
)
)
)
@@ -327,12 +299,8 @@ async def test_client_session_version_negotiation_success():
@pytest.mark.anyio
async def test_client_session_version_negotiation_failure():
"""Test version negotiation failure with unsupported version"""
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
SessionMessage
](1)
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
SessionMessage
](1)
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
async def mock_server():
session_message = await client_to_server_receive.receive()
@@ -359,9 +327,7 @@ async def test_client_session_version_negotiation_failure():
JSONRPCResponse(
jsonrpc="2.0",
id=jsonrpc_request.root.id,
result=result.model_dump(
by_alias=True, mode="json", exclude_none=True
),
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
)
)
)
@@ -388,12 +354,8 @@ async def test_client_session_version_negotiation_failure():
@pytest.mark.anyio
async def test_client_capabilities_default():
"""Test that client capabilities are properly set with default callbacks"""
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
SessionMessage
](1)
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
SessionMessage
](1)
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
received_capabilities = None
@@ -424,9 +386,7 @@ async def test_client_capabilities_default():
JSONRPCResponse(
jsonrpc="2.0",
id=jsonrpc_request.root.id,
result=result.model_dump(
by_alias=True, mode="json", exclude_none=True
),
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
)
)
)
@@ -457,12 +417,8 @@ async def test_client_capabilities_default():
@pytest.mark.anyio
async def test_client_capabilities_with_custom_callbacks():
"""Test that client capabilities are properly set with custom callbacks"""
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
SessionMessage
](1)
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
SessionMessage
](1)
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
received_capabilities = None
@@ -508,9 +464,7 @@ async def test_client_capabilities_with_custom_callbacks():
JSONRPCResponse(
jsonrpc="2.0",
id=jsonrpc_request.root.id,
result=result.model_dump(
by_alias=True, mode="json", exclude_none=True
),
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
)
)
)
@@ -536,14 +490,8 @@ async def test_client_capabilities_with_custom_callbacks():
# Assert that capabilities are properly set with custom callbacks
assert received_capabilities is not None
assert (
received_capabilities.sampling is not None
) # Custom sampling callback provided
assert received_capabilities.sampling is not None # Custom sampling callback provided
assert isinstance(received_capabilities.sampling, types.SamplingCapability)
assert (
received_capabilities.roots is not None
) # Custom list_roots callback provided
assert received_capabilities.roots is not None # Custom list_roots callback provided
assert isinstance(received_capabilities.roots, types.RootsCapability)
assert (
received_capabilities.roots.listChanged is True
) # Should be True for custom callback
assert received_capabilities.roots.listChanged is True # Should be True for custom callback

View File

@@ -58,14 +58,10 @@ class TestClientSessionGroup:
return f"{(server_info.name)}-{name}"
mcp_session_group = ClientSessionGroup(component_name_hook=hook)
mcp_session_group._tools = {
"server1-my_tool": types.Tool(name="my_tool", inputSchema={})
}
mcp_session_group._tools = {"server1-my_tool": types.Tool(name="my_tool", inputSchema={})}
mcp_session_group._tool_to_session = {"server1-my_tool": mock_session}
text_content = types.TextContent(type="text", text="OK")
mock_session.call_tool.return_value = types.CallToolResult(
content=[text_content]
)
mock_session.call_tool.return_value = types.CallToolResult(content=[text_content])
# --- Test Execution ---
result = await mcp_session_group.call_tool(
@@ -96,16 +92,12 @@ class TestClientSessionGroup:
mock_prompt1 = mock.Mock(spec=types.Prompt)
mock_prompt1.name = "prompt_c"
mock_session.list_tools.return_value = mock.AsyncMock(tools=[mock_tool1])
mock_session.list_resources.return_value = mock.AsyncMock(
resources=[mock_resource1]
)
mock_session.list_resources.return_value = mock.AsyncMock(resources=[mock_resource1])
mock_session.list_prompts.return_value = mock.AsyncMock(prompts=[mock_prompt1])
# --- Test Execution ---
group = ClientSessionGroup(exit_stack=mock_exit_stack)
with mock.patch.object(
group, "_establish_session", return_value=(mock_server_info, mock_session)
):
with mock.patch.object(group, "_establish_session", return_value=(mock_server_info, mock_session)):
await group.connect_to_server(StdioServerParameters(command="test"))
# --- Assertions ---
@@ -141,12 +133,8 @@ class TestClientSessionGroup:
return f"{server_info.name}.{name}"
# --- Test Execution ---
group = ClientSessionGroup(
exit_stack=mock_exit_stack, component_name_hook=name_hook
)
with mock.patch.object(
group, "_establish_session", return_value=(mock_server_info, mock_session)
):
group = ClientSessionGroup(exit_stack=mock_exit_stack, component_name_hook=name_hook)
with mock.patch.object(group, "_establish_session", return_value=(mock_server_info, mock_session)):
await group.connect_to_server(StdioServerParameters(command="test"))
# --- Assertions ---
@@ -231,9 +219,7 @@ class TestClientSessionGroup:
# Need a dummy session associated with the existing tool
mock_session = mock.MagicMock(spec=mcp.ClientSession)
group._tool_to_session[existing_tool_name] = mock_session
group._session_exit_stacks[mock_session] = mock.Mock(
spec=contextlib.AsyncExitStack
)
group._session_exit_stacks[mock_session] = mock.Mock(spec=contextlib.AsyncExitStack)
# --- Mock New Connection Attempt ---
mock_server_info_new = mock.Mock(spec=types.Implementation)
@@ -243,9 +229,7 @@ class TestClientSessionGroup:
# Configure the new session to return a tool with the *same name*
duplicate_tool = mock.Mock(spec=types.Tool)
duplicate_tool.name = existing_tool_name
mock_session_new.list_tools.return_value = mock.AsyncMock(
tools=[duplicate_tool]
)
mock_session_new.list_tools.return_value = mock.AsyncMock(tools=[duplicate_tool])
# Keep other lists empty for simplicity
mock_session_new.list_resources.return_value = mock.AsyncMock(resources=[])
mock_session_new.list_prompts.return_value = mock.AsyncMock(prompts=[])
@@ -266,9 +250,7 @@ class TestClientSessionGroup:
# Verify the duplicate tool was *not* added again (state should be unchanged)
assert len(group._tools) == 1 # Should still only have the original
assert (
group._tools[existing_tool_name] is not duplicate_tool
) # Ensure it's the original mock
assert group._tools[existing_tool_name] is not duplicate_tool # Ensure it's the original mock
# No patching needed here
async def test_disconnect_non_existent_server(self):
@@ -292,9 +274,7 @@ class TestClientSessionGroup:
"mcp.client.session_group.sse_client",
), # url, headers, timeout, sse_read_timeout
(
StreamableHttpParameters(
url="http://test.com/stream", terminate_on_close=False
),
StreamableHttpParameters(url="http://test.com/stream", terminate_on_close=False),
"streamablehttp",
"mcp.client.session_group.streamablehttp_client",
), # url, headers, timeout, sse_read_timeout, terminate_on_close
@@ -306,13 +286,9 @@ class TestClientSessionGroup:
client_type_name, # Just for clarity or conditional logic if needed
patch_target_for_client_func,
):
with mock.patch(
"mcp.client.session_group.mcp.ClientSession"
) as mock_ClientSession_class:
with mock.patch("mcp.client.session_group.mcp.ClientSession") as mock_ClientSession_class:
with mock.patch(patch_target_for_client_func) as mock_specific_client_func:
mock_client_cm_instance = mock.AsyncMock(
name=f"{client_type_name}ClientCM"
)
mock_client_cm_instance = mock.AsyncMock(name=f"{client_type_name}ClientCM")
mock_read_stream = mock.AsyncMock(name=f"{client_type_name}Read")
mock_write_stream = mock.AsyncMock(name=f"{client_type_name}Write")
@@ -344,9 +320,7 @@ class TestClientSessionGroup:
# Mock session.initialize()
mock_initialize_result = mock.AsyncMock(name="InitializeResult")
mock_initialize_result.serverInfo = types.Implementation(
name="foo", version="1"
)
mock_initialize_result.serverInfo = types.Implementation(name="foo", version="1")
mock_entered_session.initialize.return_value = mock_initialize_result
# --- Test Execution ---
@@ -364,9 +338,7 @@ class TestClientSessionGroup:
# --- Assertions ---
# 1. Assert the correct specific client function was called
if client_type_name == "stdio":
mock_specific_client_func.assert_called_once_with(
server_params_instance
)
mock_specific_client_func.assert_called_once_with(server_params_instance)
elif client_type_name == "sse":
mock_specific_client_func.assert_called_once_with(
url=server_params_instance.url,
@@ -386,9 +358,7 @@ class TestClientSessionGroup:
mock_client_cm_instance.__aenter__.assert_awaited_once()
# 2. Assert ClientSession was called correctly
mock_ClientSession_class.assert_called_once_with(
mock_read_stream, mock_write_stream
)
mock_ClientSession_class.assert_called_once_with(mock_read_stream, mock_write_stream)
mock_raw_session_cm.__aenter__.assert_awaited_once()
mock_entered_session.initialize.assert_awaited_once()

View File

@@ -50,20 +50,14 @@ async def test_stdio_client():
break
assert len(read_messages) == 2
assert read_messages[0] == JSONRPCMessage(
root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")
)
assert read_messages[1] == JSONRPCMessage(
root=JSONRPCResponse(jsonrpc="2.0", id=2, result={})
)
assert read_messages[0] == JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping"))
assert read_messages[1] == JSONRPCMessage(root=JSONRPCResponse(jsonrpc="2.0", id=2, result={}))
@pytest.mark.anyio
async def test_stdio_client_bad_path():
"""Check that the connection doesn't hang if process errors."""
server_params = StdioServerParameters(
command="python", args=["-c", "non-existent-file.py"]
)
server_params = StdioServerParameters(command="python", args=["-c", "non-existent-file.py"])
async with stdio_client(server_params) as (read_stream, write_stream):
async with ClientSession(read_stream, write_stream) as session:
# The session should raise an error when the connection closes