mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 14:54:24 +01:00
Add support for serverside oauth (#255)
Co-authored-by: David Soria Parra <davidsp@anthropic.com> Co-authored-by: Basil Hosmer <basil@anthropic.com> Co-authored-by: ihrpr <inna@anthropic.com>
This commit is contained in:
122
tests/server/auth/middleware/test_auth_context.py
Normal file
122
tests/server/auth/middleware/test_auth_context.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""
|
||||
Tests for the AuthContext middleware components.
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
import pytest
|
||||
from starlette.types import Message, Receive, Scope, Send
|
||||
|
||||
from mcp.server.auth.middleware.auth_context import (
|
||||
AuthContextMiddleware,
|
||||
auth_context_var,
|
||||
get_access_token,
|
||||
)
|
||||
from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser
|
||||
from mcp.server.auth.provider import AccessToken
|
||||
|
||||
|
||||
class MockApp:
|
||||
"""Mock ASGI app for testing."""
|
||||
|
||||
def __init__(self):
|
||||
self.called = False
|
||||
self.scope: Scope | None = None
|
||||
self.receive: Receive | None = None
|
||||
self.send: Send | None = None
|
||||
self.access_token_during_call: AccessToken | None = None
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
self.called = True
|
||||
self.scope = scope
|
||||
self.receive = receive
|
||||
self.send = send
|
||||
# Check the context during the call
|
||||
self.access_token_during_call = get_access_token()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def valid_access_token() -> AccessToken:
|
||||
"""Create a valid access token."""
|
||||
return AccessToken(
|
||||
token="valid_token",
|
||||
client_id="test_client",
|
||||
scopes=["read", "write"],
|
||||
expires_at=int(time.time()) + 3600, # 1 hour from now
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
class TestAuthContextMiddleware:
|
||||
"""Tests for the AuthContextMiddleware class."""
|
||||
|
||||
async def test_with_authenticated_user(self, valid_access_token: AccessToken):
|
||||
"""Test middleware with an authenticated user in scope."""
|
||||
app = MockApp()
|
||||
middleware = AuthContextMiddleware(app)
|
||||
|
||||
# Create an authenticated user
|
||||
user = AuthenticatedUser(valid_access_token)
|
||||
|
||||
scope: Scope = {"type": "http", "user": user}
|
||||
|
||||
# Create dummy async functions for receive and send
|
||||
async def receive() -> Message:
|
||||
return {"type": "http.request"}
|
||||
|
||||
async def send(message: Message) -> None:
|
||||
pass
|
||||
|
||||
# Verify context is empty before middleware
|
||||
assert auth_context_var.get() is None
|
||||
assert get_access_token() is None
|
||||
|
||||
# Run the middleware
|
||||
await middleware(scope, receive, send)
|
||||
|
||||
# Verify the app was called
|
||||
assert app.called
|
||||
assert app.scope == scope
|
||||
assert app.receive == receive
|
||||
assert app.send == send
|
||||
|
||||
# Verify the access token was available during the call
|
||||
assert app.access_token_during_call == valid_access_token
|
||||
|
||||
# Verify context is reset after middleware
|
||||
assert auth_context_var.get() is None
|
||||
assert get_access_token() is None
|
||||
|
||||
async def test_with_no_user(self):
|
||||
"""Test middleware with no user in scope."""
|
||||
app = MockApp()
|
||||
middleware = AuthContextMiddleware(app)
|
||||
|
||||
scope: Scope = {"type": "http"} # No user
|
||||
|
||||
# Create dummy async functions for receive and send
|
||||
async def receive() -> Message:
|
||||
return {"type": "http.request"}
|
||||
|
||||
async def send(message: Message) -> None:
|
||||
pass
|
||||
|
||||
# Verify context is empty before middleware
|
||||
assert auth_context_var.get() is None
|
||||
assert get_access_token() is None
|
||||
|
||||
# Run the middleware
|
||||
await middleware(scope, receive, send)
|
||||
|
||||
# Verify the app was called
|
||||
assert app.called
|
||||
assert app.scope == scope
|
||||
assert app.receive == receive
|
||||
assert app.send == send
|
||||
|
||||
# Verify the access token was not available during the call
|
||||
assert app.access_token_during_call is None
|
||||
|
||||
# Verify context is still empty after middleware
|
||||
assert auth_context_var.get() is None
|
||||
assert get_access_token() is None
|
||||
391
tests/server/auth/middleware/test_bearer_auth.py
Normal file
391
tests/server/auth/middleware/test_bearer_auth.py
Normal file
@@ -0,0 +1,391 @@
|
||||
"""
|
||||
Tests for the BearerAuth middleware components.
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Any, cast
|
||||
|
||||
import pytest
|
||||
from starlette.authentication import AuthCredentials
|
||||
from starlette.exceptions import HTTPException
|
||||
from starlette.requests import Request
|
||||
from starlette.types import Message, Receive, Scope, Send
|
||||
|
||||
from mcp.server.auth.middleware.bearer_auth import (
|
||||
AuthenticatedUser,
|
||||
BearerAuthBackend,
|
||||
RequireAuthMiddleware,
|
||||
)
|
||||
from mcp.server.auth.provider import (
|
||||
AccessToken,
|
||||
OAuthAuthorizationServerProvider,
|
||||
)
|
||||
|
||||
|
||||
class MockOAuthProvider:
|
||||
"""Mock OAuth provider for testing.
|
||||
|
||||
This is a simplified version that only implements the methods needed for testing
|
||||
the BearerAuthMiddleware components.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.tokens = {} # token -> AccessToken
|
||||
|
||||
def add_token(self, token: str, access_token: AccessToken) -> None:
|
||||
"""Add a token to the provider."""
|
||||
self.tokens[token] = access_token
|
||||
|
||||
async def load_access_token(self, token: str) -> AccessToken | None:
|
||||
"""Load an access token."""
|
||||
return self.tokens.get(token)
|
||||
|
||||
|
||||
def add_token_to_provider(
|
||||
provider: OAuthAuthorizationServerProvider[Any, Any, Any],
|
||||
token: str,
|
||||
access_token: AccessToken,
|
||||
) -> None:
|
||||
"""Helper function to add a token to a provider.
|
||||
|
||||
This is used to work around type checking issues with our mock provider.
|
||||
"""
|
||||
# We know this is actually a MockOAuthProvider
|
||||
mock_provider = cast(MockOAuthProvider, provider)
|
||||
mock_provider.add_token(token, access_token)
|
||||
|
||||
|
||||
class MockApp:
|
||||
"""Mock ASGI app for testing."""
|
||||
|
||||
def __init__(self):
|
||||
self.called = False
|
||||
self.scope: Scope | None = None
|
||||
self.receive: Receive | None = None
|
||||
self.send: Send | None = None
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
self.called = True
|
||||
self.scope = scope
|
||||
self.receive = receive
|
||||
self.send = send
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_oauth_provider() -> OAuthAuthorizationServerProvider[Any, Any, Any]:
|
||||
"""Create a mock OAuth provider."""
|
||||
# Use type casting to satisfy the type checker
|
||||
return cast(OAuthAuthorizationServerProvider[Any, Any, Any], MockOAuthProvider())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def valid_access_token() -> AccessToken:
|
||||
"""Create a valid access token."""
|
||||
return AccessToken(
|
||||
token="valid_token",
|
||||
client_id="test_client",
|
||||
scopes=["read", "write"],
|
||||
expires_at=int(time.time()) + 3600, # 1 hour from now
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expired_access_token() -> AccessToken:
|
||||
"""Create an expired access token."""
|
||||
return AccessToken(
|
||||
token="expired_token",
|
||||
client_id="test_client",
|
||||
scopes=["read"],
|
||||
expires_at=int(time.time()) - 3600, # 1 hour ago
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def no_expiry_access_token() -> AccessToken:
|
||||
"""Create an access token with no expiry."""
|
||||
return AccessToken(
|
||||
token="no_expiry_token",
|
||||
client_id="test_client",
|
||||
scopes=["read", "write"],
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
class TestBearerAuthBackend:
|
||||
"""Tests for the BearerAuthBackend class."""
|
||||
|
||||
async def test_no_auth_header(
|
||||
self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any]
|
||||
):
|
||||
"""Test authentication with no Authorization header."""
|
||||
backend = BearerAuthBackend(provider=mock_oauth_provider)
|
||||
request = Request({"type": "http", "headers": []})
|
||||
result = await backend.authenticate(request)
|
||||
assert result is None
|
||||
|
||||
async def test_non_bearer_auth_header(
|
||||
self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any]
|
||||
):
|
||||
"""Test authentication with non-Bearer Authorization header."""
|
||||
backend = BearerAuthBackend(provider=mock_oauth_provider)
|
||||
request = Request(
|
||||
{
|
||||
"type": "http",
|
||||
"headers": [(b"authorization", b"Basic dXNlcjpwYXNz")],
|
||||
}
|
||||
)
|
||||
result = await backend.authenticate(request)
|
||||
assert result is None
|
||||
|
||||
async def test_invalid_token(
|
||||
self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any]
|
||||
):
|
||||
"""Test authentication with invalid token."""
|
||||
backend = BearerAuthBackend(provider=mock_oauth_provider)
|
||||
request = Request(
|
||||
{
|
||||
"type": "http",
|
||||
"headers": [(b"authorization", b"Bearer invalid_token")],
|
||||
}
|
||||
)
|
||||
result = await backend.authenticate(request)
|
||||
assert result is None
|
||||
|
||||
async def test_expired_token(
|
||||
self,
|
||||
mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any],
|
||||
expired_access_token: AccessToken,
|
||||
):
|
||||
"""Test authentication with expired token."""
|
||||
backend = BearerAuthBackend(provider=mock_oauth_provider)
|
||||
add_token_to_provider(
|
||||
mock_oauth_provider, "expired_token", expired_access_token
|
||||
)
|
||||
request = Request(
|
||||
{
|
||||
"type": "http",
|
||||
"headers": [(b"authorization", b"Bearer expired_token")],
|
||||
}
|
||||
)
|
||||
result = await backend.authenticate(request)
|
||||
assert result is None
|
||||
|
||||
async def test_valid_token(
|
||||
self,
|
||||
mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any],
|
||||
valid_access_token: AccessToken,
|
||||
):
|
||||
"""Test authentication with valid token."""
|
||||
backend = BearerAuthBackend(provider=mock_oauth_provider)
|
||||
add_token_to_provider(mock_oauth_provider, "valid_token", valid_access_token)
|
||||
request = Request(
|
||||
{
|
||||
"type": "http",
|
||||
"headers": [(b"authorization", b"Bearer valid_token")],
|
||||
}
|
||||
)
|
||||
result = await backend.authenticate(request)
|
||||
assert result is not None
|
||||
credentials, user = result
|
||||
assert isinstance(credentials, AuthCredentials)
|
||||
assert isinstance(user, AuthenticatedUser)
|
||||
assert credentials.scopes == ["read", "write"]
|
||||
assert user.display_name == "test_client"
|
||||
assert user.access_token == valid_access_token
|
||||
assert user.scopes == ["read", "write"]
|
||||
|
||||
async def test_token_without_expiry(
|
||||
self,
|
||||
mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any],
|
||||
no_expiry_access_token: AccessToken,
|
||||
):
|
||||
"""Test authentication with token that has no expiry."""
|
||||
backend = BearerAuthBackend(provider=mock_oauth_provider)
|
||||
add_token_to_provider(
|
||||
mock_oauth_provider, "no_expiry_token", no_expiry_access_token
|
||||
)
|
||||
request = Request(
|
||||
{
|
||||
"type": "http",
|
||||
"headers": [(b"authorization", b"Bearer no_expiry_token")],
|
||||
}
|
||||
)
|
||||
result = await backend.authenticate(request)
|
||||
assert result is not None
|
||||
credentials, user = result
|
||||
assert isinstance(credentials, AuthCredentials)
|
||||
assert isinstance(user, AuthenticatedUser)
|
||||
assert credentials.scopes == ["read", "write"]
|
||||
assert user.display_name == "test_client"
|
||||
assert user.access_token == no_expiry_access_token
|
||||
assert user.scopes == ["read", "write"]
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
class TestRequireAuthMiddleware:
|
||||
"""Tests for the RequireAuthMiddleware class."""
|
||||
|
||||
async def test_no_user(self):
|
||||
"""Test middleware with no user in scope."""
|
||||
app = MockApp()
|
||||
middleware = RequireAuthMiddleware(app, required_scopes=["read"])
|
||||
scope: Scope = {"type": "http"}
|
||||
|
||||
# Create dummy async functions for receive and send
|
||||
async def receive() -> Message:
|
||||
return {"type": "http.request"}
|
||||
|
||||
async def send(message: Message) -> None:
|
||||
pass
|
||||
|
||||
with pytest.raises(HTTPException) as excinfo:
|
||||
await middleware(scope, receive, send)
|
||||
|
||||
assert excinfo.value.status_code == 401
|
||||
assert excinfo.value.detail == "Unauthorized"
|
||||
assert not app.called
|
||||
|
||||
async def test_non_authenticated_user(self):
|
||||
"""Test middleware with non-authenticated user in scope."""
|
||||
app = MockApp()
|
||||
middleware = RequireAuthMiddleware(app, required_scopes=["read"])
|
||||
scope: Scope = {"type": "http", "user": object()}
|
||||
|
||||
# Create dummy async functions for receive and send
|
||||
async def receive() -> Message:
|
||||
return {"type": "http.request"}
|
||||
|
||||
async def send(message: Message) -> None:
|
||||
pass
|
||||
|
||||
with pytest.raises(HTTPException) as excinfo:
|
||||
await middleware(scope, receive, send)
|
||||
|
||||
assert excinfo.value.status_code == 401
|
||||
assert excinfo.value.detail == "Unauthorized"
|
||||
assert not app.called
|
||||
|
||||
async def test_missing_required_scope(self, valid_access_token: AccessToken):
|
||||
"""Test middleware with user missing required scope."""
|
||||
app = MockApp()
|
||||
middleware = RequireAuthMiddleware(app, required_scopes=["admin"])
|
||||
|
||||
# Create a user with read/write scopes but not admin
|
||||
user = AuthenticatedUser(valid_access_token)
|
||||
auth = AuthCredentials(["read", "write"])
|
||||
|
||||
scope: Scope = {"type": "http", "user": user, "auth": auth}
|
||||
|
||||
# Create dummy async functions for receive and send
|
||||
async def receive() -> Message:
|
||||
return {"type": "http.request"}
|
||||
|
||||
async def send(message: Message) -> None:
|
||||
pass
|
||||
|
||||
with pytest.raises(HTTPException) as excinfo:
|
||||
await middleware(scope, receive, send)
|
||||
|
||||
assert excinfo.value.status_code == 403
|
||||
assert excinfo.value.detail == "Insufficient scope"
|
||||
assert not app.called
|
||||
|
||||
async def test_no_auth_credentials(self, valid_access_token: AccessToken):
|
||||
"""Test middleware with no auth credentials in scope."""
|
||||
app = MockApp()
|
||||
middleware = RequireAuthMiddleware(app, required_scopes=["read"])
|
||||
|
||||
# Create a user with read/write scopes
|
||||
user = AuthenticatedUser(valid_access_token)
|
||||
|
||||
scope: Scope = {"type": "http", "user": user} # No auth credentials
|
||||
|
||||
# Create dummy async functions for receive and send
|
||||
async def receive() -> Message:
|
||||
return {"type": "http.request"}
|
||||
|
||||
async def send(message: Message) -> None:
|
||||
pass
|
||||
|
||||
with pytest.raises(HTTPException) as excinfo:
|
||||
await middleware(scope, receive, send)
|
||||
|
||||
assert excinfo.value.status_code == 403
|
||||
assert excinfo.value.detail == "Insufficient scope"
|
||||
assert not app.called
|
||||
|
||||
async def test_has_required_scopes(self, valid_access_token: AccessToken):
|
||||
"""Test middleware with user having all required scopes."""
|
||||
app = MockApp()
|
||||
middleware = RequireAuthMiddleware(app, required_scopes=["read"])
|
||||
|
||||
# Create a user with read/write scopes
|
||||
user = AuthenticatedUser(valid_access_token)
|
||||
auth = AuthCredentials(["read", "write"])
|
||||
|
||||
scope: Scope = {"type": "http", "user": user, "auth": auth}
|
||||
|
||||
# Create dummy async functions for receive and send
|
||||
async def receive() -> Message:
|
||||
return {"type": "http.request"}
|
||||
|
||||
async def send(message: Message) -> None:
|
||||
pass
|
||||
|
||||
await middleware(scope, receive, send)
|
||||
|
||||
assert app.called
|
||||
assert app.scope == scope
|
||||
assert app.receive == receive
|
||||
assert app.send == send
|
||||
|
||||
async def test_multiple_required_scopes(self, valid_access_token: AccessToken):
|
||||
"""Test middleware with multiple required scopes."""
|
||||
app = MockApp()
|
||||
middleware = RequireAuthMiddleware(app, required_scopes=["read", "write"])
|
||||
|
||||
# Create a user with read/write scopes
|
||||
user = AuthenticatedUser(valid_access_token)
|
||||
auth = AuthCredentials(["read", "write"])
|
||||
|
||||
scope: Scope = {"type": "http", "user": user, "auth": auth}
|
||||
|
||||
# Create dummy async functions for receive and send
|
||||
async def receive() -> Message:
|
||||
return {"type": "http.request"}
|
||||
|
||||
async def send(message: Message) -> None:
|
||||
pass
|
||||
|
||||
await middleware(scope, receive, send)
|
||||
|
||||
assert app.called
|
||||
assert app.scope == scope
|
||||
assert app.receive == receive
|
||||
assert app.send == send
|
||||
|
||||
async def test_no_required_scopes(self, valid_access_token: AccessToken):
|
||||
"""Test middleware with no required scopes."""
|
||||
app = MockApp()
|
||||
middleware = RequireAuthMiddleware(app, required_scopes=[])
|
||||
|
||||
# Create a user with read/write scopes
|
||||
user = AuthenticatedUser(valid_access_token)
|
||||
auth = AuthCredentials(["read", "write"])
|
||||
|
||||
scope: Scope = {"type": "http", "user": user, "auth": auth}
|
||||
|
||||
# Create dummy async functions for receive and send
|
||||
async def receive() -> Message:
|
||||
return {"type": "http.request"}
|
||||
|
||||
async def send(message: Message) -> None:
|
||||
pass
|
||||
|
||||
await middleware(scope, receive, send)
|
||||
|
||||
assert app.called
|
||||
assert app.scope == scope
|
||||
assert app.receive == receive
|
||||
assert app.send == send
|
||||
294
tests/server/auth/test_error_handling.py
Normal file
294
tests/server/auth/test_error_handling.py
Normal file
@@ -0,0 +1,294 @@
|
||||
"""
|
||||
Tests for OAuth error handling in the auth handlers.
|
||||
"""
|
||||
|
||||
import unittest.mock
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from httpx import ASGITransport
|
||||
from pydantic import AnyHttpUrl
|
||||
from starlette.applications import Starlette
|
||||
|
||||
from mcp.server.auth.provider import (
|
||||
AuthorizeError,
|
||||
RegistrationError,
|
||||
TokenError,
|
||||
)
|
||||
from mcp.server.auth.routes import create_auth_routes
|
||||
from tests.server.fastmcp.auth.test_auth_integration import (
|
||||
MockOAuthProvider,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def oauth_provider():
|
||||
"""Return a MockOAuthProvider instance that can be configured to raise errors."""
|
||||
return MockOAuthProvider()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app(oauth_provider):
|
||||
from mcp.server.auth.settings import ClientRegistrationOptions, RevocationOptions
|
||||
|
||||
# Enable client registration
|
||||
client_registration_options = ClientRegistrationOptions(enabled=True)
|
||||
revocation_options = RevocationOptions(enabled=True)
|
||||
|
||||
# Create auth routes
|
||||
auth_routes = create_auth_routes(
|
||||
oauth_provider,
|
||||
issuer_url=AnyHttpUrl("http://localhost"),
|
||||
client_registration_options=client_registration_options,
|
||||
revocation_options=revocation_options,
|
||||
)
|
||||
|
||||
# Create Starlette app with routes directly
|
||||
return Starlette(routes=auth_routes)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(app):
|
||||
transport = ASGITransport(app=app)
|
||||
# Use base_url without a path since routes are directly on the app
|
||||
return httpx.AsyncClient(transport=transport, base_url="http://localhost")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pkce_challenge():
|
||||
"""Create a PKCE challenge with code_verifier and code_challenge."""
|
||||
import base64
|
||||
import hashlib
|
||||
import secrets
|
||||
|
||||
# Generate a code verifier
|
||||
code_verifier = secrets.token_urlsafe(64)[:128]
|
||||
|
||||
# Create code challenge using S256 method
|
||||
code_verifier_bytes = code_verifier.encode("ascii")
|
||||
sha256 = hashlib.sha256(code_verifier_bytes).digest()
|
||||
code_challenge = base64.urlsafe_b64encode(sha256).decode().rstrip("=")
|
||||
|
||||
return {"code_verifier": code_verifier, "code_challenge": code_challenge}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def registered_client(client):
|
||||
"""Create and register a test client."""
|
||||
# Default client metadata
|
||||
client_metadata = {
|
||||
"redirect_uris": ["https://client.example.com/callback"],
|
||||
"token_endpoint_auth_method": "client_secret_post",
|
||||
"grant_types": ["authorization_code", "refresh_token"],
|
||||
"response_types": ["code"],
|
||||
"client_name": "Test Client",
|
||||
}
|
||||
|
||||
response = await client.post("/register", json=client_metadata)
|
||||
assert response.status_code == 201, f"Failed to register client: {response.content}"
|
||||
|
||||
client_info = response.json()
|
||||
return client_info
|
||||
|
||||
|
||||
class TestRegistrationErrorHandling:
|
||||
@pytest.mark.anyio
|
||||
async def test_registration_error_handling(self, client, oauth_provider):
|
||||
# Mock the register_client method to raise a registration error
|
||||
with unittest.mock.patch.object(
|
||||
oauth_provider,
|
||||
"register_client",
|
||||
side_effect=RegistrationError(
|
||||
error="invalid_redirect_uri",
|
||||
error_description="The redirect URI is invalid",
|
||||
),
|
||||
):
|
||||
# Prepare a client registration request
|
||||
client_data = {
|
||||
"redirect_uris": ["https://client.example.com/callback"],
|
||||
"token_endpoint_auth_method": "client_secret_post",
|
||||
"grant_types": ["authorization_code", "refresh_token"],
|
||||
"response_types": ["code"],
|
||||
"client_name": "Test Client",
|
||||
}
|
||||
|
||||
# Send the registration request
|
||||
response = await client.post(
|
||||
"/register",
|
||||
json=client_data,
|
||||
)
|
||||
|
||||
# Verify the response
|
||||
assert response.status_code == 400, response.content
|
||||
data = response.json()
|
||||
assert data["error"] == "invalid_redirect_uri"
|
||||
assert data["error_description"] == "The redirect URI is invalid"
|
||||
|
||||
|
||||
class TestAuthorizeErrorHandling:
|
||||
@pytest.mark.anyio
|
||||
async def test_authorize_error_handling(
|
||||
self, client, oauth_provider, registered_client, pkce_challenge
|
||||
):
|
||||
# Mock the authorize method to raise an authorize error
|
||||
with unittest.mock.patch.object(
|
||||
oauth_provider,
|
||||
"authorize",
|
||||
side_effect=AuthorizeError(
|
||||
error="access_denied", error_description="The user denied the request"
|
||||
),
|
||||
):
|
||||
# Register the client
|
||||
client_id = registered_client["client_id"]
|
||||
redirect_uri = registered_client["redirect_uris"][0]
|
||||
|
||||
# Prepare an authorization request
|
||||
params = {
|
||||
"client_id": client_id,
|
||||
"redirect_uri": redirect_uri,
|
||||
"response_type": "code",
|
||||
"code_challenge": pkce_challenge["code_challenge"],
|
||||
"code_challenge_method": "S256",
|
||||
"state": "test_state",
|
||||
}
|
||||
|
||||
# Send the authorization request
|
||||
response = await client.get("/authorize", params=params)
|
||||
|
||||
# Verify the response is a redirect with error parameters
|
||||
assert response.status_code == 302
|
||||
redirect_url = response.headers["location"]
|
||||
parsed_url = urlparse(redirect_url)
|
||||
query_params = parse_qs(parsed_url.query)
|
||||
|
||||
assert query_params["error"][0] == "access_denied"
|
||||
assert "error_description" in query_params
|
||||
assert query_params["state"][0] == "test_state"
|
||||
|
||||
|
||||
class TestTokenErrorHandling:
|
||||
@pytest.mark.anyio
|
||||
async def test_token_error_handling_auth_code(
|
||||
self, client, oauth_provider, registered_client, pkce_challenge
|
||||
):
|
||||
# Register the client and get an auth code
|
||||
client_id = registered_client["client_id"]
|
||||
client_secret = registered_client["client_secret"]
|
||||
redirect_uri = registered_client["redirect_uris"][0]
|
||||
|
||||
# First get an authorization code
|
||||
auth_response = await client.get(
|
||||
"/authorize",
|
||||
params={
|
||||
"client_id": client_id,
|
||||
"redirect_uri": redirect_uri,
|
||||
"response_type": "code",
|
||||
"code_challenge": pkce_challenge["code_challenge"],
|
||||
"code_challenge_method": "S256",
|
||||
"state": "test_state",
|
||||
},
|
||||
)
|
||||
|
||||
redirect_url = auth_response.headers["location"]
|
||||
parsed_url = urlparse(redirect_url)
|
||||
query_params = parse_qs(parsed_url.query)
|
||||
code = query_params["code"][0]
|
||||
|
||||
# Mock the exchange_authorization_code method to raise a token error
|
||||
with unittest.mock.patch.object(
|
||||
oauth_provider,
|
||||
"exchange_authorization_code",
|
||||
side_effect=TokenError(
|
||||
error="invalid_grant",
|
||||
error_description="The authorization code is invalid",
|
||||
),
|
||||
):
|
||||
# Try to exchange the code for tokens
|
||||
token_response = await client.post(
|
||||
"/token",
|
||||
data={
|
||||
"grant_type": "authorization_code",
|
||||
"code": code,
|
||||
"redirect_uri": redirect_uri,
|
||||
"client_id": client_id,
|
||||
"client_secret": client_secret,
|
||||
"code_verifier": pkce_challenge["code_verifier"],
|
||||
},
|
||||
)
|
||||
|
||||
# Verify the response
|
||||
assert token_response.status_code == 400
|
||||
data = token_response.json()
|
||||
assert data["error"] == "invalid_grant"
|
||||
assert data["error_description"] == "The authorization code is invalid"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_token_error_handling_refresh_token(
|
||||
self, client, oauth_provider, registered_client, pkce_challenge
|
||||
):
|
||||
# Register the client and get tokens
|
||||
client_id = registered_client["client_id"]
|
||||
client_secret = registered_client["client_secret"]
|
||||
redirect_uri = registered_client["redirect_uris"][0]
|
||||
|
||||
# First get an authorization code
|
||||
auth_response = await client.get(
|
||||
"/authorize",
|
||||
params={
|
||||
"client_id": client_id,
|
||||
"redirect_uri": redirect_uri,
|
||||
"response_type": "code",
|
||||
"code_challenge": pkce_challenge["code_challenge"],
|
||||
"code_challenge_method": "S256",
|
||||
"state": "test_state",
|
||||
},
|
||||
)
|
||||
assert auth_response.status_code == 302, auth_response.content
|
||||
|
||||
redirect_url = auth_response.headers["location"]
|
||||
parsed_url = urlparse(redirect_url)
|
||||
query_params = parse_qs(parsed_url.query)
|
||||
code = query_params["code"][0]
|
||||
|
||||
# Exchange the code for tokens
|
||||
token_response = await client.post(
|
||||
"/token",
|
||||
data={
|
||||
"grant_type": "authorization_code",
|
||||
"code": code,
|
||||
"redirect_uri": redirect_uri,
|
||||
"client_id": client_id,
|
||||
"client_secret": client_secret,
|
||||
"code_verifier": pkce_challenge["code_verifier"],
|
||||
},
|
||||
)
|
||||
|
||||
tokens = token_response.json()
|
||||
refresh_token = tokens["refresh_token"]
|
||||
|
||||
# Mock the exchange_refresh_token method to raise a token error
|
||||
with unittest.mock.patch.object(
|
||||
oauth_provider,
|
||||
"exchange_refresh_token",
|
||||
side_effect=TokenError(
|
||||
error="invalid_scope",
|
||||
error_description="The requested scope is invalid",
|
||||
),
|
||||
):
|
||||
# Try to use the refresh token
|
||||
refresh_response = await client.post(
|
||||
"/token",
|
||||
data={
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refresh_token,
|
||||
"client_id": client_id,
|
||||
"client_secret": client_secret,
|
||||
},
|
||||
)
|
||||
|
||||
# Verify the response
|
||||
assert refresh_response.status_code == 400
|
||||
data = refresh_response.json()
|
||||
assert data["error"] == "invalid_scope"
|
||||
assert data["error_description"] == "The requested scope is invalid"
|
||||
3
tests/server/fastmcp/auth/__init__.py
Normal file
3
tests/server/fastmcp/auth/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
Tests for the MCP server auth components.
|
||||
"""
|
||||
1267
tests/server/fastmcp/auth/test_auth_integration.py
Normal file
1267
tests/server/fastmcp/auth/test_auth_integration.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user