Add support for serverside oauth (#255)

Co-authored-by: David Soria Parra <davidsp@anthropic.com>
Co-authored-by: Basil Hosmer <basil@anthropic.com>
Co-authored-by: ihrpr <inna@anthropic.com>
This commit is contained in:
Peter Raboud
2025-05-01 11:42:59 -07:00
committed by GitHub
parent 82bd8bc1d9
commit 2210c1be18
31 changed files with 4120 additions and 22 deletions

1
.gitignore vendored
View File

@@ -166,4 +166,5 @@ cython_debug/
# vscode # vscode
.vscode/ .vscode/
.windsurfrules
**/CLAUDE.local.md **/CLAUDE.local.md

View File

@@ -19,7 +19,7 @@ This document contains critical information about working with this codebase. Fo
- Line length: 88 chars maximum - Line length: 88 chars maximum
3. Testing Requirements 3. Testing Requirements
- Framework: `uv run pytest` - Framework: `uv run --frozen pytest`
- Async testing: use anyio, not asyncio - Async testing: use anyio, not asyncio
- Coverage: test edge cases and errors - Coverage: test edge cases and errors
- New features require tests - New features require tests
@@ -54,9 +54,9 @@ This document contains critical information about working with this codebase. Fo
## Code Formatting ## Code Formatting
1. Ruff 1. Ruff
- Format: `uv run ruff format .` - Format: `uv run --frozen ruff format .`
- Check: `uv run ruff check .` - Check: `uv run --frozen ruff check .`
- Fix: `uv run ruff check . --fix` - Fix: `uv run --frozen ruff check . --fix`
- Critical issues: - Critical issues:
- Line length (88 chars) - Line length (88 chars)
- Import sorting (I001) - Import sorting (I001)
@@ -67,7 +67,7 @@ This document contains critical information about working with this codebase. Fo
- Imports: split into multiple lines - Imports: split into multiple lines
2. Type Checking 2. Type Checking
- Tool: `uv run pyright` - Tool: `uv run --frozen pyright`
- Requirements: - Requirements:
- Explicit None checks for Optional - Explicit None checks for Optional
- Type narrowing for strings - Type narrowing for strings
@@ -104,6 +104,10 @@ This document contains critical information about working with this codebase. Fo
- Add None checks - Add None checks
- Narrow string types - Narrow string types
- Match existing patterns - Match existing patterns
- Pytest:
- If the tests aren't finding the anyio pytest mark, try adding PYTEST_DISABLE_PLUGIN_AUTOLOAD=""
to the start of the pytest run command eg:
`PYTEST_DISABLE_PLUGIN_AUTOLOAD="" uv run --frozen pytest`
3. Best Practices 3. Best Practices
- Check git status before commits - Check git status before commits

View File

@@ -309,6 +309,33 @@ async def long_task(files: list[str], ctx: Context) -> str:
return "Processing complete" return "Processing complete"
``` ```
### Authentication
Authentication can be used by servers that want to expose tools accessing protected resources.
`mcp.server.auth` implements an OAuth 2.0 server interface, which servers can use by
providing an implementation of the `OAuthServerProvider` protocol.
```
mcp = FastMCP("My App",
auth_provider=MyOAuthServerProvider(),
auth=AuthSettings(
issuer_url="https://myapp.com",
revocation_options=RevocationOptions(
enabled=True,
),
client_registration_options=ClientRegistrationOptions(
enabled=True,
valid_scopes=["myscope", "myotherscope"],
default_scopes=["myscope"],
),
required_scopes=["myscope"],
),
)
```
See [OAuthServerProvider](mcp/server/auth/provider.py) for more details.
## Running Your Server ## Running Your Server
### Development Mode ### Development Mode

View File

@@ -323,8 +323,7 @@ class ChatSession:
total = result["total"] total = result["total"]
percentage = (progress / total) * 100 percentage = (progress / total) * 100
logging.info( logging.info(
f"Progress: {progress}/{total} " f"Progress: {progress}/{total} ({percentage:.1f}%)"
f"({percentage:.1f}%)"
) )
return f"Tool execution result: {result}" return f"Tool execution result: {result}"

View File

@@ -27,6 +27,7 @@ dependencies = [
"httpx-sse>=0.4", "httpx-sse>=0.4",
"pydantic>=2.7.2,<3.0.0", "pydantic>=2.7.2,<3.0.0",
"starlette>=0.27", "starlette>=0.27",
"python-multipart>=0.0.9",
"sse-starlette>=1.6.1", "sse-starlette>=1.6.1",
"pydantic-settings>=2.5.2", "pydantic-settings>=2.5.2",
"uvicorn>=0.23.1; sys_platform != 'emscripten'", "uvicorn>=0.23.1; sys_platform != 'emscripten'",

View File

@@ -0,0 +1,3 @@
"""
MCP OAuth server authorization components.
"""

View File

@@ -0,0 +1,8 @@
from pydantic import ValidationError
def stringify_pydantic_error(validation_error: ValidationError) -> str:
return "\n".join(
f"{'.'.join(str(loc) for loc in e['loc'])}: {e['msg']}"
for e in validation_error.errors()
)

View File

@@ -0,0 +1,3 @@
"""
Request handlers for MCP authorization endpoints.
"""

View File

@@ -0,0 +1,244 @@
import logging
from dataclasses import dataclass
from typing import Any, Literal
from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, RootModel, ValidationError
from starlette.datastructures import FormData, QueryParams
from starlette.requests import Request
from starlette.responses import RedirectResponse, Response
from mcp.server.auth.errors import (
stringify_pydantic_error,
)
from mcp.server.auth.json_response import PydanticJSONResponse
from mcp.server.auth.provider import (
AuthorizationErrorCode,
AuthorizationParams,
AuthorizeError,
OAuthAuthorizationServerProvider,
construct_redirect_uri,
)
from mcp.shared.auth import (
InvalidRedirectUriError,
InvalidScopeError,
)
logger = logging.getLogger(__name__)
class AuthorizationRequest(BaseModel):
# See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.1
client_id: str = Field(..., description="The client ID")
redirect_uri: AnyHttpUrl | None = Field(
None, description="URL to redirect to after authorization"
)
# see OAuthClientMetadata; we only support `code`
response_type: Literal["code"] = Field(
..., description="Must be 'code' for authorization code flow"
)
code_challenge: str = Field(..., description="PKCE code challenge")
code_challenge_method: Literal["S256"] = Field(
"S256", description="PKCE code challenge method, must be S256"
)
state: str | None = Field(None, description="Optional state parameter")
scope: str | None = Field(
None,
description="Optional scope; if specified, should be "
"a space-separated list of scope strings",
)
class AuthorizationErrorResponse(BaseModel):
error: AuthorizationErrorCode
error_description: str | None
error_uri: AnyUrl | None = None
# must be set if provided in the request
state: str | None = None
def best_effort_extract_string(
key: str, params: None | FormData | QueryParams
) -> str | None:
if params is None:
return None
value = params.get(key)
if isinstance(value, str):
return value
return None
class AnyHttpUrlModel(RootModel[AnyHttpUrl]):
root: AnyHttpUrl
@dataclass
class AuthorizationHandler:
provider: OAuthAuthorizationServerProvider[Any, Any, Any]
async def handle(self, request: Request) -> Response:
# implements authorization requests for grant_type=code;
# see https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.1
state = None
redirect_uri = None
client = None
params = None
async def error_response(
error: AuthorizationErrorCode,
error_description: str | None,
attempt_load_client: bool = True,
):
# Error responses take two different formats:
# 1. The request has a valid client ID & redirect_uri: we issue a redirect
# back to the redirect_uri with the error response fields as query
# parameters. This allows the client to be notified of the error.
# 2. Otherwise, we return an error response directly to the end user;
# we choose to do so in JSON, but this is left undefined in the
# specification.
# See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.2.1
#
# This logic is a bit awkward to handle, because the error might be thrown
# very early in request validation, before we've done the usual Pydantic
# validation, loaded the client, etc. To handle this, error_response()
# contains fallback logic which attempts to load the parameters directly
# from the request.
nonlocal client, redirect_uri, state
if client is None and attempt_load_client:
# make last-ditch attempt to load the client
client_id = best_effort_extract_string("client_id", params)
client = client_id and await self.provider.get_client(client_id)
if redirect_uri is None and client:
# make last-ditch effort to load the redirect uri
try:
if params is not None and "redirect_uri" not in params:
raw_redirect_uri = None
else:
raw_redirect_uri = AnyHttpUrlModel.model_validate(
best_effort_extract_string("redirect_uri", params)
).root
redirect_uri = client.validate_redirect_uri(raw_redirect_uri)
except (ValidationError, InvalidRedirectUriError):
# if the redirect URI is invalid, ignore it & just return the
# initial error
pass
# the error response MUST contain the state specified by the client, if any
if state is None:
# make last-ditch effort to load state
state = best_effort_extract_string("state", params)
error_resp = AuthorizationErrorResponse(
error=error,
error_description=error_description,
state=state,
)
if redirect_uri and client:
return RedirectResponse(
url=construct_redirect_uri(
str(redirect_uri), **error_resp.model_dump(exclude_none=True)
),
status_code=302,
headers={"Cache-Control": "no-store"},
)
else:
return PydanticJSONResponse(
status_code=400,
content=error_resp,
headers={"Cache-Control": "no-store"},
)
try:
# Parse request parameters
if request.method == "GET":
# Convert query_params to dict for pydantic validation
params = request.query_params
else:
# Parse form data for POST requests
params = await request.form()
# Save state if it exists, even before validation
state = best_effort_extract_string("state", params)
try:
auth_request = AuthorizationRequest.model_validate(params)
state = auth_request.state # Update with validated state
except ValidationError as validation_error:
error: AuthorizationErrorCode = "invalid_request"
for e in validation_error.errors():
if e["loc"] == ("response_type",) and e["type"] == "literal_error":
error = "unsupported_response_type"
break
return await error_response(
error, stringify_pydantic_error(validation_error)
)
# Get client information
client = await self.provider.get_client(
auth_request.client_id,
)
if not client:
# For client_id validation errors, return direct error (no redirect)
return await error_response(
error="invalid_request",
error_description=f"Client ID '{auth_request.client_id}' not found",
attempt_load_client=False,
)
# Validate redirect_uri against client's registered URIs
try:
redirect_uri = client.validate_redirect_uri(auth_request.redirect_uri)
except InvalidRedirectUriError as validation_error:
# For redirect_uri validation errors, return direct error (no redirect)
return await error_response(
error="invalid_request",
error_description=validation_error.message,
)
# Validate scope - for scope errors, we can redirect
try:
scopes = client.validate_scope(auth_request.scope)
except InvalidScopeError as validation_error:
# For scope errors, redirect with error parameters
return await error_response(
error="invalid_scope",
error_description=validation_error.message,
)
# Setup authorization parameters
auth_params = AuthorizationParams(
state=state,
scopes=scopes,
code_challenge=auth_request.code_challenge,
redirect_uri=redirect_uri,
redirect_uri_provided_explicitly=auth_request.redirect_uri is not None,
)
try:
# Let the provider pick the next URI to redirect to
return RedirectResponse(
url=await self.provider.authorize(
client,
auth_params,
),
status_code=302,
headers={"Cache-Control": "no-store"},
)
except AuthorizeError as e:
# Handle authorization errors as defined in RFC 6749 Section 4.1.2.1
return await error_response(
error=e.error,
error_description=e.error_description,
)
except Exception as validation_error:
# Catch-all for unexpected errors
logger.exception(
"Unexpected error in authorization_handler", exc_info=validation_error
)
return await error_response(
error="server_error", error_description="An unexpected error occurred"
)

View File

@@ -0,0 +1,18 @@
from dataclasses import dataclass
from starlette.requests import Request
from starlette.responses import Response
from mcp.server.auth.json_response import PydanticJSONResponse
from mcp.shared.auth import OAuthMetadata
@dataclass
class MetadataHandler:
metadata: OAuthMetadata
async def handle(self, request: Request) -> Response:
return PydanticJSONResponse(
content=self.metadata,
headers={"Cache-Control": "public, max-age=3600"}, # Cache for 1 hour
)

View File

@@ -0,0 +1,129 @@
import secrets
import time
from dataclasses import dataclass
from typing import Any
from uuid import uuid4
from pydantic import BaseModel, RootModel, ValidationError
from starlette.requests import Request
from starlette.responses import Response
from mcp.server.auth.errors import stringify_pydantic_error
from mcp.server.auth.json_response import PydanticJSONResponse
from mcp.server.auth.provider import (
OAuthAuthorizationServerProvider,
RegistrationError,
RegistrationErrorCode,
)
from mcp.server.auth.settings import ClientRegistrationOptions
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata
class RegistrationRequest(RootModel[OAuthClientMetadata]):
# this wrapper is a no-op; it's just to separate out the types exposed to the
# provider from what we use in the HTTP handler
root: OAuthClientMetadata
class RegistrationErrorResponse(BaseModel):
error: RegistrationErrorCode
error_description: str | None
@dataclass
class RegistrationHandler:
provider: OAuthAuthorizationServerProvider[Any, Any, Any]
options: ClientRegistrationOptions
async def handle(self, request: Request) -> Response:
# Implements dynamic client registration as defined in https://datatracker.ietf.org/doc/html/rfc7591#section-3.1
try:
# Parse request body as JSON
body = await request.json()
client_metadata = OAuthClientMetadata.model_validate(body)
# Scope validation is handled below
except ValidationError as validation_error:
return PydanticJSONResponse(
content=RegistrationErrorResponse(
error="invalid_client_metadata",
error_description=stringify_pydantic_error(validation_error),
),
status_code=400,
)
client_id = str(uuid4())
client_secret = None
if client_metadata.token_endpoint_auth_method != "none":
# cryptographically secure random 32-byte hex string
client_secret = secrets.token_hex(32)
if client_metadata.scope is None and self.options.default_scopes is not None:
client_metadata.scope = " ".join(self.options.default_scopes)
elif (
client_metadata.scope is not None and self.options.valid_scopes is not None
):
requested_scopes = set(client_metadata.scope.split())
valid_scopes = set(self.options.valid_scopes)
if not requested_scopes.issubset(valid_scopes):
return PydanticJSONResponse(
content=RegistrationErrorResponse(
error="invalid_client_metadata",
error_description="Requested scopes are not valid: "
f"{', '.join(requested_scopes - valid_scopes)}",
),
status_code=400,
)
if set(client_metadata.grant_types) != {"authorization_code", "refresh_token"}:
return PydanticJSONResponse(
content=RegistrationErrorResponse(
error="invalid_client_metadata",
error_description="grant_types must be authorization_code "
"and refresh_token",
),
status_code=400,
)
client_id_issued_at = int(time.time())
client_secret_expires_at = (
client_id_issued_at + self.options.client_secret_expiry_seconds
if self.options.client_secret_expiry_seconds is not None
else None
)
client_info = OAuthClientInformationFull(
client_id=client_id,
client_id_issued_at=client_id_issued_at,
client_secret=client_secret,
client_secret_expires_at=client_secret_expires_at,
# passthrough information from the client request
redirect_uris=client_metadata.redirect_uris,
token_endpoint_auth_method=client_metadata.token_endpoint_auth_method,
grant_types=client_metadata.grant_types,
response_types=client_metadata.response_types,
client_name=client_metadata.client_name,
client_uri=client_metadata.client_uri,
logo_uri=client_metadata.logo_uri,
scope=client_metadata.scope,
contacts=client_metadata.contacts,
tos_uri=client_metadata.tos_uri,
policy_uri=client_metadata.policy_uri,
jwks_uri=client_metadata.jwks_uri,
jwks=client_metadata.jwks,
software_id=client_metadata.software_id,
software_version=client_metadata.software_version,
)
try:
# Register client
await self.provider.register_client(client_info)
# Return client information
return PydanticJSONResponse(content=client_info, status_code=201)
except RegistrationError as e:
# Handle registration errors as defined in RFC 7591 Section 3.2.2
return PydanticJSONResponse(
content=RegistrationErrorResponse(
error=e.error, error_description=e.error_description
),
status_code=400,
)

View File

@@ -0,0 +1,101 @@
from dataclasses import dataclass
from functools import partial
from typing import Any, Literal
from pydantic import BaseModel, ValidationError
from starlette.requests import Request
from starlette.responses import Response
from mcp.server.auth.errors import (
stringify_pydantic_error,
)
from mcp.server.auth.json_response import PydanticJSONResponse
from mcp.server.auth.middleware.client_auth import (
AuthenticationError,
ClientAuthenticator,
)
from mcp.server.auth.provider import (
AccessToken,
OAuthAuthorizationServerProvider,
RefreshToken,
)
class RevocationRequest(BaseModel):
"""
# See https://datatracker.ietf.org/doc/html/rfc7009#section-2.1
"""
token: str
token_type_hint: Literal["access_token", "refresh_token"] | None = None
client_id: str
client_secret: str | None
class RevocationErrorResponse(BaseModel):
error: Literal["invalid_request", "unauthorized_client"]
error_description: str | None = None
@dataclass
class RevocationHandler:
provider: OAuthAuthorizationServerProvider[Any, Any, Any]
client_authenticator: ClientAuthenticator
async def handle(self, request: Request) -> Response:
"""
Handler for the OAuth 2.0 Token Revocation endpoint.
"""
try:
form_data = await request.form()
revocation_request = RevocationRequest.model_validate(dict(form_data))
except ValidationError as e:
return PydanticJSONResponse(
status_code=400,
content=RevocationErrorResponse(
error="invalid_request",
error_description=stringify_pydantic_error(e),
),
)
# Authenticate client
try:
client = await self.client_authenticator.authenticate(
revocation_request.client_id, revocation_request.client_secret
)
except AuthenticationError as e:
return PydanticJSONResponse(
status_code=401,
content=RevocationErrorResponse(
error="unauthorized_client",
error_description=e.message,
),
)
loaders = [
self.provider.load_access_token,
partial(self.provider.load_refresh_token, client),
]
if revocation_request.token_type_hint == "refresh_token":
loaders = reversed(loaders)
token: None | AccessToken | RefreshToken = None
for loader in loaders:
token = await loader(revocation_request.token)
if token is not None:
break
# if token is not found, just return HTTP 200 per the RFC
if token and token.client_id == client.client_id:
# Revoke token; provider is not meant to be able to do validation
# at this point that would result in an error
await self.provider.revoke_token(token)
# Return successful empty response
return Response(
status_code=200,
headers={
"Cache-Control": "no-store",
"Pragma": "no-cache",
},
)

View File

@@ -0,0 +1,264 @@
import base64
import hashlib
import time
from dataclasses import dataclass
from typing import Annotated, Any, Literal
from pydantic import AnyHttpUrl, BaseModel, Field, RootModel, ValidationError
from starlette.requests import Request
from mcp.server.auth.errors import (
stringify_pydantic_error,
)
from mcp.server.auth.json_response import PydanticJSONResponse
from mcp.server.auth.middleware.client_auth import (
AuthenticationError,
ClientAuthenticator,
)
from mcp.server.auth.provider import (
OAuthAuthorizationServerProvider,
TokenError,
TokenErrorCode,
)
from mcp.shared.auth import OAuthToken
class AuthorizationCodeRequest(BaseModel):
# See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.3
grant_type: Literal["authorization_code"]
code: str = Field(..., description="The authorization code")
redirect_uri: AnyHttpUrl | None = Field(
None, description="Must be the same as redirect URI provided in /authorize"
)
client_id: str
# we use the client_secret param, per https://datatracker.ietf.org/doc/html/rfc6749#section-2.3.1
client_secret: str | None = None
# See https://datatracker.ietf.org/doc/html/rfc7636#section-4.5
code_verifier: str = Field(..., description="PKCE code verifier")
class RefreshTokenRequest(BaseModel):
# See https://datatracker.ietf.org/doc/html/rfc6749#section-6
grant_type: Literal["refresh_token"]
refresh_token: str = Field(..., description="The refresh token")
scope: str | None = Field(None, description="Optional scope parameter")
client_id: str
# we use the client_secret param, per https://datatracker.ietf.org/doc/html/rfc6749#section-2.3.1
client_secret: str | None = None
class TokenRequest(
RootModel[
Annotated[
AuthorizationCodeRequest | RefreshTokenRequest,
Field(discriminator="grant_type"),
]
]
):
root: Annotated[
AuthorizationCodeRequest | RefreshTokenRequest,
Field(discriminator="grant_type"),
]
class TokenErrorResponse(BaseModel):
"""
See https://datatracker.ietf.org/doc/html/rfc6749#section-5.2
"""
error: TokenErrorCode
error_description: str | None = None
error_uri: AnyHttpUrl | None = None
class TokenSuccessResponse(RootModel[OAuthToken]):
# this is just a wrapper over OAuthToken; the only reason we do this
# is to have some separation between the HTTP response type, and the
# type returned by the provider
root: OAuthToken
@dataclass
class TokenHandler:
provider: OAuthAuthorizationServerProvider[Any, Any, Any]
client_authenticator: ClientAuthenticator
def response(self, obj: TokenSuccessResponse | TokenErrorResponse):
status_code = 200
if isinstance(obj, TokenErrorResponse):
status_code = 400
return PydanticJSONResponse(
content=obj,
status_code=status_code,
headers={
"Cache-Control": "no-store",
"Pragma": "no-cache",
},
)
async def handle(self, request: Request):
try:
form_data = await request.form()
token_request = TokenRequest.model_validate(dict(form_data)).root
except ValidationError as validation_error:
return self.response(
TokenErrorResponse(
error="invalid_request",
error_description=stringify_pydantic_error(validation_error),
)
)
try:
client_info = await self.client_authenticator.authenticate(
client_id=token_request.client_id,
client_secret=token_request.client_secret,
)
except AuthenticationError as e:
return self.response(
TokenErrorResponse(
error="unauthorized_client",
error_description=e.message,
)
)
if token_request.grant_type not in client_info.grant_types:
return self.response(
TokenErrorResponse(
error="unsupported_grant_type",
error_description=(
f"Unsupported grant type (supported grant types are "
f"{client_info.grant_types})"
),
)
)
tokens: OAuthToken
match token_request:
case AuthorizationCodeRequest():
auth_code = await self.provider.load_authorization_code(
client_info, token_request.code
)
if auth_code is None or auth_code.client_id != token_request.client_id:
# if code belongs to different client, pretend it doesn't exist
return self.response(
TokenErrorResponse(
error="invalid_grant",
error_description="authorization code does not exist",
)
)
# make auth codes expire after a deadline
# see https://datatracker.ietf.org/doc/html/rfc6749#section-10.5
if auth_code.expires_at < time.time():
return self.response(
TokenErrorResponse(
error="invalid_grant",
error_description="authorization code has expired",
)
)
# verify redirect_uri doesn't change between /authorize and /tokens
# see https://datatracker.ietf.org/doc/html/rfc6749#section-10.6
if auth_code.redirect_uri_provided_explicitly:
authorize_request_redirect_uri = auth_code.redirect_uri
else:
authorize_request_redirect_uri = None
if token_request.redirect_uri != authorize_request_redirect_uri:
return self.response(
TokenErrorResponse(
error="invalid_request",
error_description=(
"redirect_uri did not match the one "
"used when creating auth code"
),
)
)
# Verify PKCE code verifier
sha256 = hashlib.sha256(token_request.code_verifier.encode()).digest()
hashed_code_verifier = (
base64.urlsafe_b64encode(sha256).decode().rstrip("=")
)
if hashed_code_verifier != auth_code.code_challenge:
# see https://datatracker.ietf.org/doc/html/rfc7636#section-4.6
return self.response(
TokenErrorResponse(
error="invalid_grant",
error_description="incorrect code_verifier",
)
)
try:
# Exchange authorization code for tokens
tokens = await self.provider.exchange_authorization_code(
client_info, auth_code
)
except TokenError as e:
return self.response(
TokenErrorResponse(
error=e.error,
error_description=e.error_description,
)
)
case RefreshTokenRequest():
refresh_token = await self.provider.load_refresh_token(
client_info, token_request.refresh_token
)
if (
refresh_token is None
or refresh_token.client_id != token_request.client_id
):
# if token belongs to different client, pretend it doesn't exist
return self.response(
TokenErrorResponse(
error="invalid_grant",
error_description="refresh token does not exist",
)
)
if refresh_token.expires_at and refresh_token.expires_at < time.time():
# if the refresh token has expired, pretend it doesn't exist
return self.response(
TokenErrorResponse(
error="invalid_grant",
error_description="refresh token has expired",
)
)
# Parse scopes if provided
scopes = (
token_request.scope.split(" ")
if token_request.scope
else refresh_token.scopes
)
for scope in scopes:
if scope not in refresh_token.scopes:
return self.response(
TokenErrorResponse(
error="invalid_scope",
error_description=(
f"cannot request scope `{scope}` "
"not provided by refresh token"
),
)
)
try:
# Exchange refresh token for new tokens
tokens = await self.provider.exchange_refresh_token(
client_info, refresh_token, scopes
)
except TokenError as e:
return self.response(
TokenErrorResponse(
error=e.error,
error_description=e.error_description,
)
)
return self.response(TokenSuccessResponse(root=tokens))

View File

@@ -0,0 +1,10 @@
from typing import Any
from starlette.responses import JSONResponse
class PydanticJSONResponse(JSONResponse):
# use pydantic json serialization instead of the stock `json.dumps`,
# so that we can handle serializing pydantic models like AnyHttpUrl
def render(self, content: Any) -> bytes:
return content.model_dump_json(exclude_none=True).encode("utf-8")

View File

@@ -0,0 +1,3 @@
"""
Middleware for MCP authorization.
"""

View File

@@ -0,0 +1,50 @@
import contextvars
from starlette.types import ASGIApp, Receive, Scope, Send
from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser
from mcp.server.auth.provider import AccessToken
# Create a contextvar to store the authenticated user
# The default is None, indicating no authenticated user is present
auth_context_var = contextvars.ContextVar[AuthenticatedUser | None](
"auth_context", default=None
)
def get_access_token() -> AccessToken | None:
"""
Get the access token from the current context.
Returns:
The access token if an authenticated user is available, None otherwise.
"""
auth_user = auth_context_var.get()
return auth_user.access_token if auth_user else None
class AuthContextMiddleware:
"""
Middleware that extracts the authenticated user from the request
and sets it in a contextvar for easy access throughout the request lifecycle.
This middleware should be added after the AuthenticationMiddleware in the
middleware stack to ensure that the user is properly authenticated before
being stored in the context.
"""
def __init__(self, app: ASGIApp):
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send):
user = scope.get("user")
if isinstance(user, AuthenticatedUser):
# Set the authenticated user in the contextvar
token = auth_context_var.set(user)
try:
await self.app(scope, receive, send)
finally:
auth_context_var.reset(token)
else:
# No authenticated user, just process the request
await self.app(scope, receive, send)

View File

@@ -0,0 +1,89 @@
import time
from typing import Any
from starlette.authentication import (
AuthCredentials,
AuthenticationBackend,
SimpleUser,
)
from starlette.exceptions import HTTPException
from starlette.requests import HTTPConnection
from starlette.types import Receive, Scope, Send
from mcp.server.auth.provider import AccessToken, OAuthAuthorizationServerProvider
class AuthenticatedUser(SimpleUser):
"""User with authentication info."""
def __init__(self, auth_info: AccessToken):
super().__init__(auth_info.client_id)
self.access_token = auth_info
self.scopes = auth_info.scopes
class BearerAuthBackend(AuthenticationBackend):
"""
Authentication backend that validates Bearer tokens.
"""
def __init__(
self,
provider: OAuthAuthorizationServerProvider[Any, Any, Any],
):
self.provider = provider
async def authenticate(self, conn: HTTPConnection):
auth_header = conn.headers.get("Authorization")
if not auth_header or not auth_header.startswith("Bearer "):
return None
token = auth_header[7:] # Remove "Bearer " prefix
# Validate the token with the provider
auth_info = await self.provider.load_access_token(token)
if not auth_info:
return None
if auth_info.expires_at and auth_info.expires_at < int(time.time()):
return None
return AuthCredentials(auth_info.scopes), AuthenticatedUser(auth_info)
class RequireAuthMiddleware:
"""
Middleware that requires a valid Bearer token in the Authorization header.
This will validate the token with the auth provider and store the resulting
auth info in the request state.
"""
def __init__(self, app: Any, required_scopes: list[str]):
"""
Initialize the middleware.
Args:
app: ASGI application
provider: Authentication provider to validate tokens
required_scopes: Optional list of scopes that the token must have
"""
self.app = app
self.required_scopes = required_scopes
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
auth_user = scope.get("user")
if not isinstance(auth_user, AuthenticatedUser):
raise HTTPException(status_code=401, detail="Unauthorized")
auth_credentials = scope.get("auth")
for required_scope in self.required_scopes:
# auth_credentials should always be provided; this is just paranoia
if (
auth_credentials is None
or required_scope not in auth_credentials.scopes
):
raise HTTPException(status_code=403, detail="Insufficient scope")
await self.app(scope, receive, send)

View File

@@ -0,0 +1,56 @@
import time
from typing import Any
from mcp.server.auth.provider import OAuthAuthorizationServerProvider
from mcp.shared.auth import OAuthClientInformationFull
class AuthenticationError(Exception):
def __init__(self, message: str):
self.message = message
class ClientAuthenticator:
"""
ClientAuthenticator is a callable which validates requests from a client
application, used to verify /token calls.
If, during registration, the client requested to be issued a secret, the
authenticator asserts that /token calls must be authenticated with
that same token.
NOTE: clients can opt for no authentication during registration, in which case this
logic is skipped.
"""
def __init__(self, provider: OAuthAuthorizationServerProvider[Any, Any, Any]):
"""
Initialize the dependency.
Args:
provider: Provider to look up client information
"""
self.provider = provider
async def authenticate(
self, client_id: str, client_secret: str | None
) -> OAuthClientInformationFull:
# Look up client information
client = await self.provider.get_client(client_id)
if not client:
raise AuthenticationError("Invalid client_id")
# If client from the store expects a secret, validate that the request provides
# that secret
if client.client_secret:
if not client_secret:
raise AuthenticationError("Client secret is required")
if client.client_secret != client_secret:
raise AuthenticationError("Invalid client_secret")
if (
client.client_secret_expires_at
and client.client_secret_expires_at < int(time.time())
):
raise AuthenticationError("Client secret has expired")
return client

View File

@@ -0,0 +1,289 @@
from dataclasses import dataclass
from typing import Generic, Literal, Protocol, TypeVar
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
from pydantic import AnyHttpUrl, BaseModel
from mcp.shared.auth import (
OAuthClientInformationFull,
OAuthToken,
)
class AuthorizationParams(BaseModel):
state: str | None
scopes: list[str] | None
code_challenge: str
redirect_uri: AnyHttpUrl
redirect_uri_provided_explicitly: bool
class AuthorizationCode(BaseModel):
code: str
scopes: list[str]
expires_at: float
client_id: str
code_challenge: str
redirect_uri: AnyHttpUrl
redirect_uri_provided_explicitly: bool
class RefreshToken(BaseModel):
token: str
client_id: str
scopes: list[str]
expires_at: int | None = None
class AccessToken(BaseModel):
token: str
client_id: str
scopes: list[str]
expires_at: int | None = None
RegistrationErrorCode = Literal[
"invalid_redirect_uri",
"invalid_client_metadata",
"invalid_software_statement",
"unapproved_software_statement",
]
@dataclass(frozen=True)
class RegistrationError(Exception):
error: RegistrationErrorCode
error_description: str | None = None
AuthorizationErrorCode = Literal[
"invalid_request",
"unauthorized_client",
"access_denied",
"unsupported_response_type",
"invalid_scope",
"server_error",
"temporarily_unavailable",
]
@dataclass(frozen=True)
class AuthorizeError(Exception):
error: AuthorizationErrorCode
error_description: str | None = None
TokenErrorCode = Literal[
"invalid_request",
"invalid_client",
"invalid_grant",
"unauthorized_client",
"unsupported_grant_type",
"invalid_scope",
]
@dataclass(frozen=True)
class TokenError(Exception):
error: TokenErrorCode
error_description: str | None = None
# NOTE: FastMCP doesn't render any of these types in the user response, so it's
# OK to add fields to subclasses which should not be exposed externally.
AuthorizationCodeT = TypeVar("AuthorizationCodeT", bound=AuthorizationCode)
RefreshTokenT = TypeVar("RefreshTokenT", bound=RefreshToken)
AccessTokenT = TypeVar("AccessTokenT", bound=AccessToken)
class OAuthAuthorizationServerProvider(
Protocol, Generic[AuthorizationCodeT, RefreshTokenT, AccessTokenT]
):
async def get_client(self, client_id: str) -> OAuthClientInformationFull | None:
"""
Retrieves client information by client ID.
Implementors MAY raise NotImplementedError if dynamic client registration is
disabled in ClientRegistrationOptions.
Args:
client_id: The ID of the client to retrieve.
Returns:
The client information, or None if the client does not exist.
"""
...
async def register_client(self, client_info: OAuthClientInformationFull) -> None:
"""
Saves client information as part of registering it.
Implementors MAY raise NotImplementedError if dynamic client registration is
disabled in ClientRegistrationOptions.
Args:
client_info: The client metadata to register.
Raises:
RegistrationError: If the client metadata is invalid.
"""
...
async def authorize(
self, client: OAuthClientInformationFull, params: AuthorizationParams
) -> str:
"""
Called as part of the /authorize endpoint, and returns a URL that the client
will be redirected to.
Many MCP implementations will redirect to a third-party provider to perform
a second OAuth exchange with that provider. In this sort of setup, the client
has an OAuth connection with the MCP server, and the MCP server has an OAuth
connection with the 3rd-party provider. At the end of this flow, the client
should be redirected to the redirect_uri from params.redirect_uri.
+--------+ +------------+ +-------------------+
| | | | | |
| Client | --> | MCP Server | --> | 3rd Party OAuth |
| | | | | Server |
+--------+ +------------+ +-------------------+
| ^ |
+------------+ | | |
| | | | Redirect |
|redirect_uri|<-----+ +------------------+
| |
+------------+
Implementations will need to define another handler on the MCP server return
flow to perform the second redirect, and generate and store an authorization
code as part of completing the OAuth authorization step.
Implementations SHOULD generate an authorization code with at least 160 bits of
entropy,
and MUST generate an authorization code with at least 128 bits of entropy.
See https://datatracker.ietf.org/doc/html/rfc6749#section-10.10.
Args:
client: The client requesting authorization.
params: The parameters of the authorization request.
Returns:
A URL to redirect the client to for authorization.
Raises:
AuthorizeError: If the authorization request is invalid.
"""
...
async def load_authorization_code(
self, client: OAuthClientInformationFull, authorization_code: str
) -> AuthorizationCodeT | None:
"""
Loads an AuthorizationCode by its code.
Args:
client: The client that requested the authorization code.
authorization_code: The authorization code to get the challenge for.
Returns:
The AuthorizationCode, or None if not found
"""
...
async def exchange_authorization_code(
self, client: OAuthClientInformationFull, authorization_code: AuthorizationCodeT
) -> OAuthToken:
"""
Exchanges an authorization code for an access token and refresh token.
Args:
client: The client exchanging the authorization code.
authorization_code: The authorization code to exchange.
Returns:
The OAuth token, containing access and refresh tokens.
Raises:
TokenError: If the request is invalid
"""
...
async def load_refresh_token(
self, client: OAuthClientInformationFull, refresh_token: str
) -> RefreshTokenT | None:
"""
Loads a RefreshToken by its token string.
Args:
client: The client that is requesting to load the refresh token.
refresh_token: The refresh token string to load.
Returns:
The RefreshToken object if found, or None if not found.
"""
...
async def exchange_refresh_token(
self,
client: OAuthClientInformationFull,
refresh_token: RefreshTokenT,
scopes: list[str],
) -> OAuthToken:
"""
Exchanges a refresh token for an access token and refresh token.
Implementations SHOULD rotate both the access token and refresh token.
Args:
client: The client exchanging the refresh token.
refresh_token: The refresh token to exchange.
scopes: Optional scopes to request with the new access token.
Returns:
The OAuth token, containing access and refresh tokens.
Raises:
TokenError: If the request is invalid
"""
...
async def load_access_token(self, token: str) -> AccessTokenT | None:
"""
Loads an access token by its token.
Args:
token: The access token to verify.
Returns:
The AuthInfo, or None if the token is invalid.
"""
...
async def revoke_token(
self,
token: AccessTokenT | RefreshTokenT,
) -> None:
"""
Revokes an access or refresh token.
If the given token is invalid or already revoked, this method should do nothing.
Implementations SHOULD revoke both the access token and its corresponding
refresh token, regardless of which of the access token or refresh token is
provided.
Args:
token: the token to revoke
"""
...
def construct_redirect_uri(redirect_uri_base: str, **params: str | None) -> str:
parsed_uri = urlparse(redirect_uri_base)
query_params = [(k, v) for k, vs in parse_qs(parsed_uri.query) for v in vs]
for k, v in params.items():
if v is not None:
query_params.append((k, v))
redirect_uri = urlunparse(parsed_uri._replace(query=urlencode(query_params)))
return redirect_uri

View File

@@ -0,0 +1,207 @@
from collections.abc import Awaitable, Callable
from typing import Any
from pydantic import AnyHttpUrl
from starlette.middleware.cors import CORSMiddleware
from starlette.requests import Request
from starlette.responses import Response
from starlette.routing import Route, request_response # type: ignore
from starlette.types import ASGIApp
from mcp.server.auth.handlers.authorize import AuthorizationHandler
from mcp.server.auth.handlers.metadata import MetadataHandler
from mcp.server.auth.handlers.register import RegistrationHandler
from mcp.server.auth.handlers.revoke import RevocationHandler
from mcp.server.auth.handlers.token import TokenHandler
from mcp.server.auth.middleware.client_auth import ClientAuthenticator
from mcp.server.auth.provider import OAuthAuthorizationServerProvider
from mcp.server.auth.settings import ClientRegistrationOptions, RevocationOptions
from mcp.shared.auth import OAuthMetadata
def validate_issuer_url(url: AnyHttpUrl):
"""
Validate that the issuer URL meets OAuth 2.0 requirements.
Args:
url: The issuer URL to validate
Raises:
ValueError: If the issuer URL is invalid
"""
# RFC 8414 requires HTTPS, but we allow localhost HTTP for testing
if (
url.scheme != "https"
and url.host != "localhost"
and not url.host.startswith("127.0.0.1")
):
raise ValueError("Issuer URL must be HTTPS")
# No fragments or query parameters allowed
if url.fragment:
raise ValueError("Issuer URL must not have a fragment")
if url.query:
raise ValueError("Issuer URL must not have a query string")
AUTHORIZATION_PATH = "/authorize"
TOKEN_PATH = "/token"
REGISTRATION_PATH = "/register"
REVOCATION_PATH = "/revoke"
def cors_middleware(
handler: Callable[[Request], Response | Awaitable[Response]],
allow_methods: list[str],
) -> ASGIApp:
cors_app = CORSMiddleware(
app=request_response(handler),
allow_origins="*",
allow_methods=allow_methods,
allow_headers=["mcp-protocol-version"],
)
return cors_app
def create_auth_routes(
provider: OAuthAuthorizationServerProvider[Any, Any, Any],
issuer_url: AnyHttpUrl,
service_documentation_url: AnyHttpUrl | None = None,
client_registration_options: ClientRegistrationOptions | None = None,
revocation_options: RevocationOptions | None = None,
) -> list[Route]:
validate_issuer_url(issuer_url)
client_registration_options = (
client_registration_options or ClientRegistrationOptions()
)
revocation_options = revocation_options or RevocationOptions()
metadata = build_metadata(
issuer_url,
service_documentation_url,
client_registration_options,
revocation_options,
)
client_authenticator = ClientAuthenticator(provider)
# Create routes
# Allow CORS requests for endpoints meant to be hit by the OAuth client
# (with the client secret). This is intended to support things like MCP Inspector,
# where the client runs in a web browser.
routes = [
Route(
"/.well-known/oauth-authorization-server",
endpoint=cors_middleware(
MetadataHandler(metadata).handle,
["GET", "OPTIONS"],
),
methods=["GET", "OPTIONS"],
),
Route(
AUTHORIZATION_PATH,
# do not allow CORS for authorization endpoint;
# clients should just redirect to this
endpoint=AuthorizationHandler(provider).handle,
methods=["GET", "POST"],
),
Route(
TOKEN_PATH,
endpoint=cors_middleware(
TokenHandler(provider, client_authenticator).handle,
["POST", "OPTIONS"],
),
methods=["POST", "OPTIONS"],
),
]
if client_registration_options.enabled:
registration_handler = RegistrationHandler(
provider,
options=client_registration_options,
)
routes.append(
Route(
REGISTRATION_PATH,
endpoint=cors_middleware(
registration_handler.handle,
["POST", "OPTIONS"],
),
methods=["POST", "OPTIONS"],
)
)
if revocation_options.enabled:
revocation_handler = RevocationHandler(provider, client_authenticator)
routes.append(
Route(
REVOCATION_PATH,
endpoint=cors_middleware(
revocation_handler.handle,
["POST", "OPTIONS"],
),
methods=["POST", "OPTIONS"],
)
)
return routes
def modify_url_path(url: AnyHttpUrl, path_mapper: Callable[[str], str]) -> AnyHttpUrl:
return AnyHttpUrl.build(
scheme=url.scheme,
username=url.username,
password=url.password,
host=url.host,
port=url.port,
path=path_mapper(url.path or ""),
query=url.query,
fragment=url.fragment,
)
def build_metadata(
issuer_url: AnyHttpUrl,
service_documentation_url: AnyHttpUrl | None,
client_registration_options: ClientRegistrationOptions,
revocation_options: RevocationOptions,
) -> OAuthMetadata:
authorization_url = modify_url_path(
issuer_url, lambda path: path.rstrip("/") + AUTHORIZATION_PATH.lstrip("/")
)
token_url = modify_url_path(
issuer_url, lambda path: path.rstrip("/") + TOKEN_PATH.lstrip("/")
)
# Create metadata
metadata = OAuthMetadata(
issuer=issuer_url,
authorization_endpoint=authorization_url,
token_endpoint=token_url,
scopes_supported=None,
response_types_supported=["code"],
response_modes_supported=None,
grant_types_supported=["authorization_code", "refresh_token"],
token_endpoint_auth_methods_supported=["client_secret_post"],
token_endpoint_auth_signing_alg_values_supported=None,
service_documentation=service_documentation_url,
ui_locales_supported=None,
op_policy_uri=None,
op_tos_uri=None,
introspection_endpoint=None,
code_challenge_methods_supported=["S256"],
)
# Add registration endpoint if supported
if client_registration_options.enabled:
metadata.registration_endpoint = modify_url_path(
issuer_url, lambda path: path.rstrip("/") + REGISTRATION_PATH.lstrip("/")
)
# Add revocation endpoint if supported
if revocation_options.enabled:
metadata.revocation_endpoint = modify_url_path(
issuer_url, lambda path: path.rstrip("/") + REVOCATION_PATH.lstrip("/")
)
metadata.revocation_endpoint_auth_methods_supported = ["client_secret_post"]
return metadata

View File

@@ -0,0 +1,24 @@
from pydantic import AnyHttpUrl, BaseModel, Field
class ClientRegistrationOptions(BaseModel):
enabled: bool = False
client_secret_expiry_seconds: int | None = None
valid_scopes: list[str] | None = None
default_scopes: list[str] | None = None
class RevocationOptions(BaseModel):
enabled: bool = False
class AuthSettings(BaseModel):
issuer_url: AnyHttpUrl = Field(
...,
description="URL advertised as OAuth issuer; this should be the URL the server "
"is reachable at",
)
service_documentation_url: AnyHttpUrl | None = None
client_registration_options: ClientRegistrationOptions | None = None
revocation_options: RevocationOptions | None = None
required_scopes: list[str] | None = None

View File

@@ -4,7 +4,7 @@ from __future__ import annotations as _annotations
import inspect import inspect
import re import re
from collections.abc import AsyncIterator, Callable, Iterable, Sequence from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Sequence
from contextlib import ( from contextlib import (
AbstractAsyncContextManager, AbstractAsyncContextManager,
asynccontextmanager, asynccontextmanager,
@@ -18,9 +18,22 @@ from pydantic import BaseModel, Field
from pydantic.networks import AnyUrl from pydantic.networks import AnyUrl
from pydantic_settings import BaseSettings, SettingsConfigDict from pydantic_settings import BaseSettings, SettingsConfigDict
from starlette.applications import Starlette from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.authentication import AuthenticationMiddleware
from starlette.requests import Request from starlette.requests import Request
from starlette.responses import Response
from starlette.routing import Mount, Route from starlette.routing import Mount, Route
from starlette.types import Receive, Scope, Send
from mcp.server.auth.middleware.auth_context import AuthContextMiddleware
from mcp.server.auth.middleware.bearer_auth import (
BearerAuthBackend,
RequireAuthMiddleware,
)
from mcp.server.auth.provider import OAuthAuthorizationServerProvider
from mcp.server.auth.settings import (
AuthSettings,
)
from mcp.server.fastmcp.exceptions import ResourceError from mcp.server.fastmcp.exceptions import ResourceError
from mcp.server.fastmcp.prompts import Prompt, PromptManager from mcp.server.fastmcp.prompts import Prompt, PromptManager
from mcp.server.fastmcp.resources import FunctionResource, Resource, ResourceManager from mcp.server.fastmcp.resources import FunctionResource, Resource, ResourceManager
@@ -62,6 +75,8 @@ class Settings(BaseSettings, Generic[LifespanResultT]):
model_config = SettingsConfigDict( model_config = SettingsConfigDict(
env_prefix="FASTMCP_", env_prefix="FASTMCP_",
env_file=".env", env_file=".env",
env_nested_delimiter="__",
nested_model_default_partial_update=True,
extra="ignore", extra="ignore",
) )
@@ -93,6 +108,8 @@ class Settings(BaseSettings, Generic[LifespanResultT]):
Callable[[FastMCP], AbstractAsyncContextManager[LifespanResultT]] | None Callable[[FastMCP], AbstractAsyncContextManager[LifespanResultT]] | None
) = Field(None, description="Lifespan context manager") ) = Field(None, description="Lifespan context manager")
auth: AuthSettings | None = None
def lifespan_wrapper( def lifespan_wrapper(
app: FastMCP, app: FastMCP,
@@ -108,7 +125,12 @@ def lifespan_wrapper(
class FastMCP: class FastMCP:
def __init__( def __init__(
self, name: str | None = None, instructions: str | None = None, **settings: Any self,
name: str | None = None,
instructions: str | None = None,
auth_server_provider: OAuthAuthorizationServerProvider[Any, Any, Any]
| None = None,
**settings: Any,
): ):
self.settings = Settings(**settings) self.settings = Settings(**settings)
@@ -128,6 +150,18 @@ class FastMCP:
self._prompt_manager = PromptManager( self._prompt_manager = PromptManager(
warn_on_duplicate_prompts=self.settings.warn_on_duplicate_prompts warn_on_duplicate_prompts=self.settings.warn_on_duplicate_prompts
) )
if (self.settings.auth is not None) != (auth_server_provider is not None):
# TODO: after we support separate authorization servers (see
# https://github.com/modelcontextprotocol/modelcontextprotocol/pull/284)
# we should validate that if auth is enabled, we have either an
# auth_server_provider to host our own authorization server,
# OR the URL of a 3rd party authorization server.
raise ValueError(
"settings.auth must be specified if and only if auth_server_provider "
"is specified"
)
self._auth_server_provider = auth_server_provider
self._custom_starlette_routes: list[Route] = []
self.dependencies = self.settings.dependencies self.dependencies = self.settings.dependencies
# Set up MCP protocol handlers # Set up MCP protocol handlers
@@ -465,6 +499,50 @@ class FastMCP:
return decorator return decorator
def custom_route(
self,
path: str,
methods: list[str],
name: str | None = None,
include_in_schema: bool = True,
):
"""
Decorator to register a custom HTTP route on the FastMCP server.
Allows adding arbitrary HTTP endpoints outside the standard MCP protocol,
which can be useful for OAuth callbacks, health checks, or admin APIs.
The handler function must be an async function that accepts a Starlette
Request and returns a Response.
Args:
path: URL path for the route (e.g., "/oauth/callback")
methods: List of HTTP methods to support (e.g., ["GET", "POST"])
name: Optional name for the route (to reference this route with
Starlette's reverse URL lookup feature)
include_in_schema: Whether to include in OpenAPI schema, defaults to True
Example:
@server.custom_route("/health", methods=["GET"])
async def health_check(request: Request) -> Response:
return JSONResponse({"status": "ok"})
"""
def decorator(
func: Callable[[Request], Awaitable[Response]],
) -> Callable[[Request], Awaitable[Response]]:
self._custom_starlette_routes.append(
Route(
path,
endpoint=func,
methods=methods,
name=name,
include_in_schema=include_in_schema,
)
)
return func
return decorator
async def run_stdio_async(self) -> None: async def run_stdio_async(self) -> None:
"""Run the server using stdio transport.""" """Run the server using stdio transport."""
async with stdio_server() as (read_stream, write_stream): async with stdio_server() as (read_stream, write_stream):
@@ -491,13 +569,20 @@ class FastMCP:
def sse_app(self) -> Starlette: def sse_app(self) -> Starlette:
"""Return an instance of the SSE server app.""" """Return an instance of the SSE server app."""
from starlette.middleware import Middleware
from starlette.routing import Mount, Route
# Set up auth context and dependencies
sse = SseServerTransport(self.settings.message_path) sse = SseServerTransport(self.settings.message_path)
async def handle_sse(request: Request) -> None: async def handle_sse(scope: Scope, receive: Receive, send: Send):
# Add client ID from auth context into request context if available
async with sse.connect_sse( async with sse.connect_sse(
request.scope, scope,
request.receive, receive,
request._send, # type: ignore[reportPrivateUsage] send,
) as streams: ) as streams:
await self._mcp_server.run( await self._mcp_server.run(
streams[0], streams[0],
@@ -505,12 +590,59 @@ class FastMCP:
self._mcp_server.create_initialization_options(), self._mcp_server.create_initialization_options(),
) )
# Create routes
routes: list[Route | Mount] = []
middleware: list[Middleware] = []
required_scopes = []
# Add auth endpoints if auth provider is configured
if self._auth_server_provider:
assert self.settings.auth
from mcp.server.auth.routes import create_auth_routes
required_scopes = self.settings.auth.required_scopes or []
middleware = [
# extract auth info from request (but do not require it)
Middleware(
AuthenticationMiddleware,
backend=BearerAuthBackend(
provider=self._auth_server_provider,
),
),
# Add the auth context middleware to store
# authenticated user in a contextvar
Middleware(AuthContextMiddleware),
]
routes.extend(
create_auth_routes(
provider=self._auth_server_provider,
issuer_url=self.settings.auth.issuer_url,
service_documentation_url=self.settings.auth.service_documentation_url,
client_registration_options=self.settings.auth.client_registration_options,
revocation_options=self.settings.auth.revocation_options,
)
)
routes.append(
Route(
self.settings.sse_path,
endpoint=RequireAuthMiddleware(handle_sse, required_scopes),
methods=["GET"],
)
)
routes.append(
Mount(
self.settings.message_path,
app=RequireAuthMiddleware(sse.handle_post_message, required_scopes),
)
)
# mount these routes last, so they have the lowest route matching precedence
routes.extend(self._custom_starlette_routes)
# Create Starlette app with routes and middleware
return Starlette( return Starlette(
debug=self.settings.debug, debug=self.settings.debug, routes=routes, middleware=middleware
routes=[
Route(self.settings.sse_path, endpoint=handle_sse),
Mount(self.settings.message_path, app=sse.handle_post_message),
],
) )
async def list_prompts(self) -> list[MCPPrompt]: async def list_prompts(self) -> list[MCPPrompt]:

View File

@@ -576,14 +576,12 @@ class Server(Generic[LifespanResultT]):
assert type(notify) in self.notification_handlers assert type(notify) in self.notification_handlers
handler = self.notification_handlers[type(notify)] handler = self.notification_handlers[type(notify)]
logger.debug( logger.debug(f"Dispatching notification of type {type(notify).__name__}")
f"Dispatching notification of type " f"{type(notify).__name__}"
)
try: try:
await handler(notify) await handler(notify)
except Exception as err: except Exception as err:
logger.error(f"Uncaught exception in notification handler: " f"{err}") logger.error(f"Uncaught exception in notification handler: {err}")
async def _ping_handler(request: types.PingRequest) -> types.ServerResult: async def _ping_handler(request: types.PingRequest) -> types.ServerResult:

View File

@@ -0,0 +1,213 @@
"""
A modified version of httpx.ASGITransport that supports streaming responses.
This transport runs the ASGI app as a separate anyio task, allowing it to
handle streaming responses like SSE where the app doesn't terminate until
the connection is closed.
This is only intended for writing tests for the SSE transport.
"""
import typing
from typing import Any, cast
import anyio
import anyio.abc
import anyio.streams.memory
from httpx._models import Request, Response
from httpx._transports.base import AsyncBaseTransport
from httpx._types import AsyncByteStream
from starlette.types import ASGIApp, Receive, Scope, Send
class StreamingASGITransport(AsyncBaseTransport):
"""
A custom AsyncTransport that handles sending requests directly to an ASGI app
and supports streaming responses like SSE.
Unlike the standard ASGITransport, this transport runs the ASGI app in a
separate anyio task, allowing it to handle responses from apps that don't
terminate immediately (like SSE endpoints).
Arguments:
* `app` - The ASGI application.
* `raise_app_exceptions` - Boolean indicating if exceptions in the application
should be raised. Default to `True`. Can be set to `False` for use cases
such as testing the content of a client 500 response.
* `root_path` - The root path on which the ASGI application should be mounted.
* `client` - A two-tuple indicating the client IP and port of incoming requests.
* `response_timeout` - Timeout in seconds to wait for the initial response.
Default is 10 seconds.
TODO: https://github.com/encode/httpx/pull/3059 is adding something similar to
upstream httpx. When that merges, we should delete this & switch back to the
upstream implementation.
"""
def __init__(
self,
app: ASGIApp,
task_group: anyio.abc.TaskGroup,
raise_app_exceptions: bool = True,
root_path: str = "",
client: tuple[str, int] = ("127.0.0.1", 123),
) -> None:
self.app = app
self.raise_app_exceptions = raise_app_exceptions
self.root_path = root_path
self.client = client
self.task_group = task_group
async def handle_async_request(
self,
request: Request,
) -> Response:
assert isinstance(request.stream, AsyncByteStream)
# ASGI scope.
scope = {
"type": "http",
"asgi": {"version": "3.0"},
"http_version": "1.1",
"method": request.method,
"headers": [(k.lower(), v) for (k, v) in request.headers.raw],
"scheme": request.url.scheme,
"path": request.url.path,
"raw_path": request.url.raw_path.split(b"?")[0],
"query_string": request.url.query,
"server": (request.url.host, request.url.port),
"client": self.client,
"root_path": self.root_path,
}
# Request body
request_body_chunks = request.stream.__aiter__()
request_complete = False
# Response state
status_code = 499
response_headers = None
response_started = False
response_complete = anyio.Event()
initial_response_ready = anyio.Event()
# Synchronization for streaming response
asgi_send_channel, asgi_receive_channel = anyio.create_memory_object_stream[
dict[str, Any]
](100)
content_send_channel, content_receive_channel = (
anyio.create_memory_object_stream[bytes](100)
)
# ASGI callables.
async def receive() -> dict[str, Any]:
nonlocal request_complete
if request_complete:
await response_complete.wait()
return {"type": "http.disconnect"}
try:
body = await request_body_chunks.__anext__()
except StopAsyncIteration:
request_complete = True
return {"type": "http.request", "body": b"", "more_body": False}
return {"type": "http.request", "body": body, "more_body": True}
async def send(message: dict[str, Any]) -> None:
nonlocal status_code, response_headers, response_started
await asgi_send_channel.send(message)
# Start the ASGI application in a separate task
async def run_app() -> None:
try:
# Cast the receive and send functions to the ASGI types
await self.app(
cast(Scope, scope), cast(Receive, receive), cast(Send, send)
)
except Exception:
if self.raise_app_exceptions:
raise
if not response_started:
await asgi_send_channel.send(
{"type": "http.response.start", "status": 500, "headers": []}
)
await asgi_send_channel.send(
{"type": "http.response.body", "body": b"", "more_body": False}
)
finally:
await asgi_send_channel.aclose()
# Process messages from the ASGI app
async def process_messages() -> None:
nonlocal status_code, response_headers, response_started
try:
async with asgi_receive_channel:
async for message in asgi_receive_channel:
if message["type"] == "http.response.start":
assert not response_started
status_code = message["status"]
response_headers = message.get("headers", [])
response_started = True
# As soon as we have headers, we can return a response
initial_response_ready.set()
elif message["type"] == "http.response.body":
body = message.get("body", b"")
more_body = message.get("more_body", False)
if body and request.method != "HEAD":
await content_send_channel.send(body)
if not more_body:
response_complete.set()
await content_send_channel.aclose()
break
finally:
# Ensure events are set even if there's an error
initial_response_ready.set()
response_complete.set()
await content_send_channel.aclose()
# Create tasks for running the app and processing messages
self.task_group.start_soon(run_app)
self.task_group.start_soon(process_messages)
# Wait for the initial response or timeout
await initial_response_ready.wait()
# Create a streaming response
return Response(
status_code,
headers=response_headers,
stream=StreamingASGIResponseStream(content_receive_channel),
)
class StreamingASGIResponseStream(AsyncByteStream):
"""
A modified ASGIResponseStream that supports streaming responses.
This class extends the standard ASGIResponseStream to handle cases where
the response body continues to be generated after the initial response
is returned.
"""
def __init__(
self,
receive_channel: anyio.streams.memory.MemoryObjectReceiveStream[bytes],
) -> None:
self.receive_channel = receive_channel
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
try:
async for chunk in self.receive_channel:
yield chunk
finally:
await self.receive_channel.aclose()

137
src/mcp/shared/auth.py Normal file
View File

@@ -0,0 +1,137 @@
from typing import Any, Literal
from pydantic import AnyHttpUrl, BaseModel, Field
class OAuthToken(BaseModel):
"""
See https://datatracker.ietf.org/doc/html/rfc6749#section-5.1
"""
access_token: str
token_type: Literal["bearer"] = "bearer"
expires_in: int | None = None
scope: str | None = None
refresh_token: str | None = None
class InvalidScopeError(Exception):
def __init__(self, message: str):
self.message = message
class InvalidRedirectUriError(Exception):
def __init__(self, message: str):
self.message = message
class OAuthClientMetadata(BaseModel):
"""
RFC 7591 OAuth 2.0 Dynamic Client Registration metadata.
See https://datatracker.ietf.org/doc/html/rfc7591#section-2
for the full specification.
"""
redirect_uris: list[AnyHttpUrl] = Field(..., min_length=1)
# token_endpoint_auth_method: this implementation only supports none &
# client_secret_post;
# ie: we do not support client_secret_basic
token_endpoint_auth_method: Literal["none", "client_secret_post"] = (
"client_secret_post"
)
# grant_types: this implementation only supports authorization_code & refresh_token
grant_types: list[Literal["authorization_code", "refresh_token"]] = [
"authorization_code",
"refresh_token",
]
# this implementation only supports code; ie: it does not support implicit grants
response_types: list[Literal["code"]] = ["code"]
scope: str | None = None
# these fields are currently unused, but we support & store them for potential
# future use
client_name: str | None = None
client_uri: AnyHttpUrl | None = None
logo_uri: AnyHttpUrl | None = None
contacts: list[str] | None = None
tos_uri: AnyHttpUrl | None = None
policy_uri: AnyHttpUrl | None = None
jwks_uri: AnyHttpUrl | None = None
jwks: Any | None = None
software_id: str | None = None
software_version: str | None = None
def validate_scope(self, requested_scope: str | None) -> list[str] | None:
if requested_scope is None:
return None
requested_scopes = requested_scope.split(" ")
allowed_scopes = [] if self.scope is None else self.scope.split(" ")
for scope in requested_scopes:
if scope not in allowed_scopes:
raise InvalidScopeError(f"Client was not registered with scope {scope}")
return requested_scopes
def validate_redirect_uri(self, redirect_uri: AnyHttpUrl | None) -> AnyHttpUrl:
if redirect_uri is not None:
# Validate redirect_uri against client's registered redirect URIs
if redirect_uri not in self.redirect_uris:
raise InvalidRedirectUriError(
f"Redirect URI '{redirect_uri}' not registered for client"
)
return redirect_uri
elif len(self.redirect_uris) == 1:
return self.redirect_uris[0]
else:
raise InvalidRedirectUriError(
"redirect_uri must be specified when client "
"has multiple registered URIs"
)
class OAuthClientInformationFull(OAuthClientMetadata):
"""
RFC 7591 OAuth 2.0 Dynamic Client Registration full response
(client information plus metadata).
"""
client_id: str
client_secret: str | None = None
client_id_issued_at: int | None = None
client_secret_expires_at: int | None = None
class OAuthMetadata(BaseModel):
"""
RFC 8414 OAuth 2.0 Authorization Server Metadata.
See https://datatracker.ietf.org/doc/html/rfc8414#section-2
"""
issuer: AnyHttpUrl
authorization_endpoint: AnyHttpUrl
token_endpoint: AnyHttpUrl
registration_endpoint: AnyHttpUrl | None = None
scopes_supported: list[str] | None = None
response_types_supported: list[Literal["code"]] = ["code"]
response_modes_supported: list[Literal["query", "fragment"]] | None = None
grant_types_supported: (
list[Literal["authorization_code", "refresh_token"]] | None
) = None
token_endpoint_auth_methods_supported: (
list[Literal["none", "client_secret_post"]] | None
) = None
token_endpoint_auth_signing_alg_values_supported: None = None
service_documentation: AnyHttpUrl | None = None
ui_locales_supported: list[str] | None = None
op_policy_uri: AnyHttpUrl | None = None
op_tos_uri: AnyHttpUrl | None = None
revocation_endpoint: AnyHttpUrl | None = None
revocation_endpoint_auth_methods_supported: (
list[Literal["client_secret_post"]] | None
) = None
revocation_endpoint_auth_signing_alg_values_supported: None = None
introspection_endpoint: AnyHttpUrl | None = None
introspection_endpoint_auth_methods_supported: (
list[Literal["client_secret_post"]] | None
) = None
introspection_endpoint_auth_signing_alg_values_supported: None = None
code_challenge_methods_supported: list[Literal["S256"]] | None = None

View File

@@ -0,0 +1,122 @@
"""
Tests for the AuthContext middleware components.
"""
import time
import pytest
from starlette.types import Message, Receive, Scope, Send
from mcp.server.auth.middleware.auth_context import (
AuthContextMiddleware,
auth_context_var,
get_access_token,
)
from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser
from mcp.server.auth.provider import AccessToken
class MockApp:
"""Mock ASGI app for testing."""
def __init__(self):
self.called = False
self.scope: Scope | None = None
self.receive: Receive | None = None
self.send: Send | None = None
self.access_token_during_call: AccessToken | None = None
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
self.called = True
self.scope = scope
self.receive = receive
self.send = send
# Check the context during the call
self.access_token_during_call = get_access_token()
@pytest.fixture
def valid_access_token() -> AccessToken:
"""Create a valid access token."""
return AccessToken(
token="valid_token",
client_id="test_client",
scopes=["read", "write"],
expires_at=int(time.time()) + 3600, # 1 hour from now
)
@pytest.mark.anyio
class TestAuthContextMiddleware:
"""Tests for the AuthContextMiddleware class."""
async def test_with_authenticated_user(self, valid_access_token: AccessToken):
"""Test middleware with an authenticated user in scope."""
app = MockApp()
middleware = AuthContextMiddleware(app)
# Create an authenticated user
user = AuthenticatedUser(valid_access_token)
scope: Scope = {"type": "http", "user": user}
# Create dummy async functions for receive and send
async def receive() -> Message:
return {"type": "http.request"}
async def send(message: Message) -> None:
pass
# Verify context is empty before middleware
assert auth_context_var.get() is None
assert get_access_token() is None
# Run the middleware
await middleware(scope, receive, send)
# Verify the app was called
assert app.called
assert app.scope == scope
assert app.receive == receive
assert app.send == send
# Verify the access token was available during the call
assert app.access_token_during_call == valid_access_token
# Verify context is reset after middleware
assert auth_context_var.get() is None
assert get_access_token() is None
async def test_with_no_user(self):
"""Test middleware with no user in scope."""
app = MockApp()
middleware = AuthContextMiddleware(app)
scope: Scope = {"type": "http"} # No user
# Create dummy async functions for receive and send
async def receive() -> Message:
return {"type": "http.request"}
async def send(message: Message) -> None:
pass
# Verify context is empty before middleware
assert auth_context_var.get() is None
assert get_access_token() is None
# Run the middleware
await middleware(scope, receive, send)
# Verify the app was called
assert app.called
assert app.scope == scope
assert app.receive == receive
assert app.send == send
# Verify the access token was not available during the call
assert app.access_token_during_call is None
# Verify context is still empty after middleware
assert auth_context_var.get() is None
assert get_access_token() is None

View File

@@ -0,0 +1,391 @@
"""
Tests for the BearerAuth middleware components.
"""
import time
from typing import Any, cast
import pytest
from starlette.authentication import AuthCredentials
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.types import Message, Receive, Scope, Send
from mcp.server.auth.middleware.bearer_auth import (
AuthenticatedUser,
BearerAuthBackend,
RequireAuthMiddleware,
)
from mcp.server.auth.provider import (
AccessToken,
OAuthAuthorizationServerProvider,
)
class MockOAuthProvider:
"""Mock OAuth provider for testing.
This is a simplified version that only implements the methods needed for testing
the BearerAuthMiddleware components.
"""
def __init__(self):
self.tokens = {} # token -> AccessToken
def add_token(self, token: str, access_token: AccessToken) -> None:
"""Add a token to the provider."""
self.tokens[token] = access_token
async def load_access_token(self, token: str) -> AccessToken | None:
"""Load an access token."""
return self.tokens.get(token)
def add_token_to_provider(
provider: OAuthAuthorizationServerProvider[Any, Any, Any],
token: str,
access_token: AccessToken,
) -> None:
"""Helper function to add a token to a provider.
This is used to work around type checking issues with our mock provider.
"""
# We know this is actually a MockOAuthProvider
mock_provider = cast(MockOAuthProvider, provider)
mock_provider.add_token(token, access_token)
class MockApp:
"""Mock ASGI app for testing."""
def __init__(self):
self.called = False
self.scope: Scope | None = None
self.receive: Receive | None = None
self.send: Send | None = None
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
self.called = True
self.scope = scope
self.receive = receive
self.send = send
@pytest.fixture
def mock_oauth_provider() -> OAuthAuthorizationServerProvider[Any, Any, Any]:
"""Create a mock OAuth provider."""
# Use type casting to satisfy the type checker
return cast(OAuthAuthorizationServerProvider[Any, Any, Any], MockOAuthProvider())
@pytest.fixture
def valid_access_token() -> AccessToken:
"""Create a valid access token."""
return AccessToken(
token="valid_token",
client_id="test_client",
scopes=["read", "write"],
expires_at=int(time.time()) + 3600, # 1 hour from now
)
@pytest.fixture
def expired_access_token() -> AccessToken:
"""Create an expired access token."""
return AccessToken(
token="expired_token",
client_id="test_client",
scopes=["read"],
expires_at=int(time.time()) - 3600, # 1 hour ago
)
@pytest.fixture
def no_expiry_access_token() -> AccessToken:
"""Create an access token with no expiry."""
return AccessToken(
token="no_expiry_token",
client_id="test_client",
scopes=["read", "write"],
expires_at=None,
)
@pytest.mark.anyio
class TestBearerAuthBackend:
"""Tests for the BearerAuthBackend class."""
async def test_no_auth_header(
self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any]
):
"""Test authentication with no Authorization header."""
backend = BearerAuthBackend(provider=mock_oauth_provider)
request = Request({"type": "http", "headers": []})
result = await backend.authenticate(request)
assert result is None
async def test_non_bearer_auth_header(
self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any]
):
"""Test authentication with non-Bearer Authorization header."""
backend = BearerAuthBackend(provider=mock_oauth_provider)
request = Request(
{
"type": "http",
"headers": [(b"authorization", b"Basic dXNlcjpwYXNz")],
}
)
result = await backend.authenticate(request)
assert result is None
async def test_invalid_token(
self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any]
):
"""Test authentication with invalid token."""
backend = BearerAuthBackend(provider=mock_oauth_provider)
request = Request(
{
"type": "http",
"headers": [(b"authorization", b"Bearer invalid_token")],
}
)
result = await backend.authenticate(request)
assert result is None
async def test_expired_token(
self,
mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any],
expired_access_token: AccessToken,
):
"""Test authentication with expired token."""
backend = BearerAuthBackend(provider=mock_oauth_provider)
add_token_to_provider(
mock_oauth_provider, "expired_token", expired_access_token
)
request = Request(
{
"type": "http",
"headers": [(b"authorization", b"Bearer expired_token")],
}
)
result = await backend.authenticate(request)
assert result is None
async def test_valid_token(
self,
mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any],
valid_access_token: AccessToken,
):
"""Test authentication with valid token."""
backend = BearerAuthBackend(provider=mock_oauth_provider)
add_token_to_provider(mock_oauth_provider, "valid_token", valid_access_token)
request = Request(
{
"type": "http",
"headers": [(b"authorization", b"Bearer valid_token")],
}
)
result = await backend.authenticate(request)
assert result is not None
credentials, user = result
assert isinstance(credentials, AuthCredentials)
assert isinstance(user, AuthenticatedUser)
assert credentials.scopes == ["read", "write"]
assert user.display_name == "test_client"
assert user.access_token == valid_access_token
assert user.scopes == ["read", "write"]
async def test_token_without_expiry(
self,
mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any],
no_expiry_access_token: AccessToken,
):
"""Test authentication with token that has no expiry."""
backend = BearerAuthBackend(provider=mock_oauth_provider)
add_token_to_provider(
mock_oauth_provider, "no_expiry_token", no_expiry_access_token
)
request = Request(
{
"type": "http",
"headers": [(b"authorization", b"Bearer no_expiry_token")],
}
)
result = await backend.authenticate(request)
assert result is not None
credentials, user = result
assert isinstance(credentials, AuthCredentials)
assert isinstance(user, AuthenticatedUser)
assert credentials.scopes == ["read", "write"]
assert user.display_name == "test_client"
assert user.access_token == no_expiry_access_token
assert user.scopes == ["read", "write"]
@pytest.mark.anyio
class TestRequireAuthMiddleware:
"""Tests for the RequireAuthMiddleware class."""
async def test_no_user(self):
"""Test middleware with no user in scope."""
app = MockApp()
middleware = RequireAuthMiddleware(app, required_scopes=["read"])
scope: Scope = {"type": "http"}
# Create dummy async functions for receive and send
async def receive() -> Message:
return {"type": "http.request"}
async def send(message: Message) -> None:
pass
with pytest.raises(HTTPException) as excinfo:
await middleware(scope, receive, send)
assert excinfo.value.status_code == 401
assert excinfo.value.detail == "Unauthorized"
assert not app.called
async def test_non_authenticated_user(self):
"""Test middleware with non-authenticated user in scope."""
app = MockApp()
middleware = RequireAuthMiddleware(app, required_scopes=["read"])
scope: Scope = {"type": "http", "user": object()}
# Create dummy async functions for receive and send
async def receive() -> Message:
return {"type": "http.request"}
async def send(message: Message) -> None:
pass
with pytest.raises(HTTPException) as excinfo:
await middleware(scope, receive, send)
assert excinfo.value.status_code == 401
assert excinfo.value.detail == "Unauthorized"
assert not app.called
async def test_missing_required_scope(self, valid_access_token: AccessToken):
"""Test middleware with user missing required scope."""
app = MockApp()
middleware = RequireAuthMiddleware(app, required_scopes=["admin"])
# Create a user with read/write scopes but not admin
user = AuthenticatedUser(valid_access_token)
auth = AuthCredentials(["read", "write"])
scope: Scope = {"type": "http", "user": user, "auth": auth}
# Create dummy async functions for receive and send
async def receive() -> Message:
return {"type": "http.request"}
async def send(message: Message) -> None:
pass
with pytest.raises(HTTPException) as excinfo:
await middleware(scope, receive, send)
assert excinfo.value.status_code == 403
assert excinfo.value.detail == "Insufficient scope"
assert not app.called
async def test_no_auth_credentials(self, valid_access_token: AccessToken):
"""Test middleware with no auth credentials in scope."""
app = MockApp()
middleware = RequireAuthMiddleware(app, required_scopes=["read"])
# Create a user with read/write scopes
user = AuthenticatedUser(valid_access_token)
scope: Scope = {"type": "http", "user": user} # No auth credentials
# Create dummy async functions for receive and send
async def receive() -> Message:
return {"type": "http.request"}
async def send(message: Message) -> None:
pass
with pytest.raises(HTTPException) as excinfo:
await middleware(scope, receive, send)
assert excinfo.value.status_code == 403
assert excinfo.value.detail == "Insufficient scope"
assert not app.called
async def test_has_required_scopes(self, valid_access_token: AccessToken):
"""Test middleware with user having all required scopes."""
app = MockApp()
middleware = RequireAuthMiddleware(app, required_scopes=["read"])
# Create a user with read/write scopes
user = AuthenticatedUser(valid_access_token)
auth = AuthCredentials(["read", "write"])
scope: Scope = {"type": "http", "user": user, "auth": auth}
# Create dummy async functions for receive and send
async def receive() -> Message:
return {"type": "http.request"}
async def send(message: Message) -> None:
pass
await middleware(scope, receive, send)
assert app.called
assert app.scope == scope
assert app.receive == receive
assert app.send == send
async def test_multiple_required_scopes(self, valid_access_token: AccessToken):
"""Test middleware with multiple required scopes."""
app = MockApp()
middleware = RequireAuthMiddleware(app, required_scopes=["read", "write"])
# Create a user with read/write scopes
user = AuthenticatedUser(valid_access_token)
auth = AuthCredentials(["read", "write"])
scope: Scope = {"type": "http", "user": user, "auth": auth}
# Create dummy async functions for receive and send
async def receive() -> Message:
return {"type": "http.request"}
async def send(message: Message) -> None:
pass
await middleware(scope, receive, send)
assert app.called
assert app.scope == scope
assert app.receive == receive
assert app.send == send
async def test_no_required_scopes(self, valid_access_token: AccessToken):
"""Test middleware with no required scopes."""
app = MockApp()
middleware = RequireAuthMiddleware(app, required_scopes=[])
# Create a user with read/write scopes
user = AuthenticatedUser(valid_access_token)
auth = AuthCredentials(["read", "write"])
scope: Scope = {"type": "http", "user": user, "auth": auth}
# Create dummy async functions for receive and send
async def receive() -> Message:
return {"type": "http.request"}
async def send(message: Message) -> None:
pass
await middleware(scope, receive, send)
assert app.called
assert app.scope == scope
assert app.receive == receive
assert app.send == send

