Fix the issue of get Authorization header fails during bearer auth (#637)

Co-authored-by: yangben <yangben@zhihu.com>
This commit is contained in:
yabea
2025-05-08 00:42:02 +08:00
committed by GitHub
parent 9d99aee014
commit a1307abded
2 changed files with 70 additions and 2 deletions

View File

@@ -34,8 +34,15 @@ class BearerAuthBackend(AuthenticationBackend):
self.provider = provider
async def authenticate(self, conn: HTTPConnection):
auth_header = conn.headers.get("Authorization")
if not auth_header or not auth_header.startswith("Bearer "):
auth_header = next(
(
conn.headers.get(key)
for key in conn.headers
if key.lower() == "authorization"
),
None,
)
if not auth_header or not auth_header.lower().startswith("bearer "):
return None
token = auth_header[7:] # Remove "Bearer " prefix

View File

@@ -7,6 +7,7 @@ from typing import Any, cast
import pytest
from starlette.authentication import AuthCredentials
from starlette.datastructures import Headers
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.types import Message, Receive, Scope, Send
@@ -221,6 +222,66 @@ class TestBearerAuthBackend:
assert user.access_token == no_expiry_access_token
assert user.scopes == ["read", "write"]
async def test_lowercase_bearer_prefix(
self,
mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any],
valid_access_token: AccessToken,
):
"""Test with lowercase 'bearer' prefix in Authorization header"""
backend = BearerAuthBackend(provider=mock_oauth_provider)
add_token_to_provider(mock_oauth_provider, "valid_token", valid_access_token)
headers = Headers({"Authorization": "bearer valid_token"})
scope = {"type": "http", "headers": headers.raw}
request = Request(scope)
result = await backend.authenticate(request)
assert result is not None
credentials, user = result
assert isinstance(credentials, AuthCredentials)
assert isinstance(user, AuthenticatedUser)
assert credentials.scopes == ["read", "write"]
assert user.display_name == "test_client"
assert user.access_token == valid_access_token
async def test_mixed_case_bearer_prefix(
self,
mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any],
valid_access_token: AccessToken,
):
"""Test with mixed 'BeArEr' prefix in Authorization header"""
backend = BearerAuthBackend(provider=mock_oauth_provider)
add_token_to_provider(mock_oauth_provider, "valid_token", valid_access_token)
headers = Headers({"authorization": "BeArEr valid_token"})
scope = {"type": "http", "headers": headers.raw}
request = Request(scope)
result = await backend.authenticate(request)
assert result is not None
credentials, user = result
assert isinstance(credentials, AuthCredentials)
assert isinstance(user, AuthenticatedUser)
assert credentials.scopes == ["read", "write"]
assert user.display_name == "test_client"
assert user.access_token == valid_access_token
async def test_mixed_case_authorization_header(
self,
mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any],
valid_access_token: AccessToken,
):
"""Test authentication with mixed 'Authorization' header."""
backend = BearerAuthBackend(provider=mock_oauth_provider)
add_token_to_provider(mock_oauth_provider, "valid_token", valid_access_token)
headers = Headers({"AuThOrIzAtIoN": "BeArEr valid_token"})
scope = {"type": "http", "headers": headers.raw}
request = Request(scope)
result = await backend.authenticate(request)
assert result is not None
credentials, user = result
assert isinstance(credentials, AuthCredentials)
assert isinstance(user, AuthenticatedUser)
assert credentials.scopes == ["read", "write"]
assert user.display_name == "test_client"
assert user.access_token == valid_access_token
@pytest.mark.anyio
class TestRequireAuthMiddleware: