Add support for serverside oauth (#255)

Co-authored-by: David Soria Parra <davidsp@anthropic.com>
Co-authored-by: Basil Hosmer <basil@anthropic.com>
Co-authored-by: ihrpr <inna@anthropic.com>
This commit is contained in:
Peter Raboud
2025-05-01 11:42:59 -07:00
committed by GitHub
parent 82bd8bc1d9
commit 2210c1be18
31 changed files with 4120 additions and 22 deletions

View File

@@ -0,0 +1,122 @@
"""
Tests for the AuthContext middleware components.
"""
import time
import pytest
from starlette.types import Message, Receive, Scope, Send
from mcp.server.auth.middleware.auth_context import (
AuthContextMiddleware,
auth_context_var,
get_access_token,
)
from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser
from mcp.server.auth.provider import AccessToken
class MockApp:
"""Mock ASGI app for testing."""
def __init__(self):
self.called = False
self.scope: Scope | None = None
self.receive: Receive | None = None
self.send: Send | None = None
self.access_token_during_call: AccessToken | None = None
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
self.called = True
self.scope = scope
self.receive = receive
self.send = send
# Check the context during the call
self.access_token_during_call = get_access_token()
@pytest.fixture
def valid_access_token() -> AccessToken:
"""Create a valid access token."""
return AccessToken(
token="valid_token",
client_id="test_client",
scopes=["read", "write"],
expires_at=int(time.time()) + 3600, # 1 hour from now
)
@pytest.mark.anyio
class TestAuthContextMiddleware:
"""Tests for the AuthContextMiddleware class."""
async def test_with_authenticated_user(self, valid_access_token: AccessToken):
"""Test middleware with an authenticated user in scope."""
app = MockApp()
middleware = AuthContextMiddleware(app)
# Create an authenticated user
user = AuthenticatedUser(valid_access_token)
scope: Scope = {"type": "http", "user": user}
# Create dummy async functions for receive and send
async def receive() -> Message:
return {"type": "http.request"}
async def send(message: Message) -> None:
pass
# Verify context is empty before middleware
assert auth_context_var.get() is None
assert get_access_token() is None
# Run the middleware
await middleware(scope, receive, send)
# Verify the app was called
assert app.called
assert app.scope == scope
assert app.receive == receive
assert app.send == send
# Verify the access token was available during the call
assert app.access_token_during_call == valid_access_token
# Verify context is reset after middleware
assert auth_context_var.get() is None
assert get_access_token() is None
async def test_with_no_user(self):
"""Test middleware with no user in scope."""
app = MockApp()
middleware = AuthContextMiddleware(app)
scope: Scope = {"type": "http"} # No user
# Create dummy async functions for receive and send
async def receive() -> Message:
return {"type": "http.request"}
async def send(message: Message) -> None:
pass
# Verify context is empty before middleware
assert auth_context_var.get() is None
assert get_access_token() is None
# Run the middleware
await middleware(scope, receive, send)
# Verify the app was called
assert app.called
assert app.scope == scope
assert app.receive == receive
assert app.send == send
# Verify the access token was not available during the call
assert app.access_token_during_call is None
# Verify context is still empty after middleware
assert auth_context_var.get() is None
assert get_access_token() is None

View File

@@ -0,0 +1,391 @@
"""
Tests for the BearerAuth middleware components.
"""
import time
from typing import Any, cast
import pytest
from starlette.authentication import AuthCredentials
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.types import Message, Receive, Scope, Send
from mcp.server.auth.middleware.bearer_auth import (
AuthenticatedUser,
BearerAuthBackend,
RequireAuthMiddleware,
)
from mcp.server.auth.provider import (
AccessToken,
OAuthAuthorizationServerProvider,
)
class MockOAuthProvider:
"""Mock OAuth provider for testing.
This is a simplified version that only implements the methods needed for testing
the BearerAuthMiddleware components.
"""
def __init__(self):
self.tokens = {} # token -> AccessToken
def add_token(self, token: str, access_token: AccessToken) -> None:
"""Add a token to the provider."""
self.tokens[token] = access_token
async def load_access_token(self, token: str) -> AccessToken | None:
"""Load an access token."""
return self.tokens.get(token)
def add_token_to_provider(
provider: OAuthAuthorizationServerProvider[Any, Any, Any],
token: str,
access_token: AccessToken,
) -> None:
"""Helper function to add a token to a provider.
This is used to work around type checking issues with our mock provider.
"""
# We know this is actually a MockOAuthProvider
mock_provider = cast(MockOAuthProvider, provider)
mock_provider.add_token(token, access_token)
class MockApp:
"""Mock ASGI app for testing."""
def __init__(self):
self.called = False
self.scope: Scope | None = None
self.receive: Receive | None = None
self.send: Send | None = None
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
self.called = True
self.scope = scope
self.receive = receive
self.send = send
@pytest.fixture
def mock_oauth_provider() -> OAuthAuthorizationServerProvider[Any, Any, Any]:
"""Create a mock OAuth provider."""
# Use type casting to satisfy the type checker
return cast(OAuthAuthorizationServerProvider[Any, Any, Any], MockOAuthProvider())
@pytest.fixture
def valid_access_token() -> AccessToken:
"""Create a valid access token."""
return AccessToken(
token="valid_token",
client_id="test_client",
scopes=["read", "write"],
expires_at=int(time.time()) + 3600, # 1 hour from now
)
@pytest.fixture
def expired_access_token() -> AccessToken:
"""Create an expired access token."""
return AccessToken(
token="expired_token",
client_id="test_client",
scopes=["read"],
expires_at=int(time.time()) - 3600, # 1 hour ago
)
@pytest.fixture
def no_expiry_access_token() -> AccessToken:
"""Create an access token with no expiry."""
return AccessToken(
token="no_expiry_token",
client_id="test_client",
scopes=["read", "write"],
expires_at=None,
)
@pytest.mark.anyio
class TestBearerAuthBackend:
"""Tests for the BearerAuthBackend class."""
async def test_no_auth_header(
self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any]
):
"""Test authentication with no Authorization header."""
backend = BearerAuthBackend(provider=mock_oauth_provider)
request = Request({"type": "http", "headers": []})
result = await backend.authenticate(request)
assert result is None
async def test_non_bearer_auth_header(
self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any]
):
"""Test authentication with non-Bearer Authorization header."""
backend = BearerAuthBackend(provider=mock_oauth_provider)
request = Request(
{
"type": "http",
"headers": [(b"authorization", b"Basic dXNlcjpwYXNz")],
}
)
result = await backend.authenticate(request)
assert result is None
async def test_invalid_token(
self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any]
):
"""Test authentication with invalid token."""
backend = BearerAuthBackend(provider=mock_oauth_provider)
request = Request(
{
"type": "http",
"headers": [(b"authorization", b"Bearer invalid_token")],
}
)
result = await backend.authenticate(request)
assert result is None
async def test_expired_token(
self,
mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any],
expired_access_token: AccessToken,
):
"""Test authentication with expired token."""
backend = BearerAuthBackend(provider=mock_oauth_provider)
add_token_to_provider(
mock_oauth_provider, "expired_token", expired_access_token
)
request = Request(
{
"type": "http",
"headers": [(b"authorization", b"Bearer expired_token")],
}
)
result = await backend.authenticate(request)
assert result is None
async def test_valid_token(
self,
mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any],
valid_access_token: AccessToken,
):
"""Test authentication with valid token."""
backend = BearerAuthBackend(provider=mock_oauth_provider)
add_token_to_provider(mock_oauth_provider, "valid_token", valid_access_token)
request = Request(
{
"type": "http",
"headers": [(b"authorization", b"Bearer valid_token")],
}
)
result = await backend.authenticate(request)
assert result is not None
credentials, user = result
assert isinstance(credentials, AuthCredentials)
assert isinstance(user, AuthenticatedUser)
assert credentials.scopes == ["read", "write"]
assert user.display_name == "test_client"
assert user.access_token == valid_access_token
assert user.scopes == ["read", "write"]
async def test_token_without_expiry(
self,
mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any],
no_expiry_access_token: AccessToken,
):
"""Test authentication with token that has no expiry."""
backend = BearerAuthBackend(provider=mock_oauth_provider)
add_token_to_provider(
mock_oauth_provider, "no_expiry_token", no_expiry_access_token
)
request = Request(
{
"type": "http",
"headers": [(b"authorization", b"Bearer no_expiry_token")],
}
)
result = await backend.authenticate(request)
assert result is not None
credentials, user = result
assert isinstance(credentials, AuthCredentials)
assert isinstance(user, AuthenticatedUser)
assert credentials.scopes == ["read", "write"]
assert user.display_name == "test_client"
assert user.access_token == no_expiry_access_token
assert user.scopes == ["read", "write"]
@pytest.mark.anyio
class TestRequireAuthMiddleware:
"""Tests for the RequireAuthMiddleware class."""
async def test_no_user(self):
"""Test middleware with no user in scope."""
app = MockApp()
middleware = RequireAuthMiddleware(app, required_scopes=["read"])
scope: Scope = {"type": "http"}
# Create dummy async functions for receive and send
async def receive() -> Message:
return {"type": "http.request"}
async def send(message: Message) -> None:
pass
with pytest.raises(HTTPException) as excinfo:
await middleware(scope, receive, send)
assert excinfo.value.status_code == 401
assert excinfo.value.detail == "Unauthorized"
assert not app.called
async def test_non_authenticated_user(self):
"""Test middleware with non-authenticated user in scope."""
app = MockApp()
middleware = RequireAuthMiddleware(app, required_scopes=["read"])
scope: Scope = {"type": "http", "user": object()}
# Create dummy async functions for receive and send
async def receive() -> Message:
return {"type": "http.request"}
async def send(message: Message) -> None:
pass
with pytest.raises(HTTPException) as excinfo:
await middleware(scope, receive, send)
assert excinfo.value.status_code == 401
assert excinfo.value.detail == "Unauthorized"
assert not app.called
async def test_missing_required_scope(self, valid_access_token: AccessToken):
"""Test middleware with user missing required scope."""
app = MockApp()
middleware = RequireAuthMiddleware(app, required_scopes=["admin"])
# Create a user with read/write scopes but not admin
user = AuthenticatedUser(valid_access_token)
auth = AuthCredentials(["read", "write"])
scope: Scope = {"type": "http", "user": user, "auth": auth}
# Create dummy async functions for receive and send
async def receive() -> Message:
return {"type": "http.request"}
async def send(message: Message) -> None:
pass
with pytest.raises(HTTPException) as excinfo:
await middleware(scope, receive, send)
assert excinfo.value.status_code == 403
assert excinfo.value.detail == "Insufficient scope"
assert not app.called
async def test_no_auth_credentials(self, valid_access_token: AccessToken):
"""Test middleware with no auth credentials in scope."""
app = MockApp()
middleware = RequireAuthMiddleware(app, required_scopes=["read"])
# Create a user with read/write scopes
user = AuthenticatedUser(valid_access_token)
scope: Scope = {"type": "http", "user": user} # No auth credentials
# Create dummy async functions for receive and send
async def receive() -> Message:
return {"type": "http.request"}
async def send(message: Message) -> None:
pass
with pytest.raises(HTTPException) as excinfo:
await middleware(scope, receive, send)
assert excinfo.value.status_code == 403
assert excinfo.value.detail == "Insufficient scope"
assert not app.called
async def test_has_required_scopes(self, valid_access_token: AccessToken):
"""Test middleware with user having all required scopes."""
app = MockApp()
middleware = RequireAuthMiddleware(app, required_scopes=["read"])
# Create a user with read/write scopes
user = AuthenticatedUser(valid_access_token)
auth = AuthCredentials(["read", "write"])
scope: Scope = {"type": "http", "user": user, "auth": auth}
# Create dummy async functions for receive and send
async def receive() -> Message:
return {"type": "http.request"}
async def send(message: Message) -> None:
pass
await middleware(scope, receive, send)
assert app.called
assert app.scope == scope
assert app.receive == receive
assert app.send == send
async def test_multiple_required_scopes(self, valid_access_token: AccessToken):
"""Test middleware with multiple required scopes."""
app = MockApp()
middleware = RequireAuthMiddleware(app, required_scopes=["read", "write"])
# Create a user with read/write scopes
user = AuthenticatedUser(valid_access_token)
auth = AuthCredentials(["read", "write"])
scope: Scope = {"type": "http", "user": user, "auth": auth}
# Create dummy async functions for receive and send
async def receive() -> Message:
return {"type": "http.request"}
async def send(message: Message) -> None:
pass
await middleware(scope, receive, send)
assert app.called
assert app.scope == scope
assert app.receive == receive
assert app.send == send
async def test_no_required_scopes(self, valid_access_token: AccessToken):
"""Test middleware with no required scopes."""
app = MockApp()
middleware = RequireAuthMiddleware(app, required_scopes=[])
# Create a user with read/write scopes
user = AuthenticatedUser(valid_access_token)
auth = AuthCredentials(["read", "write"])
scope: Scope = {"type": "http", "user": user, "auth": auth}
# Create dummy async functions for receive and send
async def receive() -> Message:
return {"type": "http.request"}
async def send(message: Message) -> None:
pass
await middleware(scope, receive, send)
assert app.called
assert app.scope == scope
assert app.receive == receive
assert app.send == send

View File

@@ -0,0 +1,294 @@
"""
Tests for OAuth error handling in the auth handlers.
"""
import unittest.mock
from urllib.parse import parse_qs, urlparse
import httpx
import pytest
from httpx import ASGITransport
from pydantic import AnyHttpUrl
from starlette.applications import Starlette
from mcp.server.auth.provider import (
AuthorizeError,
RegistrationError,
TokenError,
)
from mcp.server.auth.routes import create_auth_routes
from tests.server.fastmcp.auth.test_auth_integration import (
MockOAuthProvider,
)
@pytest.fixture
def oauth_provider():
"""Return a MockOAuthProvider instance that can be configured to raise errors."""
return MockOAuthProvider()
@pytest.fixture
def app(oauth_provider):
from mcp.server.auth.settings import ClientRegistrationOptions, RevocationOptions
# Enable client registration
client_registration_options = ClientRegistrationOptions(enabled=True)
revocation_options = RevocationOptions(enabled=True)
# Create auth routes
auth_routes = create_auth_routes(
oauth_provider,
issuer_url=AnyHttpUrl("http://localhost"),
client_registration_options=client_registration_options,
revocation_options=revocation_options,
)
# Create Starlette app with routes directly
return Starlette(routes=auth_routes)
@pytest.fixture
def client(app):
transport = ASGITransport(app=app)
# Use base_url without a path since routes are directly on the app
return httpx.AsyncClient(transport=transport, base_url="http://localhost")
@pytest.fixture
def pkce_challenge():
"""Create a PKCE challenge with code_verifier and code_challenge."""
import base64
import hashlib
import secrets
# Generate a code verifier
code_verifier = secrets.token_urlsafe(64)[:128]
# Create code challenge using S256 method
code_verifier_bytes = code_verifier.encode("ascii")
sha256 = hashlib.sha256(code_verifier_bytes).digest()
code_challenge = base64.urlsafe_b64encode(sha256).decode().rstrip("=")
return {"code_verifier": code_verifier, "code_challenge": code_challenge}
@pytest.fixture
async def registered_client(client):
"""Create and register a test client."""
# Default client metadata
client_metadata = {
"redirect_uris": ["https://client.example.com/callback"],
"token_endpoint_auth_method": "client_secret_post",
"grant_types": ["authorization_code", "refresh_token"],
"response_types": ["code"],
"client_name": "Test Client",
}
response = await client.post("/register", json=client_metadata)
assert response.status_code == 201, f"Failed to register client: {response.content}"
client_info = response.json()
return client_info
class TestRegistrationErrorHandling:
@pytest.mark.anyio
async def test_registration_error_handling(self, client, oauth_provider):
# Mock the register_client method to raise a registration error
with unittest.mock.patch.object(
oauth_provider,
"register_client",
side_effect=RegistrationError(
error="invalid_redirect_uri",
error_description="The redirect URI is invalid",
),
):
# Prepare a client registration request
client_data = {
"redirect_uris": ["https://client.example.com/callback"],
"token_endpoint_auth_method": "client_secret_post",
"grant_types": ["authorization_code", "refresh_token"],
"response_types": ["code"],
"client_name": "Test Client",
}
# Send the registration request
response = await client.post(
"/register",
json=client_data,
)
# Verify the response
assert response.status_code == 400, response.content
data = response.json()
assert data["error"] == "invalid_redirect_uri"
assert data["error_description"] == "The redirect URI is invalid"
class TestAuthorizeErrorHandling:
@pytest.mark.anyio
async def test_authorize_error_handling(
self, client, oauth_provider, registered_client, pkce_challenge
):
# Mock the authorize method to raise an authorize error
with unittest.mock.patch.object(
oauth_provider,
"authorize",
side_effect=AuthorizeError(
error="access_denied", error_description="The user denied the request"
),
):
# Register the client
client_id = registered_client["client_id"]
redirect_uri = registered_client["redirect_uris"][0]
# Prepare an authorization request
params = {
"client_id": client_id,
"redirect_uri": redirect_uri,
"response_type": "code",
"code_challenge": pkce_challenge["code_challenge"],
"code_challenge_method": "S256",
"state": "test_state",
}
# Send the authorization request
response = await client.get("/authorize", params=params)
# Verify the response is a redirect with error parameters
assert response.status_code == 302
redirect_url = response.headers["location"]
parsed_url = urlparse(redirect_url)
query_params = parse_qs(parsed_url.query)
assert query_params["error"][0] == "access_denied"
assert "error_description" in query_params
assert query_params["state"][0] == "test_state"
class TestTokenErrorHandling:
@pytest.mark.anyio
async def test_token_error_handling_auth_code(
self, client, oauth_provider, registered_client, pkce_challenge
):
# Register the client and get an auth code
client_id = registered_client["client_id"]
client_secret = registered_client["client_secret"]
redirect_uri = registered_client["redirect_uris"][0]
# First get an authorization code
auth_response = await client.get(
"/authorize",
params={
"client_id": client_id,
"redirect_uri": redirect_uri,
"response_type": "code",
"code_challenge": pkce_challenge["code_challenge"],
"code_challenge_method": "S256",
"state": "test_state",
},
)
redirect_url = auth_response.headers["location"]
parsed_url = urlparse(redirect_url)
query_params = parse_qs(parsed_url.query)
code = query_params["code"][0]
# Mock the exchange_authorization_code method to raise a token error
with unittest.mock.patch.object(
oauth_provider,
"exchange_authorization_code",
side_effect=TokenError(
error="invalid_grant",
error_description="The authorization code is invalid",
),
):
# Try to exchange the code for tokens
token_response = await client.post(
"/token",
data={
"grant_type": "authorization_code",
"code": code,
"redirect_uri": redirect_uri,
"client_id": client_id,
"client_secret": client_secret,
"code_verifier": pkce_challenge["code_verifier"],
},
)
# Verify the response
assert token_response.status_code == 400
data = token_response.json()
assert data["error"] == "invalid_grant"
assert data["error_description"] == "The authorization code is invalid"
@pytest.mark.anyio
async def test_token_error_handling_refresh_token(
self, client, oauth_provider, registered_client, pkce_challenge
):
# Register the client and get tokens
client_id = registered_client["client_id"]
client_secret = registered_client["client_secret"]
redirect_uri = registered_client["redirect_uris"][0]
# First get an authorization code
auth_response = await client.get(
"/authorize",
params={
"client_id": client_id,
"redirect_uri": redirect_uri,
"response_type": "code",
"code_challenge": pkce_challenge["code_challenge"],
"code_challenge_method": "S256",
"state": "test_state",
},
)
assert auth_response.status_code == 302, auth_response.content
redirect_url = auth_response.headers["location"]
parsed_url = urlparse(redirect_url)
query_params = parse_qs(parsed_url.query)
code = query_params["code"][0]
# Exchange the code for tokens
token_response = await client.post(
"/token",
data={
"grant_type": "authorization_code",
"code": code,
"redirect_uri": redirect_uri,
"client_id": client_id,
"client_secret": client_secret,
"code_verifier": pkce_challenge["code_verifier"],
},
)
tokens = token_response.json()
refresh_token = tokens["refresh_token"]
# Mock the exchange_refresh_token method to raise a token error
with unittest.mock.patch.object(
oauth_provider,
"exchange_refresh_token",
side_effect=TokenError(
error="invalid_scope",
error_description="The requested scope is invalid",
),
):
# Try to use the refresh token
refresh_response = await client.post(
"/token",
data={
"grant_type": "refresh_token",
"refresh_token": refresh_token,
"client_id": client_id,
"client_secret": client_secret,
},
)
# Verify the response
assert refresh_response.status_code == 400
data = refresh_response.json()
assert data["error"] == "invalid_scope"
assert data["error_description"] == "The requested scope is invalid"

View File

@@ -0,0 +1,3 @@
"""
Tests for the MCP server auth components.
"""

File diff suppressed because it is too large Load Diff