View File

@@ -0,0 +1,294 @@
"""
Tests for OAuth error handling in the auth handlers.
"""
import unittest.mock
from urllib.parse import parse_qs, urlparse
import httpx
import pytest
from httpx import ASGITransport
from pydantic import AnyHttpUrl
from starlette.applications import Starlette
from mcp.server.auth.provider import (
AuthorizeError,
RegistrationError,
TokenError,
)
from mcp.server.auth.routes import create_auth_routes
from tests.server.fastmcp.auth.test_auth_integration import (
MockOAuthProvider,
)
@pytest.fixture
def oauth_provider():
"""Return a MockOAuthProvider instance that can be configured to raise errors."""
return MockOAuthProvider()
@pytest.fixture
def app(oauth_provider):
from mcp.server.auth.settings import ClientRegistrationOptions, RevocationOptions
# Enable client registration
client_registration_options = ClientRegistrationOptions(enabled=True)
revocation_options = RevocationOptions(enabled=True)
# Create auth routes
auth_routes = create_auth_routes(
oauth_provider,
issuer_url=AnyHttpUrl("http://localhost"),
client_registration_options=client_registration_options,
revocation_options=revocation_options,
)
# Create Starlette app with routes directly
return Starlette(routes=auth_routes)
@pytest.fixture
def client(app):
transport = ASGITransport(app=app)
# Use base_url without a path since routes are directly on the app
return httpx.AsyncClient(transport=transport, base_url="http://localhost")
@pytest.fixture
def pkce_challenge():
"""Create a PKCE challenge with code_verifier and code_challenge."""
import base64
import hashlib
import secrets
# Generate a code verifier
code_verifier = secrets.token_urlsafe(64)[:128]
# Create code challenge using S256 method
code_verifier_bytes = code_verifier.encode("ascii")
sha256 = hashlib.sha256(code_verifier_bytes).digest()
code_challenge = base64.urlsafe_b64encode(sha256).decode().rstrip("=")
return {"code_verifier": code_verifier, "code_challenge": code_challenge}
@pytest.fixture
async def registered_client(client):
"""Create and register a test client."""
# Default client metadata
client_metadata = {
"redirect_uris": ["https://client.example.com/callback"],
"token_endpoint_auth_method": "client_secret_post",
"grant_types": ["authorization_code", "refresh_token"],
"response_types": ["code"],
"client_name": "Test Client",
}
response = await client.post("/register", json=client_metadata)
assert response.status_code == 201, f"Failed to register client: {response.content}"
client_info = response.json()
return client_info
class TestRegistrationErrorHandling:
@pytest.mark.anyio
async def test_registration_error_handling(self, client, oauth_provider):
# Mock the register_client method to raise a registration error
with unittest.mock.patch.object(
oauth_provider,
"register_client",
side_effect=RegistrationError(
error="invalid_redirect_uri",
error_description="The redirect URI is invalid",
),
):
# Prepare a client registration request
client_data = {
"redirect_uris": ["https://client.example.com/callback"],
"token_endpoint_auth_method": "client_secret_post",
"grant_types": ["authorization_code", "refresh_token"],
"response_types": ["code"],
"client_name": "Test Client",
}
# Send the registration request
response = await client.post(
"/register",
json=client_data,
)
# Verify the response
assert response.status_code == 400, response.content
data = response.json()
assert data["error"] == "invalid_redirect_uri"
assert data["error_description"] == "The redirect URI is invalid"
class TestAuthorizeErrorHandling:
@pytest.mark.anyio
async def test_authorize_error_handling(
self, client, oauth_provider, registered_client, pkce_challenge
):
# Mock the authorize method to raise an authorize error
with unittest.mock.patch.object(
oauth_provider,
"authorize",
side_effect=AuthorizeError(
error="access_denied", error_description="The user denied the request"
),
):
# Register the client
client_id = registered_client["client_id"]
redirect_uri = registered_client["redirect_uris"][0]
# Prepare an authorization request
params = {
"client_id": client_id,
"redirect_uri": redirect_uri,
"response_type": "code",
"code_challenge": pkce_challenge["code_challenge"],
"code_challenge_method": "S256",
"state": "test_state",
}
# Send the authorization request
response = await client.get("/authorize", params=params)
# Verify the response is a redirect with error parameters
assert response.status_code == 302
redirect_url = response.headers["location"]
parsed_url = urlparse(redirect_url)
query_params = parse_qs(parsed_url.query)
assert query_params["error"][0] == "access_denied"
assert "error_description" in query_params
assert query_params["state"][0] == "test_state"
class TestTokenErrorHandling:
@pytest.mark.anyio
async def test_token_error_handling_auth_code(
self, client, oauth_provider, registered_client, pkce_challenge
):
# Register the client and get an auth code
client_id = registered_client["client_id"]
client_secret = registered_client["client_secret"]
redirect_uri = registered_client["redirect_uris"][0]
# First get an authorization code
auth_response = await client.get(
"/authorize",
params={
"client_id": client_id,
"redirect_uri": redirect_uri,
"response_type": "code",
"code_challenge": pkce_challenge["code_challenge"],
"code_challenge_method": "S256",
"state": "test_state",
},
)
redirect_url = auth_response.headers["location"]
parsed_url = urlparse(redirect_url)
query_params = parse_qs(parsed_url.query)
code = query_params["code"][0]
# Mock the exchange_authorization_code method to raise a token error
with unittest.mock.patch.object(
oauth_provider,
"exchange_authorization_code",
side_effect=TokenError(
error="invalid_grant",
error_description="The authorization code is invalid",
),
):
# Try to exchange the code for tokens
token_response = await client.post(
"/token",
data={
"grant_type": "authorization_code",
"code": code,
"redirect_uri": redirect_uri,
"client_id": client_id,
"client_secret": client_secret,
"code_verifier": pkce_challenge["code_verifier"],
},
)
# Verify the response
assert token_response.status_code == 400
data = token_response.json()
assert data["error"] == "invalid_grant"
assert data["error_description"] == "The authorization code is invalid"
@pytest.mark.anyio
async def test_token_error_handling_refresh_token(
self, client, oauth_provider, registered_client, pkce_challenge
):
# Register the client and get tokens
client_id = registered_client["client_id"]
client_secret = registered_client["client_secret"]
redirect_uri = registered_client["redirect_uris"][0]
# First get an authorization code
auth_response = await client.get(
"/authorize",
params={
"client_id": client_id,
"redirect_uri": redirect_uri,
"response_type": "code",
"code_challenge": pkce_challenge["code_challenge"],
"code_challenge_method": "S256",
"state": "test_state",
},
)
assert auth_response.status_code == 302, auth_response.content
redirect_url = auth_response.headers["location"]
parsed_url = urlparse(redirect_url)
query_params = parse_qs(parsed_url.query)
code = query_params["code"][0]
# Exchange the code for tokens
token_response = await client.post(
"/token",
data={
"grant_type": "authorization_code",
"code": code,
"redirect_uri": redirect_uri,
"client_id": client_id,
"client_secret": client_secret,
"code_verifier": pkce_challenge["code_verifier"],
},
)
tokens = token_response.json()
refresh_token = tokens["refresh_token"]
# Mock the exchange_refresh_token method to raise a token error
with unittest.mock.patch.object(
oauth_provider,
"exchange_refresh_token",
side_effect=TokenError(
error="invalid_scope",
error_description="The requested scope is invalid",
),
):
# Try to use the refresh token
refresh_response = await client.post(
"/token",
data={
"grant_type": "refresh_token",
"refresh_token": refresh_token,
"client_id": client_id,
"client_secret": client_secret,
},
)
# Verify the response
assert refresh_response.status_code == 400
data = refresh_response.json()
assert data["error"] == "invalid_scope"
assert data["error_description"] == "The requested scope is invalid"

View File

@@ -0,0 +1,3 @@
"""
Tests for the MCP server auth components.
"""

File diff suppressed because it is too large Load Diff

11
uv.lock generated
View File

@@ -494,6 +494,7 @@ dependencies = [
{ name = "httpx-sse" }, { name = "httpx-sse" },
{ name = "pydantic" }, { name = "pydantic" },
{ name = "pydantic-settings" }, { name = "pydantic-settings" },
{ name = "python-multipart" },
{ name = "sse-starlette" }, { name = "sse-starlette" },
{ name = "starlette" }, { name = "starlette" },
{ name = "uvicorn", marker = "sys_platform != 'emscripten'" }, { name = "uvicorn", marker = "sys_platform != 'emscripten'" },
@@ -537,6 +538,7 @@ requires-dist = [
{ name = "pydantic", specifier = ">=2.7.2,<3.0.0" }, { name = "pydantic", specifier = ">=2.7.2,<3.0.0" },
{ name = "pydantic-settings", specifier = ">=2.5.2" }, { name = "pydantic-settings", specifier = ">=2.5.2" },
{ name = "python-dotenv", marker = "extra == 'cli'", specifier = ">=1.0.0" }, { name = "python-dotenv", marker = "extra == 'cli'", specifier = ">=1.0.0" },
{ name = "python-multipart", specifier = ">=0.0.9" },
{ name = "rich", marker = "extra == 'rich'", specifier = ">=13.9.4" }, { name = "rich", marker = "extra == 'rich'", specifier = ">=13.9.4" },
{ name = "sse-starlette", specifier = ">=1.6.1" }, { name = "sse-starlette", specifier = ">=1.6.1" },
{ name = "starlette", specifier = ">=0.27" }, { name = "starlette", specifier = ">=0.27" },
@@ -1180,6 +1182,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/44/2f/62ea1c8b593f4e093cc1a7768f0d46112107e790c3e478532329e434f00b/python_dotenv-1.0.0-py3-none-any.whl", hash = "sha256:f5971a9226b701070a4bf2c38c89e5a3f0d64de8debda981d1db98583009122a", size = 19482 }, { url = "https://files.pythonhosted.org/packages/44/2f/62ea1c8b593f4e093cc1a7768f0d46112107e790c3e478532329e434f00b/python_dotenv-1.0.0-py3-none-any.whl", hash = "sha256:f5971a9226b701070a4bf2c38c89e5a3f0d64de8debda981d1db98583009122a", size = 19482 },
] ]
[[package]]
name = "python-multipart"
version = "0.0.9"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/5c/0f/9c55ac6c84c0336e22a26fa84ca6c51d58d7ac3a2d78b0dfa8748826c883/python_multipart-0.0.9.tar.gz", hash = "sha256:03f54688c663f1b7977105f021043b0793151e4cb1c1a9d4a11fc13d622c4026", size = 31516 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/3d/47/444768600d9e0ebc82f8e347775d24aef8f6348cf00e9fa0e81910814e6d/python_multipart-0.0.9-py3-none-any.whl", hash = "sha256:97ca7b8ea7b05f977dc3849c3ba99d51689822fab725c3703af7c866a0c2b215", size = 22299 },
]
[[package]] [[package]]
name = "pyyaml" name = "pyyaml"
version = "6.0.2" version = "6.0.2"