mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 14:54:24 +01:00
Fix the issue of get Authorization header fails during bearer auth (#637)
Co-authored-by: yangben <yangben@zhihu.com>
This commit is contained in:
@@ -34,8 +34,15 @@ class BearerAuthBackend(AuthenticationBackend):
|
||||
self.provider = provider
|
||||
|
||||
async def authenticate(self, conn: HTTPConnection):
|
||||
auth_header = conn.headers.get("Authorization")
|
||||
if not auth_header or not auth_header.startswith("Bearer "):
|
||||
auth_header = next(
|
||||
(
|
||||
conn.headers.get(key)
|
||||
for key in conn.headers
|
||||
if key.lower() == "authorization"
|
||||
),
|
||||
None,
|
||||
)
|
||||
if not auth_header or not auth_header.lower().startswith("bearer "):
|
||||
return None
|
||||
|
||||
token = auth_header[7:] # Remove "Bearer " prefix
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import Any, cast
|
||||
|
||||
import pytest
|
||||
from starlette.authentication import AuthCredentials
|
||||
from starlette.datastructures import Headers
|
||||
from starlette.exceptions import HTTPException
|
||||
from starlette.requests import Request
|
||||
from starlette.types import Message, Receive, Scope, Send
|
||||
@@ -221,6 +222,66 @@ class TestBearerAuthBackend:
|
||||
assert user.access_token == no_expiry_access_token
|
||||
assert user.scopes == ["read", "write"]
|
||||
|
||||
async def test_lowercase_bearer_prefix(
|
||||
self,
|
||||
mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any],
|
||||
valid_access_token: AccessToken,
|
||||
):
|
||||
"""Test with lowercase 'bearer' prefix in Authorization header"""
|
||||
backend = BearerAuthBackend(provider=mock_oauth_provider)
|
||||
add_token_to_provider(mock_oauth_provider, "valid_token", valid_access_token)
|
||||
headers = Headers({"Authorization": "bearer valid_token"})
|
||||
scope = {"type": "http", "headers": headers.raw}
|
||||
request = Request(scope)
|
||||
result = await backend.authenticate(request)
|
||||
assert result is not None
|
||||
credentials, user = result
|
||||
assert isinstance(credentials, AuthCredentials)
|
||||
assert isinstance(user, AuthenticatedUser)
|
||||
assert credentials.scopes == ["read", "write"]
|
||||
assert user.display_name == "test_client"
|
||||
assert user.access_token == valid_access_token
|
||||
|
||||
async def test_mixed_case_bearer_prefix(
|
||||
self,
|
||||
mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any],
|
||||
valid_access_token: AccessToken,
|
||||
):
|
||||
"""Test with mixed 'BeArEr' prefix in Authorization header"""
|
||||
backend = BearerAuthBackend(provider=mock_oauth_provider)
|
||||
add_token_to_provider(mock_oauth_provider, "valid_token", valid_access_token)
|
||||
headers = Headers({"authorization": "BeArEr valid_token"})
|
||||
scope = {"type": "http", "headers": headers.raw}
|
||||
request = Request(scope)
|
||||
result = await backend.authenticate(request)
|
||||
assert result is not None
|
||||
credentials, user = result
|
||||
assert isinstance(credentials, AuthCredentials)
|
||||
assert isinstance(user, AuthenticatedUser)
|
||||
assert credentials.scopes == ["read", "write"]
|
||||
assert user.display_name == "test_client"
|
||||
assert user.access_token == valid_access_token
|
||||
|
||||
async def test_mixed_case_authorization_header(
|
||||
self,
|
||||
mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any],
|
||||
valid_access_token: AccessToken,
|
||||
):
|
||||
"""Test authentication with mixed 'Authorization' header."""
|
||||
backend = BearerAuthBackend(provider=mock_oauth_provider)
|
||||
add_token_to_provider(mock_oauth_provider, "valid_token", valid_access_token)
|
||||
headers = Headers({"AuThOrIzAtIoN": "BeArEr valid_token"})
|
||||
scope = {"type": "http", "headers": headers.raw}
|
||||
request = Request(scope)
|
||||
result = await backend.authenticate(request)
|
||||
assert result is not None
|
||||
credentials, user = result
|
||||
assert isinstance(credentials, AuthCredentials)
|
||||
assert isinstance(user, AuthenticatedUser)
|
||||
assert credentials.scopes == ["read", "write"]
|
||||
assert user.display_name == "test_client"
|
||||
assert user.access_token == valid_access_token
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
class TestRequireAuthMiddleware:
|
||||
|
||||
Reference in New Issue
Block a user