mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-24 01:04:20 +01:00
Use 120 characters instead of 88 (#856)
This commit is contained in:
committed by
GitHub
parent
f7265f7b91
commit
543961968c
@@ -49,8 +49,7 @@ class StreamSpyCollection:
|
||||
return [
|
||||
req.message.root
|
||||
for req in self.client.sent_messages
|
||||
if isinstance(req.message.root, JSONRPCRequest)
|
||||
and (method is None or req.message.root.method == method)
|
||||
if isinstance(req.message.root, JSONRPCRequest) and (method is None or req.message.root.method == method)
|
||||
]
|
||||
|
||||
def get_server_requests(self, method: str | None = None) -> list[JSONRPCRequest]:
|
||||
@@ -58,13 +57,10 @@ class StreamSpyCollection:
|
||||
return [
|
||||
req.message.root
|
||||
for req in self.server.sent_messages
|
||||
if isinstance(req.message.root, JSONRPCRequest)
|
||||
and (method is None or req.message.root.method == method)
|
||||
if isinstance(req.message.root, JSONRPCRequest) and (method is None or req.message.root.method == method)
|
||||
]
|
||||
|
||||
def get_client_notifications(
|
||||
self, method: str | None = None
|
||||
) -> list[JSONRPCNotification]:
|
||||
def get_client_notifications(self, method: str | None = None) -> list[JSONRPCNotification]:
|
||||
"""Get client-sent notifications, optionally filtered by method."""
|
||||
return [
|
||||
notif.message.root
|
||||
@@ -73,9 +69,7 @@ class StreamSpyCollection:
|
||||
and (method is None or notif.message.root.method == method)
|
||||
]
|
||||
|
||||
def get_server_notifications(
|
||||
self, method: str | None = None
|
||||
) -> list[JSONRPCNotification]:
|
||||
def get_server_notifications(self, method: str | None = None) -> list[JSONRPCNotification]:
|
||||
"""Get server-sent notifications, optionally filtered by method."""
|
||||
return [
|
||||
notif.message.root
|
||||
@@ -133,9 +127,7 @@ def stream_spy():
|
||||
yield (client_read, spy_client_write), (server_read, spy_server_write)
|
||||
|
||||
# Apply the patch for the duration of the test
|
||||
with patch(
|
||||
"mcp.shared.memory.create_client_server_memory_streams", patched_create_streams
|
||||
):
|
||||
with patch("mcp.shared.memory.create_client_server_memory_streams", patched_create_streams):
|
||||
# Return a collection with helper methods
|
||||
def get_spy_collection() -> StreamSpyCollection:
|
||||
assert client_spy is not None, "client_spy was not initialized"
|
||||
|
||||
@@ -134,9 +134,7 @@ class TestOAuthClientProvider:
|
||||
assert len(verifier) == 128
|
||||
|
||||
# Check charset (RFC 7636: A-Z, a-z, 0-9, "-", ".", "_", "~")
|
||||
allowed_chars = set(
|
||||
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~"
|
||||
)
|
||||
allowed_chars = set("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~")
|
||||
assert set(verifier) <= allowed_chars
|
||||
|
||||
# Check uniqueness (generate multiple and ensure they're different)
|
||||
@@ -151,9 +149,7 @@ class TestOAuthClientProvider:
|
||||
|
||||
# Manually calculate expected challenge
|
||||
expected_digest = hashlib.sha256(verifier.encode()).digest()
|
||||
expected_challenge = (
|
||||
base64.urlsafe_b64encode(expected_digest).decode().rstrip("=")
|
||||
)
|
||||
expected_challenge = base64.urlsafe_b64encode(expected_digest).decode().rstrip("=")
|
||||
|
||||
assert challenge == expected_challenge
|
||||
|
||||
@@ -166,29 +162,19 @@ class TestOAuthClientProvider:
|
||||
async def test_get_authorization_base_url(self, oauth_provider):
|
||||
"""Test authorization base URL extraction."""
|
||||
# Test with path
|
||||
assert (
|
||||
oauth_provider._get_authorization_base_url("https://api.example.com/v1/mcp")
|
||||
== "https://api.example.com"
|
||||
)
|
||||
assert oauth_provider._get_authorization_base_url("https://api.example.com/v1/mcp") == "https://api.example.com"
|
||||
|
||||
# Test with no path
|
||||
assert (
|
||||
oauth_provider._get_authorization_base_url("https://api.example.com")
|
||||
== "https://api.example.com"
|
||||
)
|
||||
assert oauth_provider._get_authorization_base_url("https://api.example.com") == "https://api.example.com"
|
||||
|
||||
# Test with port
|
||||
assert (
|
||||
oauth_provider._get_authorization_base_url(
|
||||
"https://api.example.com:8080/path/to/mcp"
|
||||
)
|
||||
oauth_provider._get_authorization_base_url("https://api.example.com:8080/path/to/mcp")
|
||||
== "https://api.example.com:8080"
|
||||
)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_discover_oauth_metadata_success(
|
||||
self, oauth_provider, oauth_metadata
|
||||
):
|
||||
async def test_discover_oauth_metadata_success(self, oauth_provider, oauth_metadata):
|
||||
"""Test successful OAuth metadata discovery."""
|
||||
metadata_response = oauth_metadata.model_dump(by_alias=True, mode="json")
|
||||
|
||||
@@ -201,23 +187,16 @@ class TestOAuthClientProvider:
|
||||
mock_response.json.return_value = metadata_response
|
||||
mock_client.get.return_value = mock_response
|
||||
|
||||
result = await oauth_provider._discover_oauth_metadata(
|
||||
"https://api.example.com/v1/mcp"
|
||||
)
|
||||
result = await oauth_provider._discover_oauth_metadata("https://api.example.com/v1/mcp")
|
||||
|
||||
assert result is not None
|
||||
assert (
|
||||
result.authorization_endpoint == oauth_metadata.authorization_endpoint
|
||||
)
|
||||
assert result.authorization_endpoint == oauth_metadata.authorization_endpoint
|
||||
assert result.token_endpoint == oauth_metadata.token_endpoint
|
||||
|
||||
# Verify correct URL was called
|
||||
mock_client.get.assert_called_once()
|
||||
call_args = mock_client.get.call_args[0]
|
||||
assert (
|
||||
call_args[0]
|
||||
== "https://api.example.com/.well-known/oauth-authorization-server"
|
||||
)
|
||||
assert call_args[0] == "https://api.example.com/.well-known/oauth-authorization-server"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_discover_oauth_metadata_not_found(self, oauth_provider):
|
||||
@@ -230,16 +209,12 @@ class TestOAuthClientProvider:
|
||||
mock_response.status_code = 404
|
||||
mock_client.get.return_value = mock_response
|
||||
|
||||
result = await oauth_provider._discover_oauth_metadata(
|
||||
"https://api.example.com/v1/mcp"
|
||||
)
|
||||
result = await oauth_provider._discover_oauth_metadata("https://api.example.com/v1/mcp")
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_discover_oauth_metadata_cors_fallback(
|
||||
self, oauth_provider, oauth_metadata
|
||||
):
|
||||
async def test_discover_oauth_metadata_cors_fallback(self, oauth_provider, oauth_metadata):
|
||||
"""Test OAuth metadata discovery with CORS fallback."""
|
||||
metadata_response = oauth_metadata.model_dump(by_alias=True, mode="json")
|
||||
|
||||
@@ -257,17 +232,13 @@ class TestOAuthClientProvider:
|
||||
mock_response_success, # Second call succeeds
|
||||
]
|
||||
|
||||
result = await oauth_provider._discover_oauth_metadata(
|
||||
"https://api.example.com/v1/mcp"
|
||||
)
|
||||
result = await oauth_provider._discover_oauth_metadata("https://api.example.com/v1/mcp")
|
||||
|
||||
assert result is not None
|
||||
assert mock_client.get.call_count == 2
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_register_oauth_client_success(
|
||||
self, oauth_provider, oauth_metadata, oauth_client_info
|
||||
):
|
||||
async def test_register_oauth_client_success(self, oauth_provider, oauth_metadata, oauth_client_info):
|
||||
"""Test successful OAuth client registration."""
|
||||
registration_response = oauth_client_info.model_dump(by_alias=True, mode="json")
|
||||
|
||||
@@ -295,9 +266,7 @@ class TestOAuthClientProvider:
|
||||
assert call_args[0][0] == str(oauth_metadata.registration_endpoint)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_register_oauth_client_fallback_endpoint(
|
||||
self, oauth_provider, oauth_client_info
|
||||
):
|
||||
async def test_register_oauth_client_fallback_endpoint(self, oauth_provider, oauth_client_info):
|
||||
"""Test OAuth client registration with fallback endpoint."""
|
||||
registration_response = oauth_client_info.model_dump(by_alias=True, mode="json")
|
||||
|
||||
@@ -311,9 +280,7 @@ class TestOAuthClientProvider:
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
# Mock metadata discovery to return None (fallback)
|
||||
with patch.object(
|
||||
oauth_provider, "_discover_oauth_metadata", return_value=None
|
||||
):
|
||||
with patch.object(oauth_provider, "_discover_oauth_metadata", return_value=None):
|
||||
result = await oauth_provider._register_oauth_client(
|
||||
"https://api.example.com/v1/mcp",
|
||||
oauth_provider.client_metadata,
|
||||
@@ -340,9 +307,7 @@ class TestOAuthClientProvider:
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
# Mock metadata discovery to return None (fallback)
|
||||
with patch.object(
|
||||
oauth_provider, "_discover_oauth_metadata", return_value=None
|
||||
):
|
||||
with patch.object(oauth_provider, "_discover_oauth_metadata", return_value=None):
|
||||
with pytest.raises(httpx.HTTPStatusError):
|
||||
await oauth_provider._register_oauth_client(
|
||||
"https://api.example.com/v1/mcp",
|
||||
@@ -406,9 +371,7 @@ class TestOAuthClientProvider:
|
||||
await oauth_provider._validate_token_scopes(token)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_validate_token_scopes_unauthorized(
|
||||
self, oauth_provider, client_metadata
|
||||
):
|
||||
async def test_validate_token_scopes_unauthorized(self, oauth_provider, client_metadata):
|
||||
"""Test scope validation with unauthorized scopes."""
|
||||
oauth_provider.client_metadata = client_metadata
|
||||
token = OAuthToken(
|
||||
@@ -436,9 +399,7 @@ class TestOAuthClientProvider:
|
||||
await oauth_provider._validate_token_scopes(token)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_initialize(
|
||||
self, oauth_provider, mock_storage, oauth_token, oauth_client_info
|
||||
):
|
||||
async def test_initialize(self, oauth_provider, mock_storage, oauth_token, oauth_client_info):
|
||||
"""Test initialization loading from storage."""
|
||||
mock_storage._tokens = oauth_token
|
||||
mock_storage._client_info = oauth_client_info
|
||||
@@ -449,9 +410,7 @@ class TestOAuthClientProvider:
|
||||
assert oauth_provider._client_info == oauth_client_info
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_get_or_register_client_existing(
|
||||
self, oauth_provider, oauth_client_info
|
||||
):
|
||||
async def test_get_or_register_client_existing(self, oauth_provider, oauth_client_info):
|
||||
"""Test getting existing client info."""
|
||||
oauth_provider._client_info = oauth_client_info
|
||||
|
||||
@@ -460,13 +419,9 @@ class TestOAuthClientProvider:
|
||||
assert result == oauth_client_info
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_get_or_register_client_register_new(
|
||||
self, oauth_provider, oauth_client_info
|
||||
):
|
||||
async def test_get_or_register_client_register_new(self, oauth_provider, oauth_client_info):
|
||||
"""Test registering new client."""
|
||||
with patch.object(
|
||||
oauth_provider, "_register_oauth_client", return_value=oauth_client_info
|
||||
) as mock_register:
|
||||
with patch.object(oauth_provider, "_register_oauth_client", return_value=oauth_client_info) as mock_register:
|
||||
result = await oauth_provider._get_or_register_client()
|
||||
|
||||
assert result == oauth_client_info
|
||||
@@ -474,9 +429,7 @@ class TestOAuthClientProvider:
|
||||
mock_register.assert_called_once()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_exchange_code_for_token_success(
|
||||
self, oauth_provider, oauth_client_info, oauth_token
|
||||
):
|
||||
async def test_exchange_code_for_token_success(self, oauth_provider, oauth_client_info, oauth_token):
|
||||
"""Test successful code exchange for token."""
|
||||
oauth_provider._code_verifier = "test_verifier"
|
||||
token_response = oauth_token.model_dump(by_alias=True, mode="json")
|
||||
@@ -490,23 +443,14 @@ class TestOAuthClientProvider:
|
||||
mock_response.json.return_value = token_response
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
with patch.object(
|
||||
oauth_provider, "_validate_token_scopes"
|
||||
) as mock_validate:
|
||||
await oauth_provider._exchange_code_for_token(
|
||||
"test_auth_code", oauth_client_info
|
||||
)
|
||||
with patch.object(oauth_provider, "_validate_token_scopes") as mock_validate:
|
||||
await oauth_provider._exchange_code_for_token("test_auth_code", oauth_client_info)
|
||||
|
||||
assert (
|
||||
oauth_provider._current_tokens.access_token
|
||||
== oauth_token.access_token
|
||||
)
|
||||
assert oauth_provider._current_tokens.access_token == oauth_token.access_token
|
||||
mock_validate.assert_called_once()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_exchange_code_for_token_failure(
|
||||
self, oauth_provider, oauth_client_info
|
||||
):
|
||||
async def test_exchange_code_for_token_failure(self, oauth_provider, oauth_client_info):
|
||||
"""Test failed code exchange for token."""
|
||||
oauth_provider._code_verifier = "test_verifier"
|
||||
|
||||
@@ -520,14 +464,10 @@ class TestOAuthClientProvider:
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
with pytest.raises(Exception, match="Token exchange failed"):
|
||||
await oauth_provider._exchange_code_for_token(
|
||||
"invalid_auth_code", oauth_client_info
|
||||
)
|
||||
await oauth_provider._exchange_code_for_token("invalid_auth_code", oauth_client_info)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_refresh_access_token_success(
|
||||
self, oauth_provider, oauth_client_info, oauth_token
|
||||
):
|
||||
async def test_refresh_access_token_success(self, oauth_provider, oauth_client_info, oauth_token):
|
||||
"""Test successful token refresh."""
|
||||
oauth_provider._current_tokens = oauth_token
|
||||
oauth_provider._client_info = oauth_client_info
|
||||
@@ -550,16 +490,11 @@ class TestOAuthClientProvider:
|
||||
mock_response.json.return_value = token_response
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
with patch.object(
|
||||
oauth_provider, "_validate_token_scopes"
|
||||
) as mock_validate:
|
||||
with patch.object(oauth_provider, "_validate_token_scopes") as mock_validate:
|
||||
result = await oauth_provider._refresh_access_token()
|
||||
|
||||
assert result is True
|
||||
assert (
|
||||
oauth_provider._current_tokens.access_token
|
||||
== new_token.access_token
|
||||
)
|
||||
assert oauth_provider._current_tokens.access_token == new_token.access_token
|
||||
mock_validate.assert_called_once()
|
||||
|
||||
@pytest.mark.anyio
|
||||
@@ -575,9 +510,7 @@ class TestOAuthClientProvider:
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_refresh_access_token_failure(
|
||||
self, oauth_provider, oauth_client_info, oauth_token
|
||||
):
|
||||
async def test_refresh_access_token_failure(self, oauth_provider, oauth_client_info, oauth_token):
|
||||
"""Test failed token refresh."""
|
||||
oauth_provider._current_tokens = oauth_token
|
||||
oauth_provider._client_info = oauth_client_info
|
||||
@@ -594,9 +527,7 @@ class TestOAuthClientProvider:
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_perform_oauth_flow_success(
|
||||
self, oauth_provider, oauth_metadata, oauth_client_info
|
||||
):
|
||||
async def test_perform_oauth_flow_success(self, oauth_provider, oauth_metadata, oauth_client_info):
|
||||
"""Test successful OAuth flow."""
|
||||
oauth_provider._metadata = oauth_metadata
|
||||
oauth_provider._client_info = oauth_client_info
|
||||
@@ -640,9 +571,7 @@ class TestOAuthClientProvider:
|
||||
mock_exchange.assert_called_once_with("test_auth_code", oauth_client_info)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_perform_oauth_flow_state_mismatch(
|
||||
self, oauth_provider, oauth_metadata, oauth_client_info
|
||||
):
|
||||
async def test_perform_oauth_flow_state_mismatch(self, oauth_provider, oauth_metadata, oauth_client_info):
|
||||
"""Test OAuth flow with state parameter mismatch."""
|
||||
oauth_provider._metadata = oauth_metadata
|
||||
oauth_provider._client_info = oauth_client_info
|
||||
@@ -678,9 +607,7 @@ class TestOAuthClientProvider:
|
||||
oauth_provider._current_tokens = oauth_token
|
||||
oauth_provider._token_expiry_time = time.time() - 3600 # Expired
|
||||
|
||||
with patch.object(
|
||||
oauth_provider, "_refresh_access_token", return_value=True
|
||||
) as mock_refresh:
|
||||
with patch.object(oauth_provider, "_refresh_access_token", return_value=True) as mock_refresh:
|
||||
await oauth_provider.ensure_token()
|
||||
mock_refresh.assert_called_once()
|
||||
|
||||
@@ -707,10 +634,7 @@ class TestOAuthClientProvider:
|
||||
auth_flow = oauth_provider.async_auth_flow(request)
|
||||
updated_request = await auth_flow.__anext__()
|
||||
|
||||
assert (
|
||||
updated_request.headers["Authorization"]
|
||||
== f"Bearer {oauth_token.access_token}"
|
||||
)
|
||||
assert updated_request.headers["Authorization"] == f"Bearer {oauth_token.access_token}"
|
||||
|
||||
# Send mock response
|
||||
try:
|
||||
@@ -761,9 +685,7 @@ class TestOAuthClientProvider:
|
||||
assert "Authorization" not in updated_request.headers
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_scope_priority_client_metadata_first(
|
||||
self, oauth_provider, oauth_client_info
|
||||
):
|
||||
async def test_scope_priority_client_metadata_first(self, oauth_provider, oauth_client_info):
|
||||
"""Test that client metadata scope takes priority."""
|
||||
oauth_provider.client_metadata.scope = "read write"
|
||||
oauth_provider._client_info = oauth_client_info
|
||||
@@ -782,18 +704,13 @@ class TestOAuthClientProvider:
|
||||
# Apply scope logic from _perform_oauth_flow
|
||||
if oauth_provider.client_metadata.scope:
|
||||
auth_params["scope"] = oauth_provider.client_metadata.scope
|
||||
elif (
|
||||
hasattr(oauth_provider._client_info, "scope")
|
||||
and oauth_provider._client_info.scope
|
||||
):
|
||||
elif hasattr(oauth_provider._client_info, "scope") and oauth_provider._client_info.scope:
|
||||
auth_params["scope"] = oauth_provider._client_info.scope
|
||||
|
||||
assert auth_params["scope"] == "read write"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_scope_priority_no_client_metadata_scope(
|
||||
self, oauth_provider, oauth_client_info
|
||||
):
|
||||
async def test_scope_priority_no_client_metadata_scope(self, oauth_provider, oauth_client_info):
|
||||
"""Test that no scope parameter is set when client metadata has no scope."""
|
||||
oauth_provider.client_metadata.scope = None
|
||||
oauth_provider._client_info = oauth_client_info
|
||||
@@ -837,10 +754,7 @@ class TestOAuthClientProvider:
|
||||
# Apply scope logic from _perform_oauth_flow
|
||||
if oauth_provider.client_metadata.scope:
|
||||
auth_params["scope"] = oauth_provider.client_metadata.scope
|
||||
elif (
|
||||
hasattr(oauth_provider._client_info, "scope")
|
||||
and oauth_provider._client_info.scope
|
||||
):
|
||||
elif hasattr(oauth_provider._client_info, "scope") and oauth_provider._client_info.scope:
|
||||
auth_params["scope"] = oauth_provider._client_info.scope
|
||||
|
||||
# No scope should be set
|
||||
@@ -866,9 +780,7 @@ class TestOAuthClientProvider:
|
||||
oauth_provider.redirect_handler = mock_redirect_handler
|
||||
|
||||
# Patch secrets.compare_digest to verify it's being called
|
||||
with patch(
|
||||
"mcp.client.auth.secrets.compare_digest", return_value=False
|
||||
) as mock_compare:
|
||||
with patch("mcp.client.auth.secrets.compare_digest", return_value=False) as mock_compare:
|
||||
with pytest.raises(Exception, match="State parameter mismatch"):
|
||||
await oauth_provider._perform_oauth_flow()
|
||||
|
||||
@@ -876,9 +788,7 @@ class TestOAuthClientProvider:
|
||||
mock_compare.assert_called_once()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_state_parameter_validation_none_state(
|
||||
self, oauth_provider, oauth_metadata, oauth_client_info
|
||||
):
|
||||
async def test_state_parameter_validation_none_state(self, oauth_provider, oauth_metadata, oauth_client_info):
|
||||
"""Test that None state is handled correctly."""
|
||||
oauth_provider._metadata = oauth_metadata
|
||||
oauth_provider._client_info = oauth_client_info
|
||||
@@ -913,9 +823,7 @@ class TestOAuthClientProvider:
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
with pytest.raises(Exception, match="Token exchange failed"):
|
||||
await oauth_provider._exchange_code_for_token(
|
||||
"invalid_auth_code", oauth_client_info
|
||||
)
|
||||
await oauth_provider._exchange_code_for_token("invalid_auth_code", oauth_client_info)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -968,9 +876,7 @@ def test_build_metadata(
|
||||
metadata = build_metadata(
|
||||
issuer_url=AnyHttpUrl(issuer_url),
|
||||
service_documentation_url=AnyHttpUrl(service_documentation_url),
|
||||
client_registration_options=ClientRegistrationOptions(
|
||||
enabled=True, valid_scopes=["read", "write", "admin"]
|
||||
),
|
||||
client_registration_options=ClientRegistrationOptions(enabled=True, valid_scopes=["read", "write", "admin"]),
|
||||
revocation_options=RevocationOptions(enabled=True),
|
||||
)
|
||||
|
||||
|
||||
@@ -44,9 +44,7 @@ def test_command_execution(mock_config_path: Path):
|
||||
|
||||
test_args = [command] + args + ["--help"]
|
||||
|
||||
result = subprocess.run(
|
||||
test_args, capture_output=True, text=True, timeout=5, check=False
|
||||
)
|
||||
result = subprocess.run(test_args, capture_output=True, text=True, timeout=5, check=False)
|
||||
|
||||
assert result.returncode == 0
|
||||
assert "usage" in result.stdout.lower()
|
||||
|
||||
@@ -182,9 +182,7 @@ async def test_list_resource_templates_cursor_parameter(stream_spy):
|
||||
|
||||
# Test without cursor parameter (omitted)
|
||||
_ = await client_session.list_resource_templates()
|
||||
list_templates_requests = spies.get_client_requests(
|
||||
method="resources/templates/list"
|
||||
)
|
||||
list_templates_requests = spies.get_client_requests(method="resources/templates/list")
|
||||
assert len(list_templates_requests) == 1
|
||||
assert list_templates_requests[0].params is None
|
||||
|
||||
@@ -192,9 +190,7 @@ async def test_list_resource_templates_cursor_parameter(stream_spy):
|
||||
|
||||
# Test with cursor=None
|
||||
_ = await client_session.list_resource_templates(cursor=None)
|
||||
list_templates_requests = spies.get_client_requests(
|
||||
method="resources/templates/list"
|
||||
)
|
||||
list_templates_requests = spies.get_client_requests(method="resources/templates/list")
|
||||
assert len(list_templates_requests) == 1
|
||||
assert list_templates_requests[0].params is None
|
||||
|
||||
@@ -202,9 +198,7 @@ async def test_list_resource_templates_cursor_parameter(stream_spy):
|
||||
|
||||
# Test with cursor as string
|
||||
_ = await client_session.list_resource_templates(cursor="some_cursor")
|
||||
list_templates_requests = spies.get_client_requests(
|
||||
method="resources/templates/list"
|
||||
)
|
||||
list_templates_requests = spies.get_client_requests(method="resources/templates/list")
|
||||
assert len(list_templates_requests) == 1
|
||||
assert list_templates_requests[0].params is not None
|
||||
assert list_templates_requests[0].params["cursor"] == "some_cursor"
|
||||
@@ -213,9 +207,7 @@ async def test_list_resource_templates_cursor_parameter(stream_spy):
|
||||
|
||||
# Test with empty string cursor
|
||||
_ = await client_session.list_resource_templates(cursor="")
|
||||
list_templates_requests = spies.get_client_requests(
|
||||
method="resources/templates/list"
|
||||
)
|
||||
list_templates_requests = spies.get_client_requests(method="resources/templates/list")
|
||||
assert len(list_templates_requests) == 1
|
||||
assert list_templates_requests[0].params is not None
|
||||
assert list_templates_requests[0].params["cursor"] == ""
|
||||
|
||||
@@ -41,13 +41,9 @@ async def test_list_roots_callback():
|
||||
return True
|
||||
|
||||
# Test with list_roots callback
|
||||
async with create_session(
|
||||
server._mcp_server, list_roots_callback=list_roots_callback
|
||||
) as client_session:
|
||||
async with create_session(server._mcp_server, list_roots_callback=list_roots_callback) as client_session:
|
||||
# Make a request to trigger sampling callback
|
||||
result = await client_session.call_tool(
|
||||
"test_list_roots", {"message": "test message"}
|
||||
)
|
||||
result = await client_session.call_tool("test_list_roots", {"message": "test message"})
|
||||
assert result.isError is False
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
assert result.content[0].text == "true"
|
||||
@@ -55,12 +51,7 @@ async def test_list_roots_callback():
|
||||
# Test without list_roots callback
|
||||
async with create_session(server._mcp_server) as client_session:
|
||||
# Make a request to trigger sampling callback
|
||||
result = await client_session.call_tool(
|
||||
"test_list_roots", {"message": "test message"}
|
||||
)
|
||||
result = await client_session.call_tool("test_list_roots", {"message": "test message"})
|
||||
assert result.isError is True
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
assert (
|
||||
result.content[0].text
|
||||
== "Error executing tool test_list_roots: List roots not supported"
|
||||
)
|
||||
assert result.content[0].text == "Error executing tool test_list_roots: List roots not supported"
|
||||
|
||||
@@ -49,9 +49,7 @@ async def test_logging_callback():
|
||||
|
||||
# Create a message handler to catch exceptions
|
||||
async def message_handler(
|
||||
message: RequestResponder[types.ServerRequest, types.ClientResult]
|
||||
| types.ServerNotification
|
||||
| Exception,
|
||||
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
|
||||
) -> None:
|
||||
if isinstance(message, Exception):
|
||||
raise message
|
||||
|
||||
@@ -21,9 +21,7 @@ async def test_sampling_callback():
|
||||
|
||||
callback_return = CreateMessageResult(
|
||||
role="assistant",
|
||||
content=TextContent(
|
||||
type="text", text="This is a response from the sampling callback"
|
||||
),
|
||||
content=TextContent(type="text", text="This is a response from the sampling callback"),
|
||||
model="test-model",
|
||||
stopReason="endTurn",
|
||||
)
|
||||
@@ -37,24 +35,16 @@ async def test_sampling_callback():
|
||||
@server.tool("test_sampling")
|
||||
async def test_sampling_tool(message: str):
|
||||
value = await server.get_context().session.create_message(
|
||||
messages=[
|
||||
SamplingMessage(
|
||||
role="user", content=TextContent(type="text", text=message)
|
||||
)
|
||||
],
|
||||
messages=[SamplingMessage(role="user", content=TextContent(type="text", text=message))],
|
||||
max_tokens=100,
|
||||
)
|
||||
assert value == callback_return
|
||||
return True
|
||||
|
||||
# Test with sampling callback
|
||||
async with create_session(
|
||||
server._mcp_server, sampling_callback=sampling_callback
|
||||
) as client_session:
|
||||
async with create_session(server._mcp_server, sampling_callback=sampling_callback) as client_session:
|
||||
# Make a request to trigger sampling callback
|
||||
result = await client_session.call_tool(
|
||||
"test_sampling", {"message": "Test message for sampling"}
|
||||
)
|
||||
result = await client_session.call_tool("test_sampling", {"message": "Test message for sampling"})
|
||||
assert result.isError is False
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
assert result.content[0].text == "true"
|
||||
@@ -62,12 +52,7 @@ async def test_sampling_callback():
|
||||
# Test without sampling callback
|
||||
async with create_session(server._mcp_server) as client_session:
|
||||
# Make a request to trigger sampling callback
|
||||
result = await client_session.call_tool(
|
||||
"test_sampling", {"message": "Test message for sampling"}
|
||||
)
|
||||
result = await client_session.call_tool("test_sampling", {"message": "Test message for sampling"})
|
||||
assert result.isError is True
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
assert (
|
||||
result.content[0].text
|
||||
== "Error executing tool test_sampling: Sampling not supported"
|
||||
)
|
||||
assert result.content[0].text == "Error executing tool test_sampling: Sampling not supported"
|
||||
|
||||
@@ -28,12 +28,8 @@ from mcp.types import (
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_client_session_initialize():
|
||||
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
|
||||
SessionMessage
|
||||
](1)
|
||||
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
|
||||
SessionMessage
|
||||
](1)
|
||||
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
|
||||
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
|
||||
|
||||
initialized_notification = None
|
||||
|
||||
@@ -70,9 +66,7 @@ async def test_client_session_initialize():
|
||||
JSONRPCResponse(
|
||||
jsonrpc="2.0",
|
||||
id=jsonrpc_request.root.id,
|
||||
result=result.model_dump(
|
||||
by_alias=True, mode="json", exclude_none=True
|
||||
),
|
||||
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
)
|
||||
)
|
||||
)
|
||||
@@ -81,16 +75,12 @@ async def test_client_session_initialize():
|
||||
jsonrpc_notification = session_notification.message
|
||||
assert isinstance(jsonrpc_notification.root, JSONRPCNotification)
|
||||
initialized_notification = ClientNotification.model_validate(
|
||||
jsonrpc_notification.model_dump(
|
||||
by_alias=True, mode="json", exclude_none=True
|
||||
)
|
||||
jsonrpc_notification.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||
)
|
||||
|
||||
# Create a message handler to catch exceptions
|
||||
async def message_handler(
|
||||
message: RequestResponder[types.ServerRequest, types.ClientResult]
|
||||
| types.ServerNotification
|
||||
| Exception,
|
||||
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
|
||||
) -> None:
|
||||
if isinstance(message, Exception):
|
||||
raise message
|
||||
@@ -124,12 +114,8 @@ async def test_client_session_initialize():
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_client_session_custom_client_info():
|
||||
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
|
||||
SessionMessage
|
||||
](1)
|
||||
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
|
||||
SessionMessage
|
||||
](1)
|
||||
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
|
||||
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
|
||||
|
||||
custom_client_info = Implementation(name="test-client", version="1.2.3")
|
||||
received_client_info = None
|
||||
@@ -161,9 +147,7 @@ async def test_client_session_custom_client_info():
|
||||
JSONRPCResponse(
|
||||
jsonrpc="2.0",
|
||||
id=jsonrpc_request.root.id,
|
||||
result=result.model_dump(
|
||||
by_alias=True, mode="json", exclude_none=True
|
||||
),
|
||||
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
)
|
||||
)
|
||||
)
|
||||
@@ -192,12 +176,8 @@ async def test_client_session_custom_client_info():
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_client_session_default_client_info():
|
||||
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
|
||||
SessionMessage
|
||||
](1)
|
||||
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
|
||||
SessionMessage
|
||||
](1)
|
||||
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
|
||||
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
|
||||
|
||||
received_client_info = None
|
||||
|
||||
@@ -228,9 +208,7 @@ async def test_client_session_default_client_info():
|
||||
JSONRPCResponse(
|
||||
jsonrpc="2.0",
|
||||
id=jsonrpc_request.root.id,
|
||||
result=result.model_dump(
|
||||
by_alias=True, mode="json", exclude_none=True
|
||||
),
|
||||
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
)
|
||||
)
|
||||
)
|
||||
@@ -259,12 +237,8 @@ async def test_client_session_default_client_info():
|
||||
@pytest.mark.anyio
|
||||
async def test_client_session_version_negotiation_success():
|
||||
"""Test successful version negotiation with supported version"""
|
||||
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
|
||||
SessionMessage
|
||||
](1)
|
||||
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
|
||||
SessionMessage
|
||||
](1)
|
||||
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
|
||||
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
|
||||
|
||||
async def mock_server():
|
||||
session_message = await client_to_server_receive.receive()
|
||||
@@ -294,9 +268,7 @@ async def test_client_session_version_negotiation_success():
|
||||
JSONRPCResponse(
|
||||
jsonrpc="2.0",
|
||||
id=jsonrpc_request.root.id,
|
||||
result=result.model_dump(
|
||||
by_alias=True, mode="json", exclude_none=True
|
||||
),
|
||||
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
)
|
||||
)
|
||||
)
|
||||
@@ -327,12 +299,8 @@ async def test_client_session_version_negotiation_success():
|
||||
@pytest.mark.anyio
|
||||
async def test_client_session_version_negotiation_failure():
|
||||
"""Test version negotiation failure with unsupported version"""
|
||||
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
|
||||
SessionMessage
|
||||
](1)
|
||||
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
|
||||
SessionMessage
|
||||
](1)
|
||||
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
|
||||
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
|
||||
|
||||
async def mock_server():
|
||||
session_message = await client_to_server_receive.receive()
|
||||
@@ -359,9 +327,7 @@ async def test_client_session_version_negotiation_failure():
|
||||
JSONRPCResponse(
|
||||
jsonrpc="2.0",
|
||||
id=jsonrpc_request.root.id,
|
||||
result=result.model_dump(
|
||||
by_alias=True, mode="json", exclude_none=True
|
||||
),
|
||||
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
)
|
||||
)
|
||||
)
|
||||
@@ -388,12 +354,8 @@ async def test_client_session_version_negotiation_failure():
|
||||
@pytest.mark.anyio
|
||||
async def test_client_capabilities_default():
|
||||
"""Test that client capabilities are properly set with default callbacks"""
|
||||
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
|
||||
SessionMessage
|
||||
](1)
|
||||
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
|
||||
SessionMessage
|
||||
](1)
|
||||
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
|
||||
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
|
||||
|
||||
received_capabilities = None
|
||||
|
||||
@@ -424,9 +386,7 @@ async def test_client_capabilities_default():
|
||||
JSONRPCResponse(
|
||||
jsonrpc="2.0",
|
||||
id=jsonrpc_request.root.id,
|
||||
result=result.model_dump(
|
||||
by_alias=True, mode="json", exclude_none=True
|
||||
),
|
||||
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
)
|
||||
)
|
||||
)
|
||||
@@ -457,12 +417,8 @@ async def test_client_capabilities_default():
|
||||
@pytest.mark.anyio
|
||||
async def test_client_capabilities_with_custom_callbacks():
|
||||
"""Test that client capabilities are properly set with custom callbacks"""
|
||||
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
|
||||
SessionMessage
|
||||
](1)
|
||||
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
|
||||
SessionMessage
|
||||
](1)
|
||||
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
|
||||
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
|
||||
|
||||
received_capabilities = None
|
||||
|
||||
@@ -508,9 +464,7 @@ async def test_client_capabilities_with_custom_callbacks():
|
||||
JSONRPCResponse(
|
||||
jsonrpc="2.0",
|
||||
id=jsonrpc_request.root.id,
|
||||
result=result.model_dump(
|
||||
by_alias=True, mode="json", exclude_none=True
|
||||
),
|
||||
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
)
|
||||
)
|
||||
)
|
||||
@@ -536,14 +490,8 @@ async def test_client_capabilities_with_custom_callbacks():
|
||||
|
||||
# Assert that capabilities are properly set with custom callbacks
|
||||
assert received_capabilities is not None
|
||||
assert (
|
||||
received_capabilities.sampling is not None
|
||||
) # Custom sampling callback provided
|
||||
assert received_capabilities.sampling is not None # Custom sampling callback provided
|
||||
assert isinstance(received_capabilities.sampling, types.SamplingCapability)
|
||||
assert (
|
||||
received_capabilities.roots is not None
|
||||
) # Custom list_roots callback provided
|
||||
assert received_capabilities.roots is not None # Custom list_roots callback provided
|
||||
assert isinstance(received_capabilities.roots, types.RootsCapability)
|
||||
assert (
|
||||
received_capabilities.roots.listChanged is True
|
||||
) # Should be True for custom callback
|
||||
assert received_capabilities.roots.listChanged is True # Should be True for custom callback
|
||||
|
||||
@@ -58,14 +58,10 @@ class TestClientSessionGroup:
|
||||
return f"{(server_info.name)}-{name}"
|
||||
|
||||
mcp_session_group = ClientSessionGroup(component_name_hook=hook)
|
||||
mcp_session_group._tools = {
|
||||
"server1-my_tool": types.Tool(name="my_tool", inputSchema={})
|
||||
}
|
||||
mcp_session_group._tools = {"server1-my_tool": types.Tool(name="my_tool", inputSchema={})}
|
||||
mcp_session_group._tool_to_session = {"server1-my_tool": mock_session}
|
||||
text_content = types.TextContent(type="text", text="OK")
|
||||
mock_session.call_tool.return_value = types.CallToolResult(
|
||||
content=[text_content]
|
||||
)
|
||||
mock_session.call_tool.return_value = types.CallToolResult(content=[text_content])
|
||||
|
||||
# --- Test Execution ---
|
||||
result = await mcp_session_group.call_tool(
|
||||
@@ -96,16 +92,12 @@ class TestClientSessionGroup:
|
||||
mock_prompt1 = mock.Mock(spec=types.Prompt)
|
||||
mock_prompt1.name = "prompt_c"
|
||||
mock_session.list_tools.return_value = mock.AsyncMock(tools=[mock_tool1])
|
||||
mock_session.list_resources.return_value = mock.AsyncMock(
|
||||
resources=[mock_resource1]
|
||||
)
|
||||
mock_session.list_resources.return_value = mock.AsyncMock(resources=[mock_resource1])
|
||||
mock_session.list_prompts.return_value = mock.AsyncMock(prompts=[mock_prompt1])
|
||||
|
||||
# --- Test Execution ---
|
||||
group = ClientSessionGroup(exit_stack=mock_exit_stack)
|
||||
with mock.patch.object(
|
||||
group, "_establish_session", return_value=(mock_server_info, mock_session)
|
||||
):
|
||||
with mock.patch.object(group, "_establish_session", return_value=(mock_server_info, mock_session)):
|
||||
await group.connect_to_server(StdioServerParameters(command="test"))
|
||||
|
||||
# --- Assertions ---
|
||||
@@ -141,12 +133,8 @@ class TestClientSessionGroup:
|
||||
return f"{server_info.name}.{name}"
|
||||
|
||||
# --- Test Execution ---
|
||||
group = ClientSessionGroup(
|
||||
exit_stack=mock_exit_stack, component_name_hook=name_hook
|
||||
)
|
||||
with mock.patch.object(
|
||||
group, "_establish_session", return_value=(mock_server_info, mock_session)
|
||||
):
|
||||
group = ClientSessionGroup(exit_stack=mock_exit_stack, component_name_hook=name_hook)
|
||||
with mock.patch.object(group, "_establish_session", return_value=(mock_server_info, mock_session)):
|
||||
await group.connect_to_server(StdioServerParameters(command="test"))
|
||||
|
||||
# --- Assertions ---
|
||||
@@ -231,9 +219,7 @@ class TestClientSessionGroup:
|
||||
# Need a dummy session associated with the existing tool
|
||||
mock_session = mock.MagicMock(spec=mcp.ClientSession)
|
||||
group._tool_to_session[existing_tool_name] = mock_session
|
||||
group._session_exit_stacks[mock_session] = mock.Mock(
|
||||
spec=contextlib.AsyncExitStack
|
||||
)
|
||||
group._session_exit_stacks[mock_session] = mock.Mock(spec=contextlib.AsyncExitStack)
|
||||
|
||||
# --- Mock New Connection Attempt ---
|
||||
mock_server_info_new = mock.Mock(spec=types.Implementation)
|
||||
@@ -243,9 +229,7 @@ class TestClientSessionGroup:
|
||||
# Configure the new session to return a tool with the *same name*
|
||||
duplicate_tool = mock.Mock(spec=types.Tool)
|
||||
duplicate_tool.name = existing_tool_name
|
||||
mock_session_new.list_tools.return_value = mock.AsyncMock(
|
||||
tools=[duplicate_tool]
|
||||
)
|
||||
mock_session_new.list_tools.return_value = mock.AsyncMock(tools=[duplicate_tool])
|
||||
# Keep other lists empty for simplicity
|
||||
mock_session_new.list_resources.return_value = mock.AsyncMock(resources=[])
|
||||
mock_session_new.list_prompts.return_value = mock.AsyncMock(prompts=[])
|
||||
@@ -266,9 +250,7 @@ class TestClientSessionGroup:
|
||||
|
||||
# Verify the duplicate tool was *not* added again (state should be unchanged)
|
||||
assert len(group._tools) == 1 # Should still only have the original
|
||||
assert (
|
||||
group._tools[existing_tool_name] is not duplicate_tool
|
||||
) # Ensure it's the original mock
|
||||
assert group._tools[existing_tool_name] is not duplicate_tool # Ensure it's the original mock
|
||||
|
||||
# No patching needed here
|
||||
async def test_disconnect_non_existent_server(self):
|
||||
@@ -292,9 +274,7 @@ class TestClientSessionGroup:
|
||||
"mcp.client.session_group.sse_client",
|
||||
), # url, headers, timeout, sse_read_timeout
|
||||
(
|
||||
StreamableHttpParameters(
|
||||
url="http://test.com/stream", terminate_on_close=False
|
||||
),
|
||||
StreamableHttpParameters(url="http://test.com/stream", terminate_on_close=False),
|
||||
"streamablehttp",
|
||||
"mcp.client.session_group.streamablehttp_client",
|
||||
), # url, headers, timeout, sse_read_timeout, terminate_on_close
|
||||
@@ -306,13 +286,9 @@ class TestClientSessionGroup:
|
||||
client_type_name, # Just for clarity or conditional logic if needed
|
||||
patch_target_for_client_func,
|
||||
):
|
||||
with mock.patch(
|
||||
"mcp.client.session_group.mcp.ClientSession"
|
||||
) as mock_ClientSession_class:
|
||||
with mock.patch("mcp.client.session_group.mcp.ClientSession") as mock_ClientSession_class:
|
||||
with mock.patch(patch_target_for_client_func) as mock_specific_client_func:
|
||||
mock_client_cm_instance = mock.AsyncMock(
|
||||
name=f"{client_type_name}ClientCM"
|
||||
)
|
||||
mock_client_cm_instance = mock.AsyncMock(name=f"{client_type_name}ClientCM")
|
||||
mock_read_stream = mock.AsyncMock(name=f"{client_type_name}Read")
|
||||
mock_write_stream = mock.AsyncMock(name=f"{client_type_name}Write")
|
||||
|
||||
@@ -344,9 +320,7 @@ class TestClientSessionGroup:
|
||||
|
||||
# Mock session.initialize()
|
||||
mock_initialize_result = mock.AsyncMock(name="InitializeResult")
|
||||
mock_initialize_result.serverInfo = types.Implementation(
|
||||
name="foo", version="1"
|
||||
)
|
||||
mock_initialize_result.serverInfo = types.Implementation(name="foo", version="1")
|
||||
mock_entered_session.initialize.return_value = mock_initialize_result
|
||||
|
||||
# --- Test Execution ---
|
||||
@@ -364,9 +338,7 @@ class TestClientSessionGroup:
|
||||
# --- Assertions ---
|
||||
# 1. Assert the correct specific client function was called
|
||||
if client_type_name == "stdio":
|
||||
mock_specific_client_func.assert_called_once_with(
|
||||
server_params_instance
|
||||
)
|
||||
mock_specific_client_func.assert_called_once_with(server_params_instance)
|
||||
elif client_type_name == "sse":
|
||||
mock_specific_client_func.assert_called_once_with(
|
||||
url=server_params_instance.url,
|
||||
@@ -386,9 +358,7 @@ class TestClientSessionGroup:
|
||||
mock_client_cm_instance.__aenter__.assert_awaited_once()
|
||||
|
||||
# 2. Assert ClientSession was called correctly
|
||||
mock_ClientSession_class.assert_called_once_with(
|
||||
mock_read_stream, mock_write_stream
|
||||
)
|
||||
mock_ClientSession_class.assert_called_once_with(mock_read_stream, mock_write_stream)
|
||||
mock_raw_session_cm.__aenter__.assert_awaited_once()
|
||||
mock_entered_session.initialize.assert_awaited_once()
|
||||
|
||||
|
||||
@@ -50,20 +50,14 @@ async def test_stdio_client():
|
||||
break
|
||||
|
||||
assert len(read_messages) == 2
|
||||
assert read_messages[0] == JSONRPCMessage(
|
||||
root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")
|
||||
)
|
||||
assert read_messages[1] == JSONRPCMessage(
|
||||
root=JSONRPCResponse(jsonrpc="2.0", id=2, result={})
|
||||
)
|
||||
assert read_messages[0] == JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping"))
|
||||
assert read_messages[1] == JSONRPCMessage(root=JSONRPCResponse(jsonrpc="2.0", id=2, result={}))
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_stdio_client_bad_path():
|
||||
"""Check that the connection doesn't hang if process errors."""
|
||||
server_params = StdioServerParameters(
|
||||
command="python", args=["-c", "non-existent-file.py"]
|
||||
)
|
||||
server_params = StdioServerParameters(command="python", args=["-c", "non-existent-file.py"])
|
||||
async with stdio_client(server_params) as (read_stream, write_stream):
|
||||
async with ClientSession(read_stream, write_stream) as session:
|
||||
# The session should raise an error when the connection closes
|
||||
|
||||
Reference in New Issue
Block a user