mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-20 07:14:24 +01:00
Use 120 characters instead of 88 (#856)
This commit is contained in:
committed by
GitHub
parent
f7265f7b91
commit
543961968c
@@ -47,9 +47,7 @@ class MockOAuthProvider(OAuthAuthorizationServerProvider):
|
||||
async def register_client(self, client_info: OAuthClientInformationFull):
|
||||
self.clients[client_info.client_id] = client_info
|
||||
|
||||
async def authorize(
|
||||
self, client: OAuthClientInformationFull, params: AuthorizationParams
|
||||
) -> str:
|
||||
async def authorize(self, client: OAuthClientInformationFull, params: AuthorizationParams) -> str:
|
||||
# toy authorize implementation which just immediately generates an authorization
|
||||
# code and completes the redirect
|
||||
code = AuthorizationCode(
|
||||
@@ -63,9 +61,7 @@ class MockOAuthProvider(OAuthAuthorizationServerProvider):
|
||||
)
|
||||
self.auth_codes[code.code] = code
|
||||
|
||||
return construct_redirect_uri(
|
||||
str(params.redirect_uri), code=code.code, state=params.state
|
||||
)
|
||||
return construct_redirect_uri(str(params.redirect_uri), code=code.code, state=params.state)
|
||||
|
||||
async def load_authorization_code(
|
||||
self, client: OAuthClientInformationFull, authorization_code: str
|
||||
@@ -102,9 +98,7 @@ class MockOAuthProvider(OAuthAuthorizationServerProvider):
|
||||
refresh_token=refresh_token,
|
||||
)
|
||||
|
||||
async def load_refresh_token(
|
||||
self, client: OAuthClientInformationFull, refresh_token: str
|
||||
) -> RefreshToken | None:
|
||||
async def load_refresh_token(self, client: OAuthClientInformationFull, refresh_token: str) -> RefreshToken | None:
|
||||
old_access_token = self.refresh_tokens.get(refresh_token)
|
||||
if old_access_token is None:
|
||||
return None
|
||||
@@ -224,9 +218,7 @@ def auth_app(mock_oauth_provider):
|
||||
|
||||
@pytest.fixture
|
||||
async def test_client(auth_app):
|
||||
async with httpx.AsyncClient(
|
||||
transport=httpx.ASGITransport(app=auth_app), base_url="https://mcptest.com"
|
||||
) as client:
|
||||
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=auth_app), base_url="https://mcptest.com") as client:
|
||||
yield client
|
||||
|
||||
|
||||
@@ -261,11 +253,7 @@ async def registered_client(test_client: httpx.AsyncClient, request):
|
||||
def pkce_challenge():
|
||||
"""Create a PKCE challenge with code_verifier and code_challenge."""
|
||||
code_verifier = "some_random_verifier_string"
|
||||
code_challenge = (
|
||||
base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest())
|
||||
.decode()
|
||||
.rstrip("=")
|
||||
)
|
||||
code_challenge = base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest()).decode().rstrip("=")
|
||||
|
||||
return {"code_verifier": code_verifier, "code_challenge": code_challenge}
|
||||
|
||||
@@ -356,17 +344,13 @@ class TestAuthEndpoints:
|
||||
|
||||
metadata = response.json()
|
||||
assert metadata["issuer"] == "https://auth.example.com/"
|
||||
assert (
|
||||
metadata["authorization_endpoint"] == "https://auth.example.com/authorize"
|
||||
)
|
||||
assert metadata["authorization_endpoint"] == "https://auth.example.com/authorize"
|
||||
assert metadata["token_endpoint"] == "https://auth.example.com/token"
|
||||
assert metadata["registration_endpoint"] == "https://auth.example.com/register"
|
||||
assert metadata["revocation_endpoint"] == "https://auth.example.com/revoke"
|
||||
assert metadata["response_types_supported"] == ["code"]
|
||||
assert metadata["code_challenge_methods_supported"] == ["S256"]
|
||||
assert metadata["token_endpoint_auth_methods_supported"] == [
|
||||
"client_secret_post"
|
||||
]
|
||||
assert metadata["token_endpoint_auth_methods_supported"] == ["client_secret_post"]
|
||||
assert metadata["grant_types_supported"] == [
|
||||
"authorization_code",
|
||||
"refresh_token",
|
||||
@@ -386,14 +370,10 @@ class TestAuthEndpoints:
|
||||
)
|
||||
error_response = response.json()
|
||||
assert error_response["error"] == "invalid_request"
|
||||
assert (
|
||||
"error_description" in error_response
|
||||
) # Contains validation error messages
|
||||
assert "error_description" in error_response # Contains validation error messages
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_token_invalid_auth_code(
|
||||
self, test_client, registered_client, pkce_challenge
|
||||
):
|
||||
async def test_token_invalid_auth_code(self, test_client, registered_client, pkce_challenge):
|
||||
"""Test token endpoint error - authorization code does not exist."""
|
||||
# Try to use a non-existent authorization code
|
||||
response = await test_client.post(
|
||||
@@ -413,9 +393,7 @@ class TestAuthEndpoints:
|
||||
assert response.status_code == 400
|
||||
error_response = response.json()
|
||||
assert error_response["error"] == "invalid_grant"
|
||||
assert (
|
||||
"authorization code does not exist" in error_response["error_description"]
|
||||
)
|
||||
assert "authorization code does not exist" in error_response["error_description"]
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_token_expired_auth_code(
|
||||
@@ -458,9 +436,7 @@ class TestAuthEndpoints:
|
||||
assert response.status_code == 400
|
||||
error_response = response.json()
|
||||
assert error_response["error"] == "invalid_grant"
|
||||
assert (
|
||||
"authorization code has expired" in error_response["error_description"]
|
||||
)
|
||||
assert "authorization code has expired" in error_response["error_description"]
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.parametrize(
|
||||
@@ -475,9 +451,7 @@ class TestAuthEndpoints:
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
async def test_token_redirect_uri_mismatch(
|
||||
self, test_client, registered_client, auth_code, pkce_challenge
|
||||
):
|
||||
async def test_token_redirect_uri_mismatch(self, test_client, registered_client, auth_code, pkce_challenge):
|
||||
"""Test token endpoint error - redirect URI mismatch."""
|
||||
# Try to use the code with a different redirect URI
|
||||
response = await test_client.post(
|
||||
@@ -498,9 +472,7 @@ class TestAuthEndpoints:
|
||||
assert "redirect_uri did not match" in error_response["error_description"]
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_token_code_verifier_mismatch(
|
||||
self, test_client, registered_client, auth_code
|
||||
):
|
||||
async def test_token_code_verifier_mismatch(self, test_client, registered_client, auth_code):
|
||||
"""Test token endpoint error - PKCE code verifier mismatch."""
|
||||
# Try to use the code with an incorrect code verifier
|
||||
response = await test_client.post(
|
||||
@@ -569,9 +541,7 @@ class TestAuthEndpoints:
|
||||
|
||||
# Step 2: Time travel forward 4 hours (tokens expire in 1 hour by default)
|
||||
# Mock the time.time() function to return a value 4 hours in the future
|
||||
with unittest.mock.patch(
|
||||
"time.time", return_value=current_time + 14400
|
||||
): # 4 hours = 14400 seconds
|
||||
with unittest.mock.patch("time.time", return_value=current_time + 14400): # 4 hours = 14400 seconds
|
||||
# Try to use the refresh token which should now be considered expired
|
||||
response = await test_client.post(
|
||||
"/token",
|
||||
@@ -590,9 +560,7 @@ class TestAuthEndpoints:
|
||||
assert "refresh token has expired" in error_response["error_description"]
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_token_invalid_scope(
|
||||
self, test_client, registered_client, auth_code, pkce_challenge
|
||||
):
|
||||
async def test_token_invalid_scope(self, test_client, registered_client, auth_code, pkce_challenge):
|
||||
"""Test token endpoint error - invalid scope in refresh token request."""
|
||||
# Exchange authorization code for tokens
|
||||
token_response = await test_client.post(
|
||||
@@ -628,9 +596,7 @@ class TestAuthEndpoints:
|
||||
assert "cannot request scope" in error_response["error_description"]
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_client_registration(
|
||||
self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider
|
||||
):
|
||||
async def test_client_registration(self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider):
|
||||
"""Test client registration."""
|
||||
client_metadata = {
|
||||
"redirect_uris": ["https://client.example.com/callback"],
|
||||
@@ -656,9 +622,7 @@ class TestAuthEndpoints:
|
||||
# ) is not None
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_client_registration_missing_required_fields(
|
||||
self, test_client: httpx.AsyncClient
|
||||
):
|
||||
async def test_client_registration_missing_required_fields(self, test_client: httpx.AsyncClient):
|
||||
"""Test client registration with missing required fields."""
|
||||
# Missing redirect_uris which is a required field
|
||||
client_metadata = {
|
||||
@@ -677,9 +641,7 @@ class TestAuthEndpoints:
|
||||
assert error_data["error_description"] == "redirect_uris: Field required"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_client_registration_invalid_uri(
|
||||
self, test_client: httpx.AsyncClient
|
||||
):
|
||||
async def test_client_registration_invalid_uri(self, test_client: httpx.AsyncClient):
|
||||
"""Test client registration with invalid URIs."""
|
||||
# Invalid redirect_uri format
|
||||
client_metadata = {
|
||||
@@ -696,14 +658,11 @@ class TestAuthEndpoints:
|
||||
assert "error" in error_data
|
||||
assert error_data["error"] == "invalid_client_metadata"
|
||||
assert error_data["error_description"] == (
|
||||
"redirect_uris.0: Input should be a valid URL, "
|
||||
"relative URL without a base"
|
||||
"redirect_uris.0: Input should be a valid URL, " "relative URL without a base"
|
||||
)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_client_registration_empty_redirect_uris(
|
||||
self, test_client: httpx.AsyncClient
|
||||
):
|
||||
async def test_client_registration_empty_redirect_uris(self, test_client: httpx.AsyncClient):
|
||||
"""Test client registration with empty redirect_uris array."""
|
||||
client_metadata = {
|
||||
"redirect_uris": [], # Empty array
|
||||
@@ -719,8 +678,7 @@ class TestAuthEndpoints:
|
||||
assert "error" in error_data
|
||||
assert error_data["error"] == "invalid_client_metadata"
|
||||
assert (
|
||||
error_data["error_description"]
|
||||
== "redirect_uris: List should have at least 1 item after validation, not 0"
|
||||
error_data["error_description"] == "redirect_uris: List should have at least 1 item after validation, not 0"
|
||||
)
|
||||
|
||||
@pytest.mark.anyio
|
||||
@@ -875,12 +833,7 @@ class TestAuthEndpoints:
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify that the token was revoked
|
||||
assert (
|
||||
await mock_oauth_provider.load_access_token(
|
||||
new_token_response["access_token"]
|
||||
)
|
||||
is None
|
||||
)
|
||||
assert await mock_oauth_provider.load_access_token(new_token_response["access_token"]) is None
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_revoke_invalid_token(self, test_client, registered_client):
|
||||
@@ -913,9 +866,7 @@ class TestAuthEndpoints:
|
||||
assert "token_type_hint" in error_response["error_description"]
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_client_registration_disallowed_scopes(
|
||||
self, test_client: httpx.AsyncClient
|
||||
):
|
||||
async def test_client_registration_disallowed_scopes(self, test_client: httpx.AsyncClient):
|
||||
"""Test client registration with scopes that are not allowed."""
|
||||
client_metadata = {
|
||||
"redirect_uris": ["https://client.example.com/callback"],
|
||||
@@ -955,18 +906,14 @@ class TestAuthEndpoints:
|
||||
assert client_info["scope"] == "read write"
|
||||
|
||||
# Retrieve the client from the store to verify default scopes
|
||||
registered_client = await mock_oauth_provider.get_client(
|
||||
client_info["client_id"]
|
||||
)
|
||||
registered_client = await mock_oauth_provider.get_client(client_info["client_id"])
|
||||
assert registered_client is not None
|
||||
|
||||
# Check that default scopes were applied
|
||||
assert registered_client.scope == "read write"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_client_registration_invalid_grant_type(
|
||||
self, test_client: httpx.AsyncClient
|
||||
):
|
||||
async def test_client_registration_invalid_grant_type(self, test_client: httpx.AsyncClient):
|
||||
client_metadata = {
|
||||
"redirect_uris": ["https://client.example.com/callback"],
|
||||
"client_name": "Test Client",
|
||||
@@ -981,19 +928,14 @@ class TestAuthEndpoints:
|
||||
error_data = response.json()
|
||||
assert "error" in error_data
|
||||
assert error_data["error"] == "invalid_client_metadata"
|
||||
assert (
|
||||
error_data["error_description"]
|
||||
== "grant_types must be authorization_code and refresh_token"
|
||||
)
|
||||
assert error_data["error_description"] == "grant_types must be authorization_code and refresh_token"
|
||||
|
||||
|
||||
class TestAuthorizeEndpointErrors:
|
||||
"""Test error handling in the OAuth authorization endpoint."""
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_authorize_missing_client_id(
|
||||
self, test_client: httpx.AsyncClient, pkce_challenge
|
||||
):
|
||||
async def test_authorize_missing_client_id(self, test_client: httpx.AsyncClient, pkce_challenge):
|
||||
"""Test authorization endpoint with missing client_id.
|
||||
|
||||
According to the OAuth2.0 spec, if client_id is missing, the server should
|
||||
@@ -1017,9 +959,7 @@ class TestAuthorizeEndpointErrors:
|
||||
assert "client_id" in response.text.lower()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_authorize_invalid_client_id(
|
||||
self, test_client: httpx.AsyncClient, pkce_challenge
|
||||
):
|
||||
async def test_authorize_invalid_client_id(self, test_client: httpx.AsyncClient, pkce_challenge):
|
||||
"""Test authorization endpoint with invalid client_id.
|
||||
|
||||
According to the OAuth2.0 spec, if client_id is invalid, the server should
|
||||
@@ -1202,9 +1142,7 @@ class TestAuthorizeEndpointErrors:
|
||||
assert query_params["state"][0] == "test_state"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_authorize_missing_pkce_challenge(
|
||||
self, test_client: httpx.AsyncClient, registered_client
|
||||
):
|
||||
async def test_authorize_missing_pkce_challenge(self, test_client: httpx.AsyncClient, registered_client):
|
||||
"""Test authorization endpoint with missing PKCE code_challenge.
|
||||
|
||||
Missing PKCE parameters should result in invalid_request error.
|
||||
@@ -1233,9 +1171,7 @@ class TestAuthorizeEndpointErrors:
|
||||
assert query_params["state"][0] == "test_state"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_authorize_invalid_scope(
|
||||
self, test_client: httpx.AsyncClient, registered_client, pkce_challenge
|
||||
):
|
||||
async def test_authorize_invalid_scope(self, test_client: httpx.AsyncClient, registered_client, pkce_challenge):
|
||||
"""Test authorization endpoint with invalid scope.
|
||||
|
||||
Invalid scope should redirect with invalid_scope error.
|
||||
|
||||
Reference in New Issue
Block a user