mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 14:54:24 +01:00
relax validation (#879)
This commit is contained in:
@@ -214,7 +214,7 @@ class SimpleGitHubOAuthProvider(OAuthAuthorizationServerProvider):
|
|||||||
|
|
||||||
return OAuthToken(
|
return OAuthToken(
|
||||||
access_token=mcp_token,
|
access_token=mcp_token,
|
||||||
token_type="bearer",
|
token_type="Bearer",
|
||||||
expires_in=3600,
|
expires_in=3600,
|
||||||
scope=" ".join(authorization_code.scopes),
|
scope=" ".join(authorization_code.scopes),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field
|
from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, field_validator
|
||||||
|
|
||||||
|
|
||||||
class OAuthToken(BaseModel):
|
class OAuthToken(BaseModel):
|
||||||
@@ -9,11 +9,20 @@ class OAuthToken(BaseModel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
access_token: str
|
access_token: str
|
||||||
token_type: Literal["bearer"] = "bearer"
|
token_type: Literal["Bearer"] = "Bearer"
|
||||||
expires_in: int | None = None
|
expires_in: int | None = None
|
||||||
scope: str | None = None
|
scope: str | None = None
|
||||||
refresh_token: str | None = None
|
refresh_token: str | None = None
|
||||||
|
|
||||||
|
@field_validator("token_type", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def normalize_token_type(cls, v: str | None) -> str | None:
|
||||||
|
if isinstance(v, str):
|
||||||
|
# Bearer is title-cased in the spec, so we normalize it
|
||||||
|
# https://datatracker.ietf.org/doc/html/rfc6750#section-4
|
||||||
|
return v.title()
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
class InvalidScopeError(Exception):
|
class InvalidScopeError(Exception):
|
||||||
def __init__(self, message: str):
|
def __init__(self, message: str):
|
||||||
@@ -111,27 +120,19 @@ class OAuthMetadata(BaseModel):
|
|||||||
token_endpoint: AnyHttpUrl
|
token_endpoint: AnyHttpUrl
|
||||||
registration_endpoint: AnyHttpUrl | None = None
|
registration_endpoint: AnyHttpUrl | None = None
|
||||||
scopes_supported: list[str] | None = None
|
scopes_supported: list[str] | None = None
|
||||||
response_types_supported: list[Literal["code"]] = ["code"]
|
response_types_supported: list[str] = ["code"]
|
||||||
response_modes_supported: list[Literal["query", "fragment"]] | None = None
|
response_modes_supported: list[Literal["query", "fragment"]] | None = None
|
||||||
grant_types_supported: (
|
grant_types_supported: list[str] | None = None
|
||||||
list[Literal["authorization_code", "refresh_token"]] | None
|
token_endpoint_auth_methods_supported: list[str] | None = None
|
||||||
) = None
|
|
||||||
token_endpoint_auth_methods_supported: (
|
|
||||||
list[Literal["none", "client_secret_post"]] | None
|
|
||||||
) = None
|
|
||||||
token_endpoint_auth_signing_alg_values_supported: None = None
|
token_endpoint_auth_signing_alg_values_supported: None = None
|
||||||
service_documentation: AnyHttpUrl | None = None
|
service_documentation: AnyHttpUrl | None = None
|
||||||
ui_locales_supported: list[str] | None = None
|
ui_locales_supported: list[str] | None = None
|
||||||
op_policy_uri: AnyHttpUrl | None = None
|
op_policy_uri: AnyHttpUrl | None = None
|
||||||
op_tos_uri: AnyHttpUrl | None = None
|
op_tos_uri: AnyHttpUrl | None = None
|
||||||
revocation_endpoint: AnyHttpUrl | None = None
|
revocation_endpoint: AnyHttpUrl | None = None
|
||||||
revocation_endpoint_auth_methods_supported: (
|
revocation_endpoint_auth_methods_supported: list[str] | None = None
|
||||||
list[Literal["client_secret_post"]] | None
|
|
||||||
) = None
|
|
||||||
revocation_endpoint_auth_signing_alg_values_supported: None = None
|
revocation_endpoint_auth_signing_alg_values_supported: None = None
|
||||||
introspection_endpoint: AnyHttpUrl | None = None
|
introspection_endpoint: AnyHttpUrl | None = None
|
||||||
introspection_endpoint_auth_methods_supported: (
|
introspection_endpoint_auth_methods_supported: list[str] | None = None
|
||||||
list[Literal["client_secret_post"]] | None
|
|
||||||
) = None
|
|
||||||
introspection_endpoint_auth_signing_alg_values_supported: None = None
|
introspection_endpoint_auth_signing_alg_values_supported: None = None
|
||||||
code_challenge_methods_supported: list[Literal["S256"]] | None = None
|
code_challenge_methods_supported: list[str] | None = None
|
||||||
|
|||||||
@@ -91,7 +91,7 @@ def oauth_client_info():
|
|||||||
def oauth_token():
|
def oauth_token():
|
||||||
return OAuthToken(
|
return OAuthToken(
|
||||||
access_token="test_access_token",
|
access_token="test_access_token",
|
||||||
token_type="bearer",
|
token_type="Bearer",
|
||||||
expires_in=3600,
|
expires_in=3600,
|
||||||
refresh_token="test_refresh_token",
|
refresh_token="test_refresh_token",
|
||||||
scope="read write",
|
scope="read write",
|
||||||
@@ -143,7 +143,8 @@ class TestOAuthClientProvider:
|
|||||||
verifiers = {oauth_provider._generate_code_verifier() for _ in range(10)}
|
verifiers = {oauth_provider._generate_code_verifier() for _ in range(10)}
|
||||||
assert len(verifiers) == 10
|
assert len(verifiers) == 10
|
||||||
|
|
||||||
def test_generate_code_challenge(self, oauth_provider):
|
@pytest.mark.anyio
|
||||||
|
async def test_generate_code_challenge(self, oauth_provider):
|
||||||
"""Test PKCE code challenge generation."""
|
"""Test PKCE code challenge generation."""
|
||||||
verifier = "test_code_verifier_123"
|
verifier = "test_code_verifier_123"
|
||||||
challenge = oauth_provider._generate_code_challenge(verifier)
|
challenge = oauth_provider._generate_code_challenge(verifier)
|
||||||
@@ -161,7 +162,8 @@ class TestOAuthClientProvider:
|
|||||||
assert "+" not in challenge
|
assert "+" not in challenge
|
||||||
assert "/" not in challenge
|
assert "/" not in challenge
|
||||||
|
|
||||||
def test_get_authorization_base_url(self, oauth_provider):
|
@pytest.mark.anyio
|
||||||
|
async def test_get_authorization_base_url(self, oauth_provider):
|
||||||
"""Test authorization base URL extraction."""
|
"""Test authorization base URL extraction."""
|
||||||
# Test with path
|
# Test with path
|
||||||
assert (
|
assert (
|
||||||
@@ -348,11 +350,13 @@ class TestOAuthClientProvider:
|
|||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_has_valid_token_no_token(self, oauth_provider):
|
@pytest.mark.anyio
|
||||||
|
async def test_has_valid_token_no_token(self, oauth_provider):
|
||||||
"""Test token validation with no token."""
|
"""Test token validation with no token."""
|
||||||
assert not oauth_provider._has_valid_token()
|
assert not oauth_provider._has_valid_token()
|
||||||
|
|
||||||
def test_has_valid_token_valid(self, oauth_provider, oauth_token):
|
@pytest.mark.anyio
|
||||||
|
async def test_has_valid_token_valid(self, oauth_provider, oauth_token):
|
||||||
"""Test token validation with valid token."""
|
"""Test token validation with valid token."""
|
||||||
oauth_provider._current_tokens = oauth_token
|
oauth_provider._current_tokens = oauth_token
|
||||||
oauth_provider._token_expiry_time = time.time() + 3600 # Future expiry
|
oauth_provider._token_expiry_time = time.time() + 3600 # Future expiry
|
||||||
@@ -370,7 +374,7 @@ class TestOAuthClientProvider:
|
|||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_validate_token_scopes_no_scope(self, oauth_provider):
|
async def test_validate_token_scopes_no_scope(self, oauth_provider):
|
||||||
"""Test scope validation with no scope returned."""
|
"""Test scope validation with no scope returned."""
|
||||||
token = OAuthToken(access_token="test", token_type="bearer")
|
token = OAuthToken(access_token="test", token_type="Bearer")
|
||||||
|
|
||||||
# Should not raise exception
|
# Should not raise exception
|
||||||
await oauth_provider._validate_token_scopes(token)
|
await oauth_provider._validate_token_scopes(token)
|
||||||
@@ -381,7 +385,7 @@ class TestOAuthClientProvider:
|
|||||||
oauth_provider.client_metadata = client_metadata
|
oauth_provider.client_metadata = client_metadata
|
||||||
token = OAuthToken(
|
token = OAuthToken(
|
||||||
access_token="test",
|
access_token="test",
|
||||||
token_type="bearer",
|
token_type="Bearer",
|
||||||
scope="read write",
|
scope="read write",
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -394,7 +398,7 @@ class TestOAuthClientProvider:
|
|||||||
oauth_provider.client_metadata = client_metadata
|
oauth_provider.client_metadata = client_metadata
|
||||||
token = OAuthToken(
|
token = OAuthToken(
|
||||||
access_token="test",
|
access_token="test",
|
||||||
token_type="bearer",
|
token_type="Bearer",
|
||||||
scope="read",
|
scope="read",
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -409,7 +413,7 @@ class TestOAuthClientProvider:
|
|||||||
oauth_provider.client_metadata = client_metadata
|
oauth_provider.client_metadata = client_metadata
|
||||||
token = OAuthToken(
|
token = OAuthToken(
|
||||||
access_token="test",
|
access_token="test",
|
||||||
token_type="bearer",
|
token_type="Bearer",
|
||||||
scope="read write admin", # Includes unauthorized "admin"
|
scope="read write admin", # Includes unauthorized "admin"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -423,7 +427,7 @@ class TestOAuthClientProvider:
|
|||||||
oauth_provider.client_metadata.scope = None
|
oauth_provider.client_metadata.scope = None
|
||||||
token = OAuthToken(
|
token = OAuthToken(
|
||||||
access_token="test",
|
access_token="test",
|
||||||
token_type="bearer",
|
token_type="Bearer",
|
||||||
scope="admin super",
|
scope="admin super",
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -530,7 +534,7 @@ class TestOAuthClientProvider:
|
|||||||
|
|
||||||
new_token = OAuthToken(
|
new_token = OAuthToken(
|
||||||
access_token="new_access_token",
|
access_token="new_access_token",
|
||||||
token_type="bearer",
|
token_type="Bearer",
|
||||||
expires_in=3600,
|
expires_in=3600,
|
||||||
refresh_token="new_refresh_token",
|
refresh_token="new_refresh_token",
|
||||||
scope="read write",
|
scope="read write",
|
||||||
@@ -563,7 +567,7 @@ class TestOAuthClientProvider:
|
|||||||
"""Test token refresh with no refresh token."""
|
"""Test token refresh with no refresh token."""
|
||||||
oauth_provider._current_tokens = OAuthToken(
|
oauth_provider._current_tokens = OAuthToken(
|
||||||
access_token="test",
|
access_token="test",
|
||||||
token_type="bearer",
|
token_type="Bearer",
|
||||||
# No refresh_token
|
# No refresh_token
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -756,7 +760,8 @@ class TestOAuthClientProvider:
|
|||||||
# No Authorization header should be added if no token
|
# No Authorization header should be added if no token
|
||||||
assert "Authorization" not in updated_request.headers
|
assert "Authorization" not in updated_request.headers
|
||||||
|
|
||||||
def test_scope_priority_client_metadata_first(
|
@pytest.mark.anyio
|
||||||
|
async def test_scope_priority_client_metadata_first(
|
||||||
self, oauth_provider, oauth_client_info
|
self, oauth_provider, oauth_client_info
|
||||||
):
|
):
|
||||||
"""Test that client metadata scope takes priority."""
|
"""Test that client metadata scope takes priority."""
|
||||||
@@ -785,7 +790,8 @@ class TestOAuthClientProvider:
|
|||||||
|
|
||||||
assert auth_params["scope"] == "read write"
|
assert auth_params["scope"] == "read write"
|
||||||
|
|
||||||
def test_scope_priority_no_client_metadata_scope(
|
@pytest.mark.anyio
|
||||||
|
async def test_scope_priority_no_client_metadata_scope(
|
||||||
self, oauth_provider, oauth_client_info
|
self, oauth_provider, oauth_client_info
|
||||||
):
|
):
|
||||||
"""Test that no scope parameter is set when client metadata has no scope."""
|
"""Test that no scope parameter is set when client metadata has no scope."""
|
||||||
|
|||||||
@@ -96,7 +96,7 @@ class MockOAuthProvider(OAuthAuthorizationServerProvider):
|
|||||||
|
|
||||||
return OAuthToken(
|
return OAuthToken(
|
||||||
access_token=access_token,
|
access_token=access_token,
|
||||||
token_type="bearer",
|
token_type="Bearer",
|
||||||
expires_in=3600,
|
expires_in=3600,
|
||||||
scope="read write",
|
scope="read write",
|
||||||
refresh_token=refresh_token,
|
refresh_token=refresh_token,
|
||||||
@@ -160,7 +160,7 @@ class MockOAuthProvider(OAuthAuthorizationServerProvider):
|
|||||||
|
|
||||||
return OAuthToken(
|
return OAuthToken(
|
||||||
access_token=new_access_token,
|
access_token=new_access_token,
|
||||||
token_type="bearer",
|
token_type="Bearer",
|
||||||
expires_in=3600,
|
expires_in=3600,
|
||||||
scope=" ".join(scopes) if scopes else " ".join(token_info.scopes),
|
scope=" ".join(scopes) if scopes else " ".join(token_info.scopes),
|
||||||
refresh_token=new_refresh_token,
|
refresh_token=new_refresh_token,
|
||||||
@@ -831,7 +831,7 @@ class TestAuthEndpoints:
|
|||||||
assert "token_type" in token_response
|
assert "token_type" in token_response
|
||||||
assert "refresh_token" in token_response
|
assert "refresh_token" in token_response
|
||||||
assert "expires_in" in token_response
|
assert "expires_in" in token_response
|
||||||
assert token_response["token_type"] == "bearer"
|
assert token_response["token_type"] == "Bearer"
|
||||||
|
|
||||||
# 5. Verify the access token
|
# 5. Verify the access token
|
||||||
access_token = token_response["access_token"]
|
access_token = token_response["access_token"]
|
||||||
|
|||||||
Reference in New Issue
Block a user