Use 120 characters instead of 88 (#856)

This commit is contained in:
Marcelo Trylesinski
2025-06-11 02:45:50 -07:00
committed by GitHub
parent f7265f7b91
commit 543961968c
90 changed files with 687 additions and 2142 deletions

View File

@@ -134,9 +134,7 @@ class TestOAuthClientProvider:
assert len(verifier) == 128
# Check charset (RFC 7636: A-Z, a-z, 0-9, "-", ".", "_", "~")
allowed_chars = set(
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~"
)
allowed_chars = set("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~")
assert set(verifier) <= allowed_chars
# Check uniqueness (generate multiple and ensure they're different)
@@ -151,9 +149,7 @@ class TestOAuthClientProvider:
# Manually calculate expected challenge
expected_digest = hashlib.sha256(verifier.encode()).digest()
expected_challenge = (
base64.urlsafe_b64encode(expected_digest).decode().rstrip("=")
)
expected_challenge = base64.urlsafe_b64encode(expected_digest).decode().rstrip("=")
assert challenge == expected_challenge
@@ -166,29 +162,19 @@ class TestOAuthClientProvider:
async def test_get_authorization_base_url(self, oauth_provider):
"""Test authorization base URL extraction."""
# Test with path
assert (
oauth_provider._get_authorization_base_url("https://api.example.com/v1/mcp")
== "https://api.example.com"
)
assert oauth_provider._get_authorization_base_url("https://api.example.com/v1/mcp") == "https://api.example.com"
# Test with no path
assert (
oauth_provider._get_authorization_base_url("https://api.example.com")
== "https://api.example.com"
)
assert oauth_provider._get_authorization_base_url("https://api.example.com") == "https://api.example.com"
# Test with port
assert (
oauth_provider._get_authorization_base_url(
"https://api.example.com:8080/path/to/mcp"
)
oauth_provider._get_authorization_base_url("https://api.example.com:8080/path/to/mcp")
== "https://api.example.com:8080"
)
@pytest.mark.anyio
async def test_discover_oauth_metadata_success(
self, oauth_provider, oauth_metadata
):
async def test_discover_oauth_metadata_success(self, oauth_provider, oauth_metadata):
"""Test successful OAuth metadata discovery."""
metadata_response = oauth_metadata.model_dump(by_alias=True, mode="json")
@@ -201,23 +187,16 @@ class TestOAuthClientProvider:
mock_response.json.return_value = metadata_response
mock_client.get.return_value = mock_response
result = await oauth_provider._discover_oauth_metadata(
"https://api.example.com/v1/mcp"
)
result = await oauth_provider._discover_oauth_metadata("https://api.example.com/v1/mcp")
assert result is not None
assert (
result.authorization_endpoint == oauth_metadata.authorization_endpoint
)
assert result.authorization_endpoint == oauth_metadata.authorization_endpoint
assert result.token_endpoint == oauth_metadata.token_endpoint
# Verify correct URL was called
mock_client.get.assert_called_once()
call_args = mock_client.get.call_args[0]
assert (
call_args[0]
== "https://api.example.com/.well-known/oauth-authorization-server"
)
assert call_args[0] == "https://api.example.com/.well-known/oauth-authorization-server"
@pytest.mark.anyio
async def test_discover_oauth_metadata_not_found(self, oauth_provider):
@@ -230,16 +209,12 @@ class TestOAuthClientProvider:
mock_response.status_code = 404
mock_client.get.return_value = mock_response
result = await oauth_provider._discover_oauth_metadata(
"https://api.example.com/v1/mcp"
)
result = await oauth_provider._discover_oauth_metadata("https://api.example.com/v1/mcp")
assert result is None
@pytest.mark.anyio
async def test_discover_oauth_metadata_cors_fallback(
self, oauth_provider, oauth_metadata
):
async def test_discover_oauth_metadata_cors_fallback(self, oauth_provider, oauth_metadata):
"""Test OAuth metadata discovery with CORS fallback."""
metadata_response = oauth_metadata.model_dump(by_alias=True, mode="json")
@@ -257,17 +232,13 @@ class TestOAuthClientProvider:
mock_response_success, # Second call succeeds
]
result = await oauth_provider._discover_oauth_metadata(
"https://api.example.com/v1/mcp"
)
result = await oauth_provider._discover_oauth_metadata("https://api.example.com/v1/mcp")
assert result is not None
assert mock_client.get.call_count == 2
@pytest.mark.anyio
async def test_register_oauth_client_success(
self, oauth_provider, oauth_metadata, oauth_client_info
):
async def test_register_oauth_client_success(self, oauth_provider, oauth_metadata, oauth_client_info):
"""Test successful OAuth client registration."""
registration_response = oauth_client_info.model_dump(by_alias=True, mode="json")
@@ -295,9 +266,7 @@ class TestOAuthClientProvider:
assert call_args[0][0] == str(oauth_metadata.registration_endpoint)
@pytest.mark.anyio
async def test_register_oauth_client_fallback_endpoint(
self, oauth_provider, oauth_client_info
):
async def test_register_oauth_client_fallback_endpoint(self, oauth_provider, oauth_client_info):
"""Test OAuth client registration with fallback endpoint."""
registration_response = oauth_client_info.model_dump(by_alias=True, mode="json")
@@ -311,9 +280,7 @@ class TestOAuthClientProvider:
mock_client.post.return_value = mock_response
# Mock metadata discovery to return None (fallback)
with patch.object(
oauth_provider, "_discover_oauth_metadata", return_value=None
):
with patch.object(oauth_provider, "_discover_oauth_metadata", return_value=None):
result = await oauth_provider._register_oauth_client(
"https://api.example.com/v1/mcp",
oauth_provider.client_metadata,
@@ -340,9 +307,7 @@ class TestOAuthClientProvider:
mock_client.post.return_value = mock_response
# Mock metadata discovery to return None (fallback)
with patch.object(
oauth_provider, "_discover_oauth_metadata", return_value=None
):
with patch.object(oauth_provider, "_discover_oauth_metadata", return_value=None):
with pytest.raises(httpx.HTTPStatusError):
await oauth_provider._register_oauth_client(
"https://api.example.com/v1/mcp",
@@ -406,9 +371,7 @@ class TestOAuthClientProvider:
await oauth_provider._validate_token_scopes(token)
@pytest.mark.anyio
async def test_validate_token_scopes_unauthorized(
self, oauth_provider, client_metadata
):
async def test_validate_token_scopes_unauthorized(self, oauth_provider, client_metadata):
"""Test scope validation with unauthorized scopes."""
oauth_provider.client_metadata = client_metadata
token = OAuthToken(
@@ -436,9 +399,7 @@ class TestOAuthClientProvider:
await oauth_provider._validate_token_scopes(token)
@pytest.mark.anyio
async def test_initialize(
self, oauth_provider, mock_storage, oauth_token, oauth_client_info
):
async def test_initialize(self, oauth_provider, mock_storage, oauth_token, oauth_client_info):
"""Test initialization loading from storage."""
mock_storage._tokens = oauth_token
mock_storage._client_info = oauth_client_info
@@ -449,9 +410,7 @@ class TestOAuthClientProvider:
assert oauth_provider._client_info == oauth_client_info
@pytest.mark.anyio
async def test_get_or_register_client_existing(
self, oauth_provider, oauth_client_info
):
async def test_get_or_register_client_existing(self, oauth_provider, oauth_client_info):
"""Test getting existing client info."""
oauth_provider._client_info = oauth_client_info
@@ -460,13 +419,9 @@ class TestOAuthClientProvider:
assert result == oauth_client_info
@pytest.mark.anyio
async def test_get_or_register_client_register_new(
self, oauth_provider, oauth_client_info
):
async def test_get_or_register_client_register_new(self, oauth_provider, oauth_client_info):
"""Test registering new client."""
with patch.object(
oauth_provider, "_register_oauth_client", return_value=oauth_client_info
) as mock_register:
with patch.object(oauth_provider, "_register_oauth_client", return_value=oauth_client_info) as mock_register:
result = await oauth_provider._get_or_register_client()
assert result == oauth_client_info
@@ -474,9 +429,7 @@ class TestOAuthClientProvider:
mock_register.assert_called_once()
@pytest.mark.anyio
async def test_exchange_code_for_token_success(
self, oauth_provider, oauth_client_info, oauth_token
):
async def test_exchange_code_for_token_success(self, oauth_provider, oauth_client_info, oauth_token):
"""Test successful code exchange for token."""
oauth_provider._code_verifier = "test_verifier"
token_response = oauth_token.model_dump(by_alias=True, mode="json")
@@ -490,23 +443,14 @@ class TestOAuthClientProvider:
mock_response.json.return_value = token_response
mock_client.post.return_value = mock_response
with patch.object(
oauth_provider, "_validate_token_scopes"
) as mock_validate:
await oauth_provider._exchange_code_for_token(
"test_auth_code", oauth_client_info
)
with patch.object(oauth_provider, "_validate_token_scopes") as mock_validate:
await oauth_provider._exchange_code_for_token("test_auth_code", oauth_client_info)
assert (
oauth_provider._current_tokens.access_token
== oauth_token.access_token
)
assert oauth_provider._current_tokens.access_token == oauth_token.access_token
mock_validate.assert_called_once()
@pytest.mark.anyio
async def test_exchange_code_for_token_failure(
self, oauth_provider, oauth_client_info
):
async def test_exchange_code_for_token_failure(self, oauth_provider, oauth_client_info):
"""Test failed code exchange for token."""
oauth_provider._code_verifier = "test_verifier"
@@ -520,14 +464,10 @@ class TestOAuthClientProvider:
mock_client.post.return_value = mock_response
with pytest.raises(Exception, match="Token exchange failed"):
await oauth_provider._exchange_code_for_token(
"invalid_auth_code", oauth_client_info
)
await oauth_provider._exchange_code_for_token("invalid_auth_code", oauth_client_info)
@pytest.mark.anyio
async def test_refresh_access_token_success(
self, oauth_provider, oauth_client_info, oauth_token
):
async def test_refresh_access_token_success(self, oauth_provider, oauth_client_info, oauth_token):
"""Test successful token refresh."""
oauth_provider._current_tokens = oauth_token
oauth_provider._client_info = oauth_client_info
@@ -550,16 +490,11 @@ class TestOAuthClientProvider:
mock_response.json.return_value = token_response
mock_client.post.return_value = mock_response
with patch.object(
oauth_provider, "_validate_token_scopes"
) as mock_validate:
with patch.object(oauth_provider, "_validate_token_scopes") as mock_validate:
result = await oauth_provider._refresh_access_token()
assert result is True
assert (
oauth_provider._current_tokens.access_token
== new_token.access_token
)
assert oauth_provider._current_tokens.access_token == new_token.access_token
mock_validate.assert_called_once()
@pytest.mark.anyio
@@ -575,9 +510,7 @@ class TestOAuthClientProvider:
assert result is False
@pytest.mark.anyio
async def test_refresh_access_token_failure(
self, oauth_provider, oauth_client_info, oauth_token
):
async def test_refresh_access_token_failure(self, oauth_provider, oauth_client_info, oauth_token):
"""Test failed token refresh."""
oauth_provider._current_tokens = oauth_token
oauth_provider._client_info = oauth_client_info
@@ -594,9 +527,7 @@ class TestOAuthClientProvider:
assert result is False
@pytest.mark.anyio
async def test_perform_oauth_flow_success(
self, oauth_provider, oauth_metadata, oauth_client_info
):
async def test_perform_oauth_flow_success(self, oauth_provider, oauth_metadata, oauth_client_info):
"""Test successful OAuth flow."""
oauth_provider._metadata = oauth_metadata
oauth_provider._client_info = oauth_client_info
@@ -640,9 +571,7 @@ class TestOAuthClientProvider:
mock_exchange.assert_called_once_with("test_auth_code", oauth_client_info)
@pytest.mark.anyio
async def test_perform_oauth_flow_state_mismatch(
self, oauth_provider, oauth_metadata, oauth_client_info
):
async def test_perform_oauth_flow_state_mismatch(self, oauth_provider, oauth_metadata, oauth_client_info):
"""Test OAuth flow with state parameter mismatch."""
oauth_provider._metadata = oauth_metadata
oauth_provider._client_info = oauth_client_info
@@ -678,9 +607,7 @@ class TestOAuthClientProvider:
oauth_provider._current_tokens = oauth_token
oauth_provider._token_expiry_time = time.time() - 3600 # Expired
with patch.object(
oauth_provider, "_refresh_access_token", return_value=True
) as mock_refresh:
with patch.object(oauth_provider, "_refresh_access_token", return_value=True) as mock_refresh:
await oauth_provider.ensure_token()
mock_refresh.assert_called_once()
@@ -707,10 +634,7 @@ class TestOAuthClientProvider:
auth_flow = oauth_provider.async_auth_flow(request)
updated_request = await auth_flow.__anext__()
assert (
updated_request.headers["Authorization"]
== f"Bearer {oauth_token.access_token}"
)
assert updated_request.headers["Authorization"] == f"Bearer {oauth_token.access_token}"
# Send mock response
try:
@@ -761,9 +685,7 @@ class TestOAuthClientProvider:
assert "Authorization" not in updated_request.headers
@pytest.mark.anyio
async def test_scope_priority_client_metadata_first(
self, oauth_provider, oauth_client_info
):
async def test_scope_priority_client_metadata_first(self, oauth_provider, oauth_client_info):
"""Test that client metadata scope takes priority."""
oauth_provider.client_metadata.scope = "read write"
oauth_provider._client_info = oauth_client_info
@@ -782,18 +704,13 @@ class TestOAuthClientProvider:
# Apply scope logic from _perform_oauth_flow
if oauth_provider.client_metadata.scope:
auth_params["scope"] = oauth_provider.client_metadata.scope
elif (
hasattr(oauth_provider._client_info, "scope")
and oauth_provider._client_info.scope
):
elif hasattr(oauth_provider._client_info, "scope") and oauth_provider._client_info.scope:
auth_params["scope"] = oauth_provider._client_info.scope
assert auth_params["scope"] == "read write"
@pytest.mark.anyio
async def test_scope_priority_no_client_metadata_scope(
self, oauth_provider, oauth_client_info
):
async def test_scope_priority_no_client_metadata_scope(self, oauth_provider, oauth_client_info):
"""Test that no scope parameter is set when client metadata has no scope."""
oauth_provider.client_metadata.scope = None
oauth_provider._client_info = oauth_client_info
@@ -837,10 +754,7 @@ class TestOAuthClientProvider:
# Apply scope logic from _perform_oauth_flow
if oauth_provider.client_metadata.scope:
auth_params["scope"] = oauth_provider.client_metadata.scope
elif (
hasattr(oauth_provider._client_info, "scope")
and oauth_provider._client_info.scope
):
elif hasattr(oauth_provider._client_info, "scope") and oauth_provider._client_info.scope:
auth_params["scope"] = oauth_provider._client_info.scope
# No scope should be set
@@ -866,9 +780,7 @@ class TestOAuthClientProvider:
oauth_provider.redirect_handler = mock_redirect_handler
# Patch secrets.compare_digest to verify it's being called
with patch(
"mcp.client.auth.secrets.compare_digest", return_value=False
) as mock_compare:
with patch("mcp.client.auth.secrets.compare_digest", return_value=False) as mock_compare:
with pytest.raises(Exception, match="State parameter mismatch"):
await oauth_provider._perform_oauth_flow()
@@ -876,9 +788,7 @@ class TestOAuthClientProvider:
mock_compare.assert_called_once()
@pytest.mark.anyio
async def test_state_parameter_validation_none_state(
self, oauth_provider, oauth_metadata, oauth_client_info
):
async def test_state_parameter_validation_none_state(self, oauth_provider, oauth_metadata, oauth_client_info):
"""Test that None state is handled correctly."""
oauth_provider._metadata = oauth_metadata
oauth_provider._client_info = oauth_client_info
@@ -913,9 +823,7 @@ class TestOAuthClientProvider:
mock_client.post.return_value = mock_response
with pytest.raises(Exception, match="Token exchange failed"):
await oauth_provider._exchange_code_for_token(
"invalid_auth_code", oauth_client_info
)
await oauth_provider._exchange_code_for_token("invalid_auth_code", oauth_client_info)
@pytest.mark.parametrize(
@@ -968,9 +876,7 @@ def test_build_metadata(
metadata = build_metadata(
issuer_url=AnyHttpUrl(issuer_url),
service_documentation_url=AnyHttpUrl(service_documentation_url),
client_registration_options=ClientRegistrationOptions(
enabled=True, valid_scopes=["read", "write", "admin"]
),
client_registration_options=ClientRegistrationOptions(enabled=True, valid_scopes=["read", "write", "admin"]),
revocation_options=RevocationOptions(enabled=True),
)