mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 23:04:25 +01:00
453 lines
16 KiB
Python
453 lines
16 KiB
Python
"""
|
|
Tests for the BearerAuth middleware components.
|
|
"""
|
|
|
|
import time
|
|
from typing import Any, cast
|
|
|
|
import pytest
|
|
from starlette.authentication import AuthCredentials
|
|
from starlette.datastructures import Headers
|
|
from starlette.exceptions import HTTPException
|
|
from starlette.requests import Request
|
|
from starlette.types import Message, Receive, Scope, Send
|
|
|
|
from mcp.server.auth.middleware.bearer_auth import (
|
|
AuthenticatedUser,
|
|
BearerAuthBackend,
|
|
RequireAuthMiddleware,
|
|
)
|
|
from mcp.server.auth.provider import (
|
|
AccessToken,
|
|
OAuthAuthorizationServerProvider,
|
|
)
|
|
|
|
|
|
class MockOAuthProvider:
|
|
"""Mock OAuth provider for testing.
|
|
|
|
This is a simplified version that only implements the methods needed for testing
|
|
the BearerAuthMiddleware components.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.tokens = {} # token -> AccessToken
|
|
|
|
def add_token(self, token: str, access_token: AccessToken) -> None:
|
|
"""Add a token to the provider."""
|
|
self.tokens[token] = access_token
|
|
|
|
async def load_access_token(self, token: str) -> AccessToken | None:
|
|
"""Load an access token."""
|
|
return self.tokens.get(token)
|
|
|
|
|
|
def add_token_to_provider(
|
|
provider: OAuthAuthorizationServerProvider[Any, Any, Any],
|
|
token: str,
|
|
access_token: AccessToken,
|
|
) -> None:
|
|
"""Helper function to add a token to a provider.
|
|
|
|
This is used to work around type checking issues with our mock provider.
|
|
"""
|
|
# We know this is actually a MockOAuthProvider
|
|
mock_provider = cast(MockOAuthProvider, provider)
|
|
mock_provider.add_token(token, access_token)
|
|
|
|
|
|
class MockApp:
|
|
"""Mock ASGI app for testing."""
|
|
|
|
def __init__(self):
|
|
self.called = False
|
|
self.scope: Scope | None = None
|
|
self.receive: Receive | None = None
|
|
self.send: Send | None = None
|
|
|
|
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
|
self.called = True
|
|
self.scope = scope
|
|
self.receive = receive
|
|
self.send = send
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_oauth_provider() -> OAuthAuthorizationServerProvider[Any, Any, Any]:
|
|
"""Create a mock OAuth provider."""
|
|
# Use type casting to satisfy the type checker
|
|
return cast(OAuthAuthorizationServerProvider[Any, Any, Any], MockOAuthProvider())
|
|
|
|
|
|
@pytest.fixture
|
|
def valid_access_token() -> AccessToken:
|
|
"""Create a valid access token."""
|
|
return AccessToken(
|
|
token="valid_token",
|
|
client_id="test_client",
|
|
scopes=["read", "write"],
|
|
expires_at=int(time.time()) + 3600, # 1 hour from now
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def expired_access_token() -> AccessToken:
|
|
"""Create an expired access token."""
|
|
return AccessToken(
|
|
token="expired_token",
|
|
client_id="test_client",
|
|
scopes=["read"],
|
|
expires_at=int(time.time()) - 3600, # 1 hour ago
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def no_expiry_access_token() -> AccessToken:
|
|
"""Create an access token with no expiry."""
|
|
return AccessToken(
|
|
token="no_expiry_token",
|
|
client_id="test_client",
|
|
scopes=["read", "write"],
|
|
expires_at=None,
|
|
)
|
|
|
|
|
|
@pytest.mark.anyio
|
|
class TestBearerAuthBackend:
|
|
"""Tests for the BearerAuthBackend class."""
|
|
|
|
async def test_no_auth_header(
|
|
self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any]
|
|
):
|
|
"""Test authentication with no Authorization header."""
|
|
backend = BearerAuthBackend(provider=mock_oauth_provider)
|
|
request = Request({"type": "http", "headers": []})
|
|
result = await backend.authenticate(request)
|
|
assert result is None
|
|
|
|
async def test_non_bearer_auth_header(
|
|
self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any]
|
|
):
|
|
"""Test authentication with non-Bearer Authorization header."""
|
|
backend = BearerAuthBackend(provider=mock_oauth_provider)
|
|
request = Request(
|
|
{
|
|
"type": "http",
|
|
"headers": [(b"authorization", b"Basic dXNlcjpwYXNz")],
|
|
}
|
|
)
|
|
result = await backend.authenticate(request)
|
|
assert result is None
|
|
|
|
async def test_invalid_token(
|
|
self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any]
|
|
):
|
|
"""Test authentication with invalid token."""
|
|
backend = BearerAuthBackend(provider=mock_oauth_provider)
|
|
request = Request(
|
|
{
|
|
"type": "http",
|
|
"headers": [(b"authorization", b"Bearer invalid_token")],
|
|
}
|
|
)
|
|
result = await backend.authenticate(request)
|
|
assert result is None
|
|
|
|
async def test_expired_token(
|
|
self,
|
|
mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any],
|
|
expired_access_token: AccessToken,
|
|
):
|
|
"""Test authentication with expired token."""
|
|
backend = BearerAuthBackend(provider=mock_oauth_provider)
|
|
add_token_to_provider(
|
|
mock_oauth_provider, "expired_token", expired_access_token
|
|
)
|
|
request = Request(
|
|
{
|
|
"type": "http",
|
|
"headers": [(b"authorization", b"Bearer expired_token")],
|
|
}
|
|
)
|
|
result = await backend.authenticate(request)
|
|
assert result is None
|
|
|
|
async def test_valid_token(
|
|
self,
|
|
mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any],
|
|
valid_access_token: AccessToken,
|
|
):
|
|
"""Test authentication with valid token."""
|
|
backend = BearerAuthBackend(provider=mock_oauth_provider)
|
|
add_token_to_provider(mock_oauth_provider, "valid_token", valid_access_token)
|
|
request = Request(
|
|
{
|
|
"type": "http",
|
|
"headers": [(b"authorization", b"Bearer valid_token")],
|
|
}
|
|
)
|
|
result = await backend.authenticate(request)
|
|
assert result is not None
|
|
credentials, user = result
|
|
assert isinstance(credentials, AuthCredentials)
|
|
assert isinstance(user, AuthenticatedUser)
|
|
assert credentials.scopes == ["read", "write"]
|
|
assert user.display_name == "test_client"
|
|
assert user.access_token == valid_access_token
|
|
assert user.scopes == ["read", "write"]
|
|
|
|
async def test_token_without_expiry(
|
|
self,
|
|
mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any],
|
|
no_expiry_access_token: AccessToken,
|
|
):
|
|
"""Test authentication with token that has no expiry."""
|
|
backend = BearerAuthBackend(provider=mock_oauth_provider)
|
|
add_token_to_provider(
|
|
mock_oauth_provider, "no_expiry_token", no_expiry_access_token
|
|
)
|
|
request = Request(
|
|
{
|
|
"type": "http",
|
|
"headers": [(b"authorization", b"Bearer no_expiry_token")],
|
|
}
|
|
)
|
|
result = await backend.authenticate(request)
|
|
assert result is not None
|
|
credentials, user = result
|
|
assert isinstance(credentials, AuthCredentials)
|
|
assert isinstance(user, AuthenticatedUser)
|
|
assert credentials.scopes == ["read", "write"]
|
|
assert user.display_name == "test_client"
|
|
assert user.access_token == no_expiry_access_token
|
|
assert user.scopes == ["read", "write"]
|
|
|
|
async def test_lowercase_bearer_prefix(
|
|
self,
|
|
mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any],
|
|
valid_access_token: AccessToken,
|
|
):
|
|
"""Test with lowercase 'bearer' prefix in Authorization header"""
|
|
backend = BearerAuthBackend(provider=mock_oauth_provider)
|
|
add_token_to_provider(mock_oauth_provider, "valid_token", valid_access_token)
|
|
headers = Headers({"Authorization": "bearer valid_token"})
|
|
scope = {"type": "http", "headers": headers.raw}
|
|
request = Request(scope)
|
|
result = await backend.authenticate(request)
|
|
assert result is not None
|
|
credentials, user = result
|
|
assert isinstance(credentials, AuthCredentials)
|
|
assert isinstance(user, AuthenticatedUser)
|
|
assert credentials.scopes == ["read", "write"]
|
|
assert user.display_name == "test_client"
|
|
assert user.access_token == valid_access_token
|
|
|
|
async def test_mixed_case_bearer_prefix(
|
|
self,
|
|
mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any],
|
|
valid_access_token: AccessToken,
|
|
):
|
|
"""Test with mixed 'BeArEr' prefix in Authorization header"""
|
|
backend = BearerAuthBackend(provider=mock_oauth_provider)
|
|
add_token_to_provider(mock_oauth_provider, "valid_token", valid_access_token)
|
|
headers = Headers({"authorization": "BeArEr valid_token"})
|
|
scope = {"type": "http", "headers": headers.raw}
|
|
request = Request(scope)
|
|
result = await backend.authenticate(request)
|
|
assert result is not None
|
|
credentials, user = result
|
|
assert isinstance(credentials, AuthCredentials)
|
|
assert isinstance(user, AuthenticatedUser)
|
|
assert credentials.scopes == ["read", "write"]
|
|
assert user.display_name == "test_client"
|
|
assert user.access_token == valid_access_token
|
|
|
|
async def test_mixed_case_authorization_header(
|
|
self,
|
|
mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any],
|
|
valid_access_token: AccessToken,
|
|
):
|
|
"""Test authentication with mixed 'Authorization' header."""
|
|
backend = BearerAuthBackend(provider=mock_oauth_provider)
|
|
add_token_to_provider(mock_oauth_provider, "valid_token", valid_access_token)
|
|
headers = Headers({"AuThOrIzAtIoN": "BeArEr valid_token"})
|
|
scope = {"type": "http", "headers": headers.raw}
|
|
request = Request(scope)
|
|
result = await backend.authenticate(request)
|
|
assert result is not None
|
|
credentials, user = result
|
|
assert isinstance(credentials, AuthCredentials)
|
|
assert isinstance(user, AuthenticatedUser)
|
|
assert credentials.scopes == ["read", "write"]
|
|
assert user.display_name == "test_client"
|
|
assert user.access_token == valid_access_token
|
|
|
|
|
|
@pytest.mark.anyio
|
|
class TestRequireAuthMiddleware:
|
|
"""Tests for the RequireAuthMiddleware class."""
|
|
|
|
async def test_no_user(self):
|
|
"""Test middleware with no user in scope."""
|
|
app = MockApp()
|
|
middleware = RequireAuthMiddleware(app, required_scopes=["read"])
|
|
scope: Scope = {"type": "http"}
|
|
|
|
# Create dummy async functions for receive and send
|
|
async def receive() -> Message:
|
|
return {"type": "http.request"}
|
|
|
|
async def send(message: Message) -> None:
|
|
pass
|
|
|
|
with pytest.raises(HTTPException) as excinfo:
|
|
await middleware(scope, receive, send)
|
|
|
|
assert excinfo.value.status_code == 401
|
|
assert excinfo.value.detail == "Unauthorized"
|
|
assert not app.called
|
|
|
|
async def test_non_authenticated_user(self):
|
|
"""Test middleware with non-authenticated user in scope."""
|
|
app = MockApp()
|
|
middleware = RequireAuthMiddleware(app, required_scopes=["read"])
|
|
scope: Scope = {"type": "http", "user": object()}
|
|
|
|
# Create dummy async functions for receive and send
|
|
async def receive() -> Message:
|
|
return {"type": "http.request"}
|
|
|
|
async def send(message: Message) -> None:
|
|
pass
|
|
|
|
with pytest.raises(HTTPException) as excinfo:
|
|
await middleware(scope, receive, send)
|
|
|
|
assert excinfo.value.status_code == 401
|
|
assert excinfo.value.detail == "Unauthorized"
|
|
assert not app.called
|
|
|
|
async def test_missing_required_scope(self, valid_access_token: AccessToken):
|
|
"""Test middleware with user missing required scope."""
|
|
app = MockApp()
|
|
middleware = RequireAuthMiddleware(app, required_scopes=["admin"])
|
|
|
|
# Create a user with read/write scopes but not admin
|
|
user = AuthenticatedUser(valid_access_token)
|
|
auth = AuthCredentials(["read", "write"])
|
|
|
|
scope: Scope = {"type": "http", "user": user, "auth": auth}
|
|
|
|
# Create dummy async functions for receive and send
|
|
async def receive() -> Message:
|
|
return {"type": "http.request"}
|
|
|
|
async def send(message: Message) -> None:
|
|
pass
|
|
|
|
with pytest.raises(HTTPException) as excinfo:
|
|
await middleware(scope, receive, send)
|
|
|
|
assert excinfo.value.status_code == 403
|
|
assert excinfo.value.detail == "Insufficient scope"
|
|
assert not app.called
|
|
|
|
async def test_no_auth_credentials(self, valid_access_token: AccessToken):
|
|
"""Test middleware with no auth credentials in scope."""
|
|
app = MockApp()
|
|
middleware = RequireAuthMiddleware(app, required_scopes=["read"])
|
|
|
|
# Create a user with read/write scopes
|
|
user = AuthenticatedUser(valid_access_token)
|
|
|
|
scope: Scope = {"type": "http", "user": user} # No auth credentials
|
|
|
|
# Create dummy async functions for receive and send
|
|
async def receive() -> Message:
|
|
return {"type": "http.request"}
|
|
|
|
async def send(message: Message) -> None:
|
|
pass
|
|
|
|
with pytest.raises(HTTPException) as excinfo:
|
|
await middleware(scope, receive, send)
|
|
|
|
assert excinfo.value.status_code == 403
|
|
assert excinfo.value.detail == "Insufficient scope"
|
|
assert not app.called
|
|
|
|
async def test_has_required_scopes(self, valid_access_token: AccessToken):
|
|
"""Test middleware with user having all required scopes."""
|
|
app = MockApp()
|
|
middleware = RequireAuthMiddleware(app, required_scopes=["read"])
|
|
|
|
# Create a user with read/write scopes
|
|
user = AuthenticatedUser(valid_access_token)
|
|
auth = AuthCredentials(["read", "write"])
|
|
|
|
scope: Scope = {"type": "http", "user": user, "auth": auth}
|
|
|
|
# Create dummy async functions for receive and send
|
|
async def receive() -> Message:
|
|
return {"type": "http.request"}
|
|
|
|
async def send(message: Message) -> None:
|
|
pass
|
|
|
|
await middleware(scope, receive, send)
|
|
|
|
assert app.called
|
|
assert app.scope == scope
|
|
assert app.receive == receive
|
|
assert app.send == send
|
|
|
|
async def test_multiple_required_scopes(self, valid_access_token: AccessToken):
|
|
"""Test middleware with multiple required scopes."""
|
|
app = MockApp()
|
|
middleware = RequireAuthMiddleware(app, required_scopes=["read", "write"])
|
|
|
|
# Create a user with read/write scopes
|
|
user = AuthenticatedUser(valid_access_token)
|
|
auth = AuthCredentials(["read", "write"])
|
|
|
|
scope: Scope = {"type": "http", "user": user, "auth": auth}
|
|
|
|
# Create dummy async functions for receive and send
|
|
async def receive() -> Message:
|
|
return {"type": "http.request"}
|
|
|
|
async def send(message: Message) -> None:
|
|
pass
|
|
|
|
await middleware(scope, receive, send)
|
|
|
|
assert app.called
|
|
assert app.scope == scope
|
|
assert app.receive == receive
|
|
assert app.send == send
|
|
|
|
async def test_no_required_scopes(self, valid_access_token: AccessToken):
|
|
"""Test middleware with no required scopes."""
|
|
app = MockApp()
|
|
middleware = RequireAuthMiddleware(app, required_scopes=[])
|
|
|
|
# Create a user with read/write scopes
|
|
user = AuthenticatedUser(valid_access_token)
|
|
auth = AuthCredentials(["read", "write"])
|
|
|
|
scope: Scope = {"type": "http", "user": user, "auth": auth}
|
|
|
|
# Create dummy async functions for receive and send
|
|
async def receive() -> Message:
|
|
return {"type": "http.request"}
|
|
|
|
async def send(message: Message) -> None:
|
|
pass
|
|
|
|
await middleware(scope, receive, send)
|
|
|
|
assert app.called
|
|
assert app.scope == scope
|
|
assert app.receive == receive
|
|
assert app.send == send
|