""" Integration tests for MCP authorization components. """ import base64 import hashlib import secrets import time import unittest.mock from urllib.parse import parse_qs, urlparse import httpx import pytest from pydantic import AnyHttpUrl from starlette.applications import Starlette from mcp.server.auth.provider import ( AccessToken, AuthorizationCode, AuthorizationParams, OAuthAuthorizationServerProvider, RefreshToken, construct_redirect_uri, ) from mcp.server.auth.routes import ( ClientRegistrationOptions, RevocationOptions, create_auth_routes, ) from mcp.shared.auth import ( OAuthClientInformationFull, OAuthToken, ) # Mock OAuth provider for testing class MockOAuthProvider(OAuthAuthorizationServerProvider): def __init__(self): self.clients = {} self.auth_codes = {} # code -> {client_id, code_challenge, redirect_uri} self.tokens = {} # token -> {client_id, scopes, expires_at} self.refresh_tokens = {} # refresh_token -> access_token async def get_client(self, client_id: str) -> OAuthClientInformationFull | None: return self.clients.get(client_id) async def register_client(self, client_info: OAuthClientInformationFull): self.clients[client_info.client_id] = client_info async def authorize(self, client: OAuthClientInformationFull, params: AuthorizationParams) -> str: # toy authorize implementation which just immediately generates an authorization # code and completes the redirect code = AuthorizationCode( code=f"code_{int(time.time())}", client_id=client.client_id, code_challenge=params.code_challenge, redirect_uri=params.redirect_uri, redirect_uri_provided_explicitly=params.redirect_uri_provided_explicitly, expires_at=time.time() + 300, scopes=params.scopes or ["read", "write"], ) self.auth_codes[code.code] = code return construct_redirect_uri(str(params.redirect_uri), code=code.code, state=params.state) async def load_authorization_code( self, client: OAuthClientInformationFull, authorization_code: str ) -> AuthorizationCode | None: return self.auth_codes.get(authorization_code) async def exchange_authorization_code( self, client: OAuthClientInformationFull, authorization_code: AuthorizationCode ) -> OAuthToken: assert authorization_code.code in self.auth_codes # Generate an access token and refresh token access_token = f"access_{secrets.token_hex(32)}" refresh_token = f"refresh_{secrets.token_hex(32)}" # Store the tokens self.tokens[access_token] = AccessToken( token=access_token, client_id=client.client_id, scopes=authorization_code.scopes, expires_at=int(time.time()) + 3600, ) self.refresh_tokens[refresh_token] = access_token # Remove the used code del self.auth_codes[authorization_code.code] return OAuthToken( access_token=access_token, token_type="Bearer", expires_in=3600, scope="read write", refresh_token=refresh_token, ) async def load_refresh_token(self, client: OAuthClientInformationFull, refresh_token: str) -> RefreshToken | None: old_access_token = self.refresh_tokens.get(refresh_token) if old_access_token is None: return None token_info = self.tokens.get(old_access_token) if token_info is None: return None # Create a RefreshToken object that matches what is expected in later code refresh_obj = RefreshToken( token=refresh_token, client_id=token_info.client_id, scopes=token_info.scopes, expires_at=token_info.expires_at, ) return refresh_obj async def exchange_refresh_token( self, client: OAuthClientInformationFull, refresh_token: RefreshToken, scopes: list[str], ) -> OAuthToken: # Check if refresh token exists assert refresh_token.token in self.refresh_tokens old_access_token = self.refresh_tokens[refresh_token.token] # Check if the access token exists assert old_access_token in self.tokens # Check if the token was issued to this client token_info = self.tokens[old_access_token] assert token_info.client_id == client.client_id # Generate a new access token and refresh token new_access_token = f"access_{secrets.token_hex(32)}" new_refresh_token = f"refresh_{secrets.token_hex(32)}" # Store the new tokens self.tokens[new_access_token] = AccessToken( token=new_access_token, client_id=client.client_id, scopes=scopes or token_info.scopes, expires_at=int(time.time()) + 3600, ) self.refresh_tokens[new_refresh_token] = new_access_token # Remove the old tokens del self.refresh_tokens[refresh_token.token] del self.tokens[old_access_token] return OAuthToken( access_token=new_access_token, token_type="Bearer", expires_in=3600, scope=" ".join(scopes) if scopes else " ".join(token_info.scopes), refresh_token=new_refresh_token, ) async def load_access_token(self, token: str) -> AccessToken | None: token_info = self.tokens.get(token) # Check if token is expired # if token_info.expires_at < int(time.time()): # raise InvalidTokenError("Access token has expired") return token_info and AccessToken( token=token, client_id=token_info.client_id, scopes=token_info.scopes, expires_at=token_info.expires_at, ) async def revoke_token(self, token: AccessToken | RefreshToken) -> None: match token: case RefreshToken(): # Remove the refresh token del self.refresh_tokens[token.token] case AccessToken(): # Remove the access token del self.tokens[token.token] # Also remove any refresh tokens that point to this access token for refresh_token, access_token in list(self.refresh_tokens.items()): if access_token == token.token: del self.refresh_tokens[refresh_token] @pytest.fixture def mock_oauth_provider(): return MockOAuthProvider() @pytest.fixture def auth_app(mock_oauth_provider): # Create auth router auth_routes = create_auth_routes( mock_oauth_provider, AnyHttpUrl("https://auth.example.com"), AnyHttpUrl("https://docs.example.com"), client_registration_options=ClientRegistrationOptions( enabled=True, valid_scopes=["read", "write", "profile"], default_scopes=["read", "write"], ), revocation_options=RevocationOptions(enabled=True), ) # Create Starlette app app = Starlette(routes=auth_routes) return app @pytest.fixture async def test_client(auth_app): async with httpx.AsyncClient(transport=httpx.ASGITransport(app=auth_app), base_url="https://mcptest.com") as client: yield client @pytest.fixture async def registered_client(test_client: httpx.AsyncClient, request): """Create and register a test client. Parameters can be customized via indirect parameterization: @pytest.mark.parametrize("registered_client", [{"grant_types": ["authorization_code"]}], indirect=True) """ # Default client metadata client_metadata = { "redirect_uris": ["https://client.example.com/callback"], "client_name": "Test Client", "grant_types": ["authorization_code", "refresh_token"], } # Override with any parameters from the test if hasattr(request, "param") and request.param: client_metadata.update(request.param) response = await test_client.post("/register", json=client_metadata) assert response.status_code == 201, f"Failed to register client: {response.content}" client_info = response.json() return client_info @pytest.fixture def pkce_challenge(): """Create a PKCE challenge with code_verifier and code_challenge.""" code_verifier = "some_random_verifier_string" code_challenge = base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest()).decode().rstrip("=") return {"code_verifier": code_verifier, "code_challenge": code_challenge} @pytest.fixture async def auth_code(test_client, registered_client, pkce_challenge, request): """Get an authorization code. Parameters can be customized via indirect parameterization: @pytest.mark.parametrize("auth_code", [{"redirect_uri": "https://client.example.com/other-callback"}], indirect=True) """ # Default authorize params auth_params = { "response_type": "code", "client_id": registered_client["client_id"], "redirect_uri": "https://client.example.com/callback", "code_challenge": pkce_challenge["code_challenge"], "code_challenge_method": "S256", "state": "test_state", } # Override with any parameters from the test if hasattr(request, "param") and request.param: auth_params.update(request.param) response = await test_client.get("/authorize", params=auth_params) assert response.status_code == 302, f"Failed to get auth code: {response.content}" # Extract the authorization code redirect_url = response.headers["location"] parsed_url = urlparse(redirect_url) query_params = parse_qs(parsed_url.query) assert "code" in query_params, f"No code in response: {query_params}" auth_code = query_params["code"][0] return { "code": auth_code, "redirect_uri": auth_params["redirect_uri"], "state": query_params.get("state", [None])[0], } @pytest.fixture async def tokens(test_client, registered_client, auth_code, pkce_challenge, request): """Exchange authorization code for tokens. Parameters can be customized via indirect parameterization: @pytest.mark.parametrize("tokens", [{"code_verifier": "wrong_verifier"}], indirect=True) """ # Default token request params token_params = { "grant_type": "authorization_code", "client_id": registered_client["client_id"], "client_secret": registered_client["client_secret"], "code": auth_code["code"], "code_verifier": pkce_challenge["code_verifier"], "redirect_uri": auth_code["redirect_uri"], } # Override with any parameters from the test if hasattr(request, "param") and request.param: token_params.update(request.param) response = await test_client.post("/token", data=token_params) # Don't assert success here since some tests will intentionally cause errors return { "response": response, "params": token_params, } class TestAuthEndpoints: @pytest.mark.anyio async def test_metadata_endpoint(self, test_client: httpx.AsyncClient): """Test the OAuth 2.0 metadata endpoint.""" print("Sending request to metadata endpoint") response = await test_client.get("/.well-known/oauth-authorization-server") print(f"Got response: {response.status_code}") if response.status_code != 200: print(f"Response content: {response.content}") assert response.status_code == 200 metadata = response.json() assert metadata["issuer"] == "https://auth.example.com/" assert metadata["authorization_endpoint"] == "https://auth.example.com/authorize" assert metadata["token_endpoint"] == "https://auth.example.com/token" assert metadata["registration_endpoint"] == "https://auth.example.com/register" assert metadata["revocation_endpoint"] == "https://auth.example.com/revoke" assert metadata["response_types_supported"] == ["code"] assert metadata["code_challenge_methods_supported"] == ["S256"] assert metadata["token_endpoint_auth_methods_supported"] == ["client_secret_post"] assert metadata["grant_types_supported"] == [ "authorization_code", "refresh_token", ] assert metadata["service_documentation"] == "https://docs.example.com/" @pytest.mark.anyio async def test_token_validation_error(self, test_client: httpx.AsyncClient): """Test token endpoint error - validation error.""" # Missing required fields response = await test_client.post( "/token", data={ "grant_type": "authorization_code", # Missing code, code_verifier, client_id, etc. }, ) error_response = response.json() assert error_response["error"] == "invalid_request" assert "error_description" in error_response # Contains validation error messages @pytest.mark.anyio async def test_token_invalid_auth_code(self, test_client, registered_client, pkce_challenge): """Test token endpoint error - authorization code does not exist.""" # Try to use a non-existent authorization code response = await test_client.post( "/token", data={ "grant_type": "authorization_code", "client_id": registered_client["client_id"], "client_secret": registered_client["client_secret"], "code": "non_existent_auth_code", "code_verifier": pkce_challenge["code_verifier"], "redirect_uri": "https://client.example.com/callback", }, ) print(f"Status code: {response.status_code}") print(f"Response body: {response.content}") print(f"Response JSON: {response.json()}") assert response.status_code == 400 error_response = response.json() assert error_response["error"] == "invalid_grant" assert "authorization code does not exist" in error_response["error_description"] @pytest.mark.anyio async def test_token_expired_auth_code( self, test_client, registered_client, auth_code, pkce_challenge, mock_oauth_provider, ): """Test token endpoint error - authorization code has expired.""" # Get the current time for our time mocking current_time = time.time() # Find the auth code object code_value = auth_code["code"] found_code = None for code_obj in mock_oauth_provider.auth_codes.values(): if code_obj.code == code_value: found_code = code_obj break assert found_code is not None # Authorization codes are typically short-lived (5 minutes = 300 seconds) # So we'll mock time to be 10 minutes (600 seconds) in the future with unittest.mock.patch("time.time", return_value=current_time + 600): # Try to use the expired authorization code response = await test_client.post( "/token", data={ "grant_type": "authorization_code", "client_id": registered_client["client_id"], "client_secret": registered_client["client_secret"], "code": code_value, "code_verifier": pkce_challenge["code_verifier"], "redirect_uri": auth_code["redirect_uri"], }, ) assert response.status_code == 400 error_response = response.json() assert error_response["error"] == "invalid_grant" assert "authorization code has expired" in error_response["error_description"] @pytest.mark.anyio @pytest.mark.parametrize( "registered_client", [ { "redirect_uris": [ "https://client.example.com/callback", "https://client.example.com/other-callback", ] } ], indirect=True, ) async def test_token_redirect_uri_mismatch(self, test_client, registered_client, auth_code, pkce_challenge): """Test token endpoint error - redirect URI mismatch.""" # Try to use the code with a different redirect URI response = await test_client.post( "/token", data={ "grant_type": "authorization_code", "client_id": registered_client["client_id"], "client_secret": registered_client["client_secret"], "code": auth_code["code"], "code_verifier": pkce_challenge["code_verifier"], # Different from the one used in /authorize "redirect_uri": "https://client.example.com/other-callback", }, ) assert response.status_code == 400 error_response = response.json() assert error_response["error"] == "invalid_request" assert "redirect_uri did not match" in error_response["error_description"] @pytest.mark.anyio async def test_token_code_verifier_mismatch(self, test_client, registered_client, auth_code): """Test token endpoint error - PKCE code verifier mismatch.""" # Try to use the code with an incorrect code verifier response = await test_client.post( "/token", data={ "grant_type": "authorization_code", "client_id": registered_client["client_id"], "client_secret": registered_client["client_secret"], "code": auth_code["code"], # Different from the one used to create challenge "code_verifier": "incorrect_code_verifier", "redirect_uri": auth_code["redirect_uri"], }, ) assert response.status_code == 400 error_response = response.json() assert error_response["error"] == "invalid_grant" assert "incorrect code_verifier" in error_response["error_description"] @pytest.mark.anyio async def test_token_invalid_refresh_token(self, test_client, registered_client): """Test token endpoint error - refresh token does not exist.""" # Try to use a non-existent refresh token response = await test_client.post( "/token", data={ "grant_type": "refresh_token", "client_id": registered_client["client_id"], "client_secret": registered_client["client_secret"], "refresh_token": "non_existent_refresh_token", }, ) assert response.status_code == 400 error_response = response.json() assert error_response["error"] == "invalid_grant" assert "refresh token does not exist" in error_response["error_description"] @pytest.mark.anyio async def test_token_expired_refresh_token( self, test_client, registered_client, auth_code, pkce_challenge, mock_oauth_provider, ): """Test token endpoint error - refresh token has expired.""" # Step 1: First, let's create a token and refresh token at the current time current_time = time.time() # Exchange authorization code for tokens normally token_response = await test_client.post( "/token", data={ "grant_type": "authorization_code", "client_id": registered_client["client_id"], "client_secret": registered_client["client_secret"], "code": auth_code["code"], "code_verifier": pkce_challenge["code_verifier"], "redirect_uri": auth_code["redirect_uri"], }, ) assert token_response.status_code == 200 tokens = token_response.json() refresh_token = tokens["refresh_token"] # Step 2: Time travel forward 4 hours (tokens expire in 1 hour by default) # Mock the time.time() function to return a value 4 hours in the future with unittest.mock.patch("time.time", return_value=current_time + 14400): # 4 hours = 14400 seconds # Try to use the refresh token which should now be considered expired response = await test_client.post( "/token", data={ "grant_type": "refresh_token", "client_id": registered_client["client_id"], "client_secret": registered_client["client_secret"], "refresh_token": refresh_token, }, ) # In the "future", the token should be considered expired assert response.status_code == 400 error_response = response.json() assert error_response["error"] == "invalid_grant" assert "refresh token has expired" in error_response["error_description"] @pytest.mark.anyio async def test_token_invalid_scope(self, test_client, registered_client, auth_code, pkce_challenge): """Test token endpoint error - invalid scope in refresh token request.""" # Exchange authorization code for tokens token_response = await test_client.post( "/token", data={ "grant_type": "authorization_code", "client_id": registered_client["client_id"], "client_secret": registered_client["client_secret"], "code": auth_code["code"], "code_verifier": pkce_challenge["code_verifier"], "redirect_uri": auth_code["redirect_uri"], }, ) assert token_response.status_code == 200 tokens = token_response.json() refresh_token = tokens["refresh_token"] # Try to use refresh token with an invalid scope response = await test_client.post( "/token", data={ "grant_type": "refresh_token", "client_id": registered_client["client_id"], "client_secret": registered_client["client_secret"], "refresh_token": refresh_token, "scope": "read write invalid_scope", # Adding an invalid scope }, ) assert response.status_code == 400 error_response = response.json() assert error_response["error"] == "invalid_scope" assert "cannot request scope" in error_response["error_description"] @pytest.mark.anyio async def test_client_registration(self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider): """Test client registration.""" client_metadata = { "redirect_uris": ["https://client.example.com/callback"], "client_name": "Test Client", "client_uri": "https://client.example.com", } response = await test_client.post( "/register", json=client_metadata, ) assert response.status_code == 201, response.content client_info = response.json() assert "client_id" in client_info assert "client_secret" in client_info assert client_info["client_name"] == "Test Client" assert client_info["redirect_uris"] == ["https://client.example.com/callback"] # Verify that the client was registered # assert await mock_oauth_provider.clients_store.get_client( # client_info["client_id"] # ) is not None @pytest.mark.anyio async def test_client_registration_missing_required_fields(self, test_client: httpx.AsyncClient): """Test client registration with missing required fields.""" # Missing redirect_uris which is a required field client_metadata = { "client_name": "Test Client", "client_uri": "https://client.example.com", } response = await test_client.post( "/register", json=client_metadata, ) assert response.status_code == 400 error_data = response.json() assert "error" in error_data assert error_data["error"] == "invalid_client_metadata" assert error_data["error_description"] == "redirect_uris: Field required" @pytest.mark.anyio async def test_client_registration_invalid_uri(self, test_client: httpx.AsyncClient): """Test client registration with invalid URIs.""" # Invalid redirect_uri format client_metadata = { "redirect_uris": ["not-a-valid-uri"], "client_name": "Test Client", } response = await test_client.post( "/register", json=client_metadata, ) assert response.status_code == 400 error_data = response.json() assert "error" in error_data assert error_data["error"] == "invalid_client_metadata" assert error_data["error_description"] == ( "redirect_uris.0: Input should be a valid URL, " "relative URL without a base" ) @pytest.mark.anyio async def test_client_registration_empty_redirect_uris(self, test_client: httpx.AsyncClient): """Test client registration with empty redirect_uris array.""" client_metadata = { "redirect_uris": [], # Empty array "client_name": "Test Client", } response = await test_client.post( "/register", json=client_metadata, ) assert response.status_code == 400 error_data = response.json() assert "error" in error_data assert error_data["error"] == "invalid_client_metadata" assert ( error_data["error_description"] == "redirect_uris: List should have at least 1 item after validation, not 0" ) @pytest.mark.anyio async def test_authorize_form_post( self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider, pkce_challenge, ): """Test the authorization endpoint using POST with form-encoded data.""" # Register a client client_metadata = { "redirect_uris": ["https://client.example.com/callback"], "client_name": "Test Client", "grant_types": ["authorization_code", "refresh_token"], } response = await test_client.post( "/register", json=client_metadata, ) assert response.status_code == 201 client_info = response.json() # Use POST with form-encoded data for authorization response = await test_client.post( "/authorize", data={ "response_type": "code", "client_id": client_info["client_id"], "redirect_uri": "https://client.example.com/callback", "code_challenge": pkce_challenge["code_challenge"], "code_challenge_method": "S256", "state": "test_form_state", }, ) assert response.status_code == 302 # Extract the authorization code from the redirect URL redirect_url = response.headers["location"] parsed_url = urlparse(redirect_url) query_params = parse_qs(parsed_url.query) assert "code" in query_params assert query_params["state"][0] == "test_form_state" @pytest.mark.anyio async def test_authorization_get( self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider, pkce_challenge, ): """Test the full authorization flow.""" # 1. Register a client client_metadata = { "redirect_uris": ["https://client.example.com/callback"], "client_name": "Test Client", "grant_types": ["authorization_code", "refresh_token"], } response = await test_client.post( "/register", json=client_metadata, ) assert response.status_code == 201 client_info = response.json() # 2. Request authorization using GET with query params response = await test_client.get( "/authorize", params={ "response_type": "code", "client_id": client_info["client_id"], "redirect_uri": "https://client.example.com/callback", "code_challenge": pkce_challenge["code_challenge"], "code_challenge_method": "S256", "state": "test_state", }, ) assert response.status_code == 302 # 3. Extract the authorization code from the redirect URL redirect_url = response.headers["location"] parsed_url = urlparse(redirect_url) query_params = parse_qs(parsed_url.query) assert "code" in query_params assert query_params["state"][0] == "test_state" auth_code = query_params["code"][0] # 4. Exchange the authorization code for tokens response = await test_client.post( "/token", data={ "grant_type": "authorization_code", "client_id": client_info["client_id"], "client_secret": client_info["client_secret"], "code": auth_code, "code_verifier": pkce_challenge["code_verifier"], "redirect_uri": "https://client.example.com/callback", }, ) assert response.status_code == 200 token_response = response.json() assert "access_token" in token_response assert "token_type" in token_response assert "refresh_token" in token_response assert "expires_in" in token_response assert token_response["token_type"] == "Bearer" # 5. Verify the access token access_token = token_response["access_token"] refresh_token = token_response["refresh_token"] # Create a test client with the token auth_info = await mock_oauth_provider.load_access_token(access_token) assert auth_info assert auth_info.client_id == client_info["client_id"] assert "read" in auth_info.scopes assert "write" in auth_info.scopes # 6. Refresh the token response = await test_client.post( "/token", data={ "grant_type": "refresh_token", "client_id": client_info["client_id"], "client_secret": client_info["client_secret"], "refresh_token": refresh_token, "redirect_uri": "https://client.example.com/callback", }, ) assert response.status_code == 200 new_token_response = response.json() assert "access_token" in new_token_response assert "refresh_token" in new_token_response assert new_token_response["access_token"] != access_token assert new_token_response["refresh_token"] != refresh_token # 7. Revoke the token response = await test_client.post( "/revoke", data={ "client_id": client_info["client_id"], "client_secret": client_info["client_secret"], "token": new_token_response["access_token"], }, ) assert response.status_code == 200 # Verify that the token was revoked assert await mock_oauth_provider.load_access_token(new_token_response["access_token"]) is None @pytest.mark.anyio async def test_revoke_invalid_token(self, test_client, registered_client): """Test revoking an invalid token.""" response = await test_client.post( "/revoke", data={ "client_id": registered_client["client_id"], "client_secret": registered_client["client_secret"], "token": "invalid_token", }, ) # per RFC, this should return 200 even if the token is invalid assert response.status_code == 200 @pytest.mark.anyio async def test_revoke_with_malformed_token(self, test_client, registered_client): response = await test_client.post( "/revoke", data={ "client_id": registered_client["client_id"], "client_secret": registered_client["client_secret"], "token": 123, "token_type_hint": "asdf", }, ) assert response.status_code == 400 error_response = response.json() assert error_response["error"] == "invalid_request" assert "token_type_hint" in error_response["error_description"] @pytest.mark.anyio async def test_client_registration_disallowed_scopes(self, test_client: httpx.AsyncClient): """Test client registration with scopes that are not allowed.""" client_metadata = { "redirect_uris": ["https://client.example.com/callback"], "client_name": "Test Client", "scope": "read write profile admin", # 'admin' is not in valid_scopes } response = await test_client.post( "/register", json=client_metadata, ) assert response.status_code == 400 error_data = response.json() assert "error" in error_data assert error_data["error"] == "invalid_client_metadata" assert "scope" in error_data["error_description"] assert "admin" in error_data["error_description"] @pytest.mark.anyio async def test_client_registration_default_scopes( self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider ): client_metadata = { "redirect_uris": ["https://client.example.com/callback"], "client_name": "Test Client", # No scope specified } response = await test_client.post( "/register", json=client_metadata, ) assert response.status_code == 201 client_info = response.json() # Verify client was registered successfully assert client_info["scope"] == "read write" # Retrieve the client from the store to verify default scopes registered_client = await mock_oauth_provider.get_client(client_info["client_id"]) assert registered_client is not None # Check that default scopes were applied assert registered_client.scope == "read write" @pytest.mark.anyio async def test_client_registration_invalid_grant_type(self, test_client: httpx.AsyncClient): client_metadata = { "redirect_uris": ["https://client.example.com/callback"], "client_name": "Test Client", "grant_types": ["authorization_code"], } response = await test_client.post( "/register", json=client_metadata, ) assert response.status_code == 400 error_data = response.json() assert "error" in error_data assert error_data["error"] == "invalid_client_metadata" assert error_data["error_description"] == "grant_types must be authorization_code and refresh_token" class TestAuthorizeEndpointErrors: """Test error handling in the OAuth authorization endpoint.""" @pytest.mark.anyio async def test_authorize_missing_client_id(self, test_client: httpx.AsyncClient, pkce_challenge): """Test authorization endpoint with missing client_id. According to the OAuth2.0 spec, if client_id is missing, the server should inform the resource owner and NOT redirect. """ response = await test_client.get( "/authorize", params={ "response_type": "code", # Missing client_id "redirect_uri": "https://client.example.com/callback", "state": "test_state", "code_challenge": pkce_challenge["code_challenge"], "code_challenge_method": "S256", }, ) # Should NOT redirect, should show an error page assert response.status_code == 400 # The response should include an error message about missing client_id assert "client_id" in response.text.lower() @pytest.mark.anyio async def test_authorize_invalid_client_id(self, test_client: httpx.AsyncClient, pkce_challenge): """Test authorization endpoint with invalid client_id. According to the OAuth2.0 spec, if client_id is invalid, the server should inform the resource owner and NOT redirect. """ response = await test_client.get( "/authorize", params={ "response_type": "code", "client_id": "invalid_client_id_that_does_not_exist", "redirect_uri": "https://client.example.com/callback", "state": "test_state", "code_challenge": pkce_challenge["code_challenge"], "code_challenge_method": "S256", }, ) # Should NOT redirect, should show an error page assert response.status_code == 400 # The response should include an error message about invalid client_id assert "client" in response.text.lower() @pytest.mark.anyio async def test_authorize_missing_redirect_uri( self, test_client: httpx.AsyncClient, registered_client, pkce_challenge ): """Test authorization endpoint with missing redirect_uri. If client has only one registered redirect_uri, it can be omitted. """ response = await test_client.get( "/authorize", params={ "response_type": "code", "client_id": registered_client["client_id"], # Missing redirect_uri "code_challenge": pkce_challenge["code_challenge"], "code_challenge_method": "S256", "state": "test_state", }, ) # Should redirect to the registered redirect_uri assert response.status_code == 302, response.content redirect_url = response.headers["location"] assert redirect_url.startswith("https://client.example.com/callback") @pytest.mark.anyio async def test_authorize_invalid_redirect_uri( self, test_client: httpx.AsyncClient, registered_client, pkce_challenge ): """Test authorization endpoint with invalid redirect_uri. According to the OAuth2.0 spec, if redirect_uri is invalid or doesn't match, the server should inform the resource owner and NOT redirect. """ response = await test_client.get( "/authorize", params={ "response_type": "code", "client_id": registered_client["client_id"], # Non-matching URI "redirect_uri": "https://attacker.example.com/callback", "code_challenge": pkce_challenge["code_challenge"], "code_challenge_method": "S256", "state": "test_state", }, ) # Should NOT redirect, should show an error page assert response.status_code == 400, response.content # The response should include an error message about redirect_uri mismatch assert "redirect" in response.text.lower() @pytest.mark.anyio @pytest.mark.parametrize( "registered_client", [ { "redirect_uris": [ "https://client.example.com/callback", "https://client.example.com/other-callback", ] } ], indirect=True, ) async def test_authorize_missing_redirect_uri_multiple_registered( self, test_client: httpx.AsyncClient, registered_client, pkce_challenge ): """Test endpoint with missing redirect_uri with multiple registered URIs. If client has multiple registered redirect_uris, redirect_uri must be provided. """ response = await test_client.get( "/authorize", params={ "response_type": "code", "client_id": registered_client["client_id"], # Missing redirect_uri "code_challenge": pkce_challenge["code_challenge"], "code_challenge_method": "S256", "state": "test_state", }, ) # Should NOT redirect, should return a 400 error assert response.status_code == 400 # The response should include an error message about missing redirect_uri assert "redirect_uri" in response.text.lower() @pytest.mark.anyio async def test_authorize_unsupported_response_type( self, test_client: httpx.AsyncClient, registered_client, pkce_challenge ): """Test authorization endpoint with unsupported response_type. According to the OAuth2.0 spec, for other errors like unsupported_response_type, the server should redirect with error parameters. """ response = await test_client.get( "/authorize", params={ "response_type": "token", # Unsupported (we only support "code") "client_id": registered_client["client_id"], "redirect_uri": "https://client.example.com/callback", "code_challenge": pkce_challenge["code_challenge"], "code_challenge_method": "S256", "state": "test_state", }, ) # Should redirect with error parameters assert response.status_code == 302 redirect_url = response.headers["location"] parsed_url = urlparse(redirect_url) query_params = parse_qs(parsed_url.query) assert "error" in query_params assert query_params["error"][0] == "unsupported_response_type" # State should be preserved assert "state" in query_params assert query_params["state"][0] == "test_state" @pytest.mark.anyio async def test_authorize_missing_response_type( self, test_client: httpx.AsyncClient, registered_client, pkce_challenge ): """Test authorization endpoint with missing response_type. Missing required parameter should result in invalid_request error. """ response = await test_client.get( "/authorize", params={ # Missing response_type "client_id": registered_client["client_id"], "redirect_uri": "https://client.example.com/callback", "code_challenge": pkce_challenge["code_challenge"], "code_challenge_method": "S256", "state": "test_state", }, ) # Should redirect with error parameters assert response.status_code == 302 redirect_url = response.headers["location"] parsed_url = urlparse(redirect_url) query_params = parse_qs(parsed_url.query) assert "error" in query_params assert query_params["error"][0] == "invalid_request" # State should be preserved assert "state" in query_params assert query_params["state"][0] == "test_state" @pytest.mark.anyio async def test_authorize_missing_pkce_challenge(self, test_client: httpx.AsyncClient, registered_client): """Test authorization endpoint with missing PKCE code_challenge. Missing PKCE parameters should result in invalid_request error. """ response = await test_client.get( "/authorize", params={ "response_type": "code", "client_id": registered_client["client_id"], # Missing code_challenge "state": "test_state", # using default URL }, ) # Should redirect with error parameters assert response.status_code == 302 redirect_url = response.headers["location"] parsed_url = urlparse(redirect_url) query_params = parse_qs(parsed_url.query) assert "error" in query_params assert query_params["error"][0] == "invalid_request" # State should be preserved assert "state" in query_params assert query_params["state"][0] == "test_state" @pytest.mark.anyio async def test_authorize_invalid_scope(self, test_client: httpx.AsyncClient, registered_client, pkce_challenge): """Test authorization endpoint with invalid scope. Invalid scope should redirect with invalid_scope error. """ response = await test_client.get( "/authorize", params={ "response_type": "code", "client_id": registered_client["client_id"], "redirect_uri": "https://client.example.com/callback", "code_challenge": pkce_challenge["code_challenge"], "code_challenge_method": "S256", "scope": "invalid_scope_that_does_not_exist", "state": "test_state", }, ) # Should redirect with error parameters assert response.status_code == 302 redirect_url = response.headers["location"] parsed_url = urlparse(redirect_url) query_params = parse_qs(parsed_url.query) assert "error" in query_params assert query_params["error"][0] == "invalid_scope" # State should be preserved assert "state" in query_params assert query_params["state"][0] == "test_state"