Fix the issue of get Authorization header fails during bearer auth (#637)

Co-authored-by: yangben <yangben@zhihu.com>
This commit is contained in:
yabea
2025-05-08 00:42:02 +08:00
committed by GitHub
parent 9d99aee014
commit a1307abded
2 changed files with 70 additions and 2 deletions

View File

@@ -34,8 +34,15 @@ class BearerAuthBackend(AuthenticationBackend):
self.provider = provider self.provider = provider
async def authenticate(self, conn: HTTPConnection): async def authenticate(self, conn: HTTPConnection):
auth_header = conn.headers.get("Authorization") auth_header = next(
if not auth_header or not auth_header.startswith("Bearer "): (
conn.headers.get(key)
for key in conn.headers
if key.lower() == "authorization"
),
None,
)
if not auth_header or not auth_header.lower().startswith("bearer "):
return None return None
token = auth_header[7:] # Remove "Bearer " prefix token = auth_header[7:] # Remove "Bearer " prefix

View File

@@ -7,6 +7,7 @@ from typing import Any, cast
import pytest import pytest
from starlette.authentication import AuthCredentials from starlette.authentication import AuthCredentials
from starlette.datastructures import Headers
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
from starlette.requests import Request from starlette.requests import Request
from starlette.types import Message, Receive, Scope, Send from starlette.types import Message, Receive, Scope, Send
@@ -221,6 +222,66 @@ class TestBearerAuthBackend:
assert user.access_token == no_expiry_access_token assert user.access_token == no_expiry_access_token
assert user.scopes == ["read", "write"] assert user.scopes == ["read", "write"]
async def test_lowercase_bearer_prefix(
self,
mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any],
valid_access_token: AccessToken,
):
"""Test with lowercase 'bearer' prefix in Authorization header"""
backend = BearerAuthBackend(provider=mock_oauth_provider)
add_token_to_provider(mock_oauth_provider, "valid_token", valid_access_token)
headers = Headers({"Authorization": "bearer valid_token"})
scope = {"type": "http", "headers": headers.raw}
request = Request(scope)
result = await backend.authenticate(request)
assert result is not None
credentials, user = result
assert isinstance(credentials, AuthCredentials)
assert isinstance(user, AuthenticatedUser)
assert credentials.scopes == ["read", "write"]
assert user.display_name == "test_client"
assert user.access_token == valid_access_token
async def test_mixed_case_bearer_prefix(
self,
mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any],
valid_access_token: AccessToken,
):
"""Test with mixed 'BeArEr' prefix in Authorization header"""
backend = BearerAuthBackend(provider=mock_oauth_provider)
add_token_to_provider(mock_oauth_provider, "valid_token", valid_access_token)
headers = Headers({"authorization": "BeArEr valid_token"})
scope = {"type": "http", "headers": headers.raw}
request = Request(scope)
result = await backend.authenticate(request)
assert result is not None
credentials, user = result
assert isinstance(credentials, AuthCredentials)
assert isinstance(user, AuthenticatedUser)
assert credentials.scopes == ["read", "write"]
assert user.display_name == "test_client"
assert user.access_token == valid_access_token
async def test_mixed_case_authorization_header(
self,
mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any],
valid_access_token: AccessToken,
):
"""Test authentication with mixed 'Authorization' header."""
backend = BearerAuthBackend(provider=mock_oauth_provider)
add_token_to_provider(mock_oauth_provider, "valid_token", valid_access_token)
headers = Headers({"AuThOrIzAtIoN": "BeArEr valid_token"})
scope = {"type": "http", "headers": headers.raw}
request = Request(scope)
result = await backend.authenticate(request)
assert result is not None
credentials, user = result
assert isinstance(credentials, AuthCredentials)
assert isinstance(user, AuthenticatedUser)
assert credentials.scopes == ["read", "write"]
assert user.display_name == "test_client"
assert user.access_token == valid_access_token
@pytest.mark.anyio @pytest.mark.anyio
class TestRequireAuthMiddleware: class TestRequireAuthMiddleware: