Use 120 characters instead of 88 (#856)

This commit is contained in:
Marcelo Trylesinski
2025-06-11 02:45:50 -07:00
committed by GitHub
parent f7265f7b91
commit 543961968c
90 changed files with 687 additions and 2142 deletions

View File

@@ -47,9 +47,7 @@ class MockOAuthProvider(OAuthAuthorizationServerProvider):
async def register_client(self, client_info: OAuthClientInformationFull):
self.clients[client_info.client_id] = client_info
async def authorize(
self, client: OAuthClientInformationFull, params: AuthorizationParams
) -> str:
async def authorize(self, client: OAuthClientInformationFull, params: AuthorizationParams) -> str:
# toy authorize implementation which just immediately generates an authorization
# code and completes the redirect
code = AuthorizationCode(
@@ -63,9 +61,7 @@ class MockOAuthProvider(OAuthAuthorizationServerProvider):
)
self.auth_codes[code.code] = code
return construct_redirect_uri(
str(params.redirect_uri), code=code.code, state=params.state
)
return construct_redirect_uri(str(params.redirect_uri), code=code.code, state=params.state)
async def load_authorization_code(
self, client: OAuthClientInformationFull, authorization_code: str
@@ -102,9 +98,7 @@ class MockOAuthProvider(OAuthAuthorizationServerProvider):
refresh_token=refresh_token,
)
async def load_refresh_token(
self, client: OAuthClientInformationFull, refresh_token: str
) -> RefreshToken | None:
async def load_refresh_token(self, client: OAuthClientInformationFull, refresh_token: str) -> RefreshToken | None:
old_access_token = self.refresh_tokens.get(refresh_token)
if old_access_token is None:
return None
@@ -224,9 +218,7 @@ def auth_app(mock_oauth_provider):
@pytest.fixture
async def test_client(auth_app):
async with httpx.AsyncClient(
transport=httpx.ASGITransport(app=auth_app), base_url="https://mcptest.com"
) as client:
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=auth_app), base_url="https://mcptest.com") as client:
yield client
@@ -261,11 +253,7 @@ async def registered_client(test_client: httpx.AsyncClient, request):
def pkce_challenge():
"""Create a PKCE challenge with code_verifier and code_challenge."""
code_verifier = "some_random_verifier_string"
code_challenge = (
base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest())
.decode()
.rstrip("=")
)
code_challenge = base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest()).decode().rstrip("=")
return {"code_verifier": code_verifier, "code_challenge": code_challenge}
@@ -356,17 +344,13 @@ class TestAuthEndpoints:
metadata = response.json()
assert metadata["issuer"] == "https://auth.example.com/"
assert (
metadata["authorization_endpoint"] == "https://auth.example.com/authorize"
)
assert metadata["authorization_endpoint"] == "https://auth.example.com/authorize"
assert metadata["token_endpoint"] == "https://auth.example.com/token"
assert metadata["registration_endpoint"] == "https://auth.example.com/register"
assert metadata["revocation_endpoint"] == "https://auth.example.com/revoke"
assert metadata["response_types_supported"] == ["code"]
assert metadata["code_challenge_methods_supported"] == ["S256"]
assert metadata["token_endpoint_auth_methods_supported"] == [
"client_secret_post"
]
assert metadata["token_endpoint_auth_methods_supported"] == ["client_secret_post"]
assert metadata["grant_types_supported"] == [
"authorization_code",
"refresh_token",
@@ -386,14 +370,10 @@ class TestAuthEndpoints:
)
error_response = response.json()
assert error_response["error"] == "invalid_request"
assert (
"error_description" in error_response
) # Contains validation error messages
assert "error_description" in error_response # Contains validation error messages
@pytest.mark.anyio
async def test_token_invalid_auth_code(
self, test_client, registered_client, pkce_challenge
):
async def test_token_invalid_auth_code(self, test_client, registered_client, pkce_challenge):
"""Test token endpoint error - authorization code does not exist."""
# Try to use a non-existent authorization code
response = await test_client.post(
@@ -413,9 +393,7 @@ class TestAuthEndpoints:
assert response.status_code == 400
error_response = response.json()
assert error_response["error"] == "invalid_grant"
assert (
"authorization code does not exist" in error_response["error_description"]
)
assert "authorization code does not exist" in error_response["error_description"]
@pytest.mark.anyio
async def test_token_expired_auth_code(
@@ -458,9 +436,7 @@ class TestAuthEndpoints:
assert response.status_code == 400
error_response = response.json()
assert error_response["error"] == "invalid_grant"
assert (
"authorization code has expired" in error_response["error_description"]
)
assert "authorization code has expired" in error_response["error_description"]
@pytest.mark.anyio
@pytest.mark.parametrize(
@@ -475,9 +451,7 @@ class TestAuthEndpoints:
],
indirect=True,
)
async def test_token_redirect_uri_mismatch(
self, test_client, registered_client, auth_code, pkce_challenge
):
async def test_token_redirect_uri_mismatch(self, test_client, registered_client, auth_code, pkce_challenge):
"""Test token endpoint error - redirect URI mismatch."""
# Try to use the code with a different redirect URI
response = await test_client.post(
@@ -498,9 +472,7 @@ class TestAuthEndpoints:
assert "redirect_uri did not match" in error_response["error_description"]
@pytest.mark.anyio
async def test_token_code_verifier_mismatch(
self, test_client, registered_client, auth_code
):
async def test_token_code_verifier_mismatch(self, test_client, registered_client, auth_code):
"""Test token endpoint error - PKCE code verifier mismatch."""
# Try to use the code with an incorrect code verifier
response = await test_client.post(
@@ -569,9 +541,7 @@ class TestAuthEndpoints:
# Step 2: Time travel forward 4 hours (tokens expire in 1 hour by default)
# Mock the time.time() function to return a value 4 hours in the future
with unittest.mock.patch(
"time.time", return_value=current_time + 14400
): # 4 hours = 14400 seconds
with unittest.mock.patch("time.time", return_value=current_time + 14400): # 4 hours = 14400 seconds
# Try to use the refresh token which should now be considered expired
response = await test_client.post(
"/token",
@@ -590,9 +560,7 @@ class TestAuthEndpoints:
assert "refresh token has expired" in error_response["error_description"]
@pytest.mark.anyio
async def test_token_invalid_scope(
self, test_client, registered_client, auth_code, pkce_challenge
):
async def test_token_invalid_scope(self, test_client, registered_client, auth_code, pkce_challenge):
"""Test token endpoint error - invalid scope in refresh token request."""
# Exchange authorization code for tokens
token_response = await test_client.post(
@@ -628,9 +596,7 @@ class TestAuthEndpoints:
assert "cannot request scope" in error_response["error_description"]
@pytest.mark.anyio
async def test_client_registration(
self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider
):
async def test_client_registration(self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider):
"""Test client registration."""
client_metadata = {
"redirect_uris": ["https://client.example.com/callback"],
@@ -656,9 +622,7 @@ class TestAuthEndpoints:
# ) is not None
@pytest.mark.anyio
async def test_client_registration_missing_required_fields(
self, test_client: httpx.AsyncClient
):
async def test_client_registration_missing_required_fields(self, test_client: httpx.AsyncClient):
"""Test client registration with missing required fields."""
# Missing redirect_uris which is a required field
client_metadata = {
@@ -677,9 +641,7 @@ class TestAuthEndpoints:
assert error_data["error_description"] == "redirect_uris: Field required"
@pytest.mark.anyio
async def test_client_registration_invalid_uri(
self, test_client: httpx.AsyncClient
):
async def test_client_registration_invalid_uri(self, test_client: httpx.AsyncClient):
"""Test client registration with invalid URIs."""
# Invalid redirect_uri format
client_metadata = {
@@ -696,14 +658,11 @@ class TestAuthEndpoints:
assert "error" in error_data
assert error_data["error"] == "invalid_client_metadata"
assert error_data["error_description"] == (
"redirect_uris.0: Input should be a valid URL, "
"relative URL without a base"
"redirect_uris.0: Input should be a valid URL, " "relative URL without a base"
)
@pytest.mark.anyio
async def test_client_registration_empty_redirect_uris(
self, test_client: httpx.AsyncClient
):
async def test_client_registration_empty_redirect_uris(self, test_client: httpx.AsyncClient):
"""Test client registration with empty redirect_uris array."""
client_metadata = {
"redirect_uris": [], # Empty array
@@ -719,8 +678,7 @@ class TestAuthEndpoints:
assert "error" in error_data
assert error_data["error"] == "invalid_client_metadata"
assert (
error_data["error_description"]
== "redirect_uris: List should have at least 1 item after validation, not 0"
error_data["error_description"] == "redirect_uris: List should have at least 1 item after validation, not 0"
)
@pytest.mark.anyio
@@ -875,12 +833,7 @@ class TestAuthEndpoints:
assert response.status_code == 200
# Verify that the token was revoked
assert (
await mock_oauth_provider.load_access_token(
new_token_response["access_token"]
)
is None
)
assert await mock_oauth_provider.load_access_token(new_token_response["access_token"]) is None
@pytest.mark.anyio
async def test_revoke_invalid_token(self, test_client, registered_client):
@@ -913,9 +866,7 @@ class TestAuthEndpoints:
assert "token_type_hint" in error_response["error_description"]
@pytest.mark.anyio
async def test_client_registration_disallowed_scopes(
self, test_client: httpx.AsyncClient
):
async def test_client_registration_disallowed_scopes(self, test_client: httpx.AsyncClient):
"""Test client registration with scopes that are not allowed."""
client_metadata = {
"redirect_uris": ["https://client.example.com/callback"],
@@ -955,18 +906,14 @@ class TestAuthEndpoints:
assert client_info["scope"] == "read write"
# Retrieve the client from the store to verify default scopes
registered_client = await mock_oauth_provider.get_client(
client_info["client_id"]
)
registered_client = await mock_oauth_provider.get_client(client_info["client_id"])
assert registered_client is not None
# Check that default scopes were applied
assert registered_client.scope == "read write"
@pytest.mark.anyio
async def test_client_registration_invalid_grant_type(
self, test_client: httpx.AsyncClient
):
async def test_client_registration_invalid_grant_type(self, test_client: httpx.AsyncClient):
client_metadata = {
"redirect_uris": ["https://client.example.com/callback"],
"client_name": "Test Client",
@@ -981,19 +928,14 @@ class TestAuthEndpoints:
error_data = response.json()
assert "error" in error_data
assert error_data["error"] == "invalid_client_metadata"
assert (
error_data["error_description"]
== "grant_types must be authorization_code and refresh_token"
)
assert error_data["error_description"] == "grant_types must be authorization_code and refresh_token"
class TestAuthorizeEndpointErrors:
"""Test error handling in the OAuth authorization endpoint."""
@pytest.mark.anyio
async def test_authorize_missing_client_id(
self, test_client: httpx.AsyncClient, pkce_challenge
):
async def test_authorize_missing_client_id(self, test_client: httpx.AsyncClient, pkce_challenge):
"""Test authorization endpoint with missing client_id.
According to the OAuth2.0 spec, if client_id is missing, the server should
@@ -1017,9 +959,7 @@ class TestAuthorizeEndpointErrors:
assert "client_id" in response.text.lower()
@pytest.mark.anyio
async def test_authorize_invalid_client_id(
self, test_client: httpx.AsyncClient, pkce_challenge
):
async def test_authorize_invalid_client_id(self, test_client: httpx.AsyncClient, pkce_challenge):
"""Test authorization endpoint with invalid client_id.
According to the OAuth2.0 spec, if client_id is invalid, the server should
@@ -1202,9 +1142,7 @@ class TestAuthorizeEndpointErrors:
assert query_params["state"][0] == "test_state"
@pytest.mark.anyio
async def test_authorize_missing_pkce_challenge(
self, test_client: httpx.AsyncClient, registered_client
):
async def test_authorize_missing_pkce_challenge(self, test_client: httpx.AsyncClient, registered_client):
"""Test authorization endpoint with missing PKCE code_challenge.
Missing PKCE parameters should result in invalid_request error.
@@ -1233,9 +1171,7 @@ class TestAuthorizeEndpointErrors:
assert query_params["state"][0] == "test_state"
@pytest.mark.anyio
async def test_authorize_invalid_scope(
self, test_client: httpx.AsyncClient, registered_client, pkce_challenge
):
async def test_authorize_invalid_scope(self, test_client: httpx.AsyncClient, registered_client, pkce_challenge):
"""Test authorization endpoint with invalid scope.
Invalid scope should redirect with invalid_scope error.