mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 14:54:24 +01:00
Add OAuth authentication client for HTTPX (#751)
Co-authored-by: Paul Carleton <paulc@anthropic.com>
This commit is contained in:
501
src/mcp/client/auth.py
Normal file
501
src/mcp/client/auth.py
Normal file
@@ -0,0 +1,501 @@
|
||||
"""
|
||||
OAuth2 Authentication implementation for HTTPX.
|
||||
|
||||
Implements authorization code flow with PKCE and automatic token refresh.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import logging
|
||||
import secrets
|
||||
import string
|
||||
import time
|
||||
from collections.abc import AsyncGenerator, Awaitable, Callable
|
||||
from typing import Protocol
|
||||
from urllib.parse import urlencode, urljoin
|
||||
|
||||
import anyio
|
||||
import httpx
|
||||
|
||||
from mcp.shared.auth import (
|
||||
OAuthClientInformationFull,
|
||||
OAuthClientMetadata,
|
||||
OAuthMetadata,
|
||||
OAuthToken,
|
||||
)
|
||||
from mcp.types import LATEST_PROTOCOL_VERSION
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TokenStorage(Protocol):
|
||||
"""Protocol for token storage implementations."""
|
||||
|
||||
async def get_tokens(self) -> OAuthToken | None:
|
||||
"""Get stored tokens."""
|
||||
...
|
||||
|
||||
async def set_tokens(self, tokens: OAuthToken) -> None:
|
||||
"""Store tokens."""
|
||||
...
|
||||
|
||||
async def get_client_info(self) -> OAuthClientInformationFull | None:
|
||||
"""Get stored client information."""
|
||||
...
|
||||
|
||||
async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
|
||||
"""Store client information."""
|
||||
...
|
||||
|
||||
|
||||
class OAuthClientProvider(httpx.Auth):
|
||||
"""
|
||||
Authentication for httpx using anyio.
|
||||
Handles OAuth flow with automatic client registration and token storage.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_url: str,
|
||||
client_metadata: OAuthClientMetadata,
|
||||
storage: TokenStorage,
|
||||
redirect_handler: Callable[[str], Awaitable[None]],
|
||||
callback_handler: Callable[[], Awaitable[tuple[str, str | None]]],
|
||||
timeout: float = 300.0,
|
||||
):
|
||||
"""
|
||||
Initialize OAuth2 authentication.
|
||||
|
||||
Args:
|
||||
server_url: Base URL of the OAuth server
|
||||
client_metadata: OAuth client metadata
|
||||
storage: Token storage implementation (defaults to in-memory)
|
||||
redirect_handler: Function to handle authorization URL like opening browser
|
||||
callback_handler: Function to wait for callback
|
||||
and return (auth_code, state)
|
||||
timeout: Timeout for OAuth flow in seconds
|
||||
"""
|
||||
self.server_url = server_url
|
||||
self.client_metadata = client_metadata
|
||||
self.storage = storage
|
||||
self.redirect_handler = redirect_handler
|
||||
self.callback_handler = callback_handler
|
||||
self.timeout = timeout
|
||||
|
||||
# Cached authentication state
|
||||
self._current_tokens: OAuthToken | None = None
|
||||
self._metadata: OAuthMetadata | None = None
|
||||
self._client_info: OAuthClientInformationFull | None = None
|
||||
self._token_expiry_time: float | None = None
|
||||
|
||||
# PKCE flow parameters
|
||||
self._code_verifier: str | None = None
|
||||
self._code_challenge: str | None = None
|
||||
|
||||
# State parameter for CSRF protection
|
||||
self._auth_state: str | None = None
|
||||
|
||||
# Thread safety lock
|
||||
self._token_lock = anyio.Lock()
|
||||
|
||||
def _generate_code_verifier(self) -> str:
|
||||
"""Generate a cryptographically random code verifier for PKCE."""
|
||||
return "".join(
|
||||
secrets.choice(string.ascii_letters + string.digits + "-._~")
|
||||
for _ in range(128)
|
||||
)
|
||||
|
||||
def _generate_code_challenge(self, code_verifier: str) -> str:
|
||||
"""Generate a code challenge from a code verifier using SHA256."""
|
||||
digest = hashlib.sha256(code_verifier.encode()).digest()
|
||||
return base64.urlsafe_b64encode(digest).decode().rstrip("=")
|
||||
|
||||
def _get_authorization_base_url(self, server_url: str) -> str:
|
||||
"""
|
||||
Extract base URL by removing path component.
|
||||
|
||||
Per MCP spec 2.3.2: https://api.example.com/v1/mcp -> https://api.example.com
|
||||
"""
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
|
||||
parsed = urlparse(server_url)
|
||||
# Remove path component
|
||||
return urlunparse((parsed.scheme, parsed.netloc, "", "", "", ""))
|
||||
|
||||
async def _discover_oauth_metadata(self, server_url: str) -> OAuthMetadata | None:
|
||||
"""
|
||||
Discover OAuth metadata from server's well-known endpoint.
|
||||
"""
|
||||
# Extract base URL per MCP spec
|
||||
auth_base_url = self._get_authorization_base_url(server_url)
|
||||
url = urljoin(auth_base_url, "/.well-known/oauth-authorization-server")
|
||||
headers = {"MCP-Protocol-Version": LATEST_PROTOCOL_VERSION}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
response = await client.get(url, headers=headers)
|
||||
if response.status_code == 404:
|
||||
return None
|
||||
response.raise_for_status()
|
||||
metadata_json = response.json()
|
||||
logger.debug(f"OAuth metadata discovered: {metadata_json}")
|
||||
return OAuthMetadata.model_validate(metadata_json)
|
||||
except Exception:
|
||||
# Retry without MCP header for CORS compatibility
|
||||
try:
|
||||
response = await client.get(url)
|
||||
if response.status_code == 404:
|
||||
return None
|
||||
response.raise_for_status()
|
||||
metadata_json = response.json()
|
||||
logger.debug(
|
||||
f"OAuth metadata discovered (no MCP header): {metadata_json}"
|
||||
)
|
||||
return OAuthMetadata.model_validate(metadata_json)
|
||||
except Exception:
|
||||
logger.exception("Failed to discover OAuth metadata")
|
||||
return None
|
||||
|
||||
async def _register_oauth_client(
|
||||
self,
|
||||
server_url: str,
|
||||
client_metadata: OAuthClientMetadata,
|
||||
metadata: OAuthMetadata | None = None,
|
||||
) -> OAuthClientInformationFull:
|
||||
"""
|
||||
Register OAuth client with server.
|
||||
"""
|
||||
if not metadata:
|
||||
metadata = await self._discover_oauth_metadata(server_url)
|
||||
|
||||
if metadata and metadata.registration_endpoint:
|
||||
registration_url = str(metadata.registration_endpoint)
|
||||
else:
|
||||
# Use fallback registration endpoint
|
||||
auth_base_url = self._get_authorization_base_url(server_url)
|
||||
registration_url = urljoin(auth_base_url, "/register")
|
||||
|
||||
# Handle default scope
|
||||
if (
|
||||
client_metadata.scope is None
|
||||
and metadata
|
||||
and metadata.scopes_supported is not None
|
||||
):
|
||||
client_metadata.scope = " ".join(metadata.scopes_supported)
|
||||
|
||||
# Serialize client metadata
|
||||
registration_data = client_metadata.model_dump(
|
||||
by_alias=True, mode="json", exclude_none=True
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
response = await client.post(
|
||||
registration_url,
|
||||
json=registration_data,
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
|
||||
if response.status_code not in (200, 201):
|
||||
raise httpx.HTTPStatusError(
|
||||
f"Registration failed: {response.status_code}",
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
|
||||
response_data = response.json()
|
||||
logger.debug(f"Registration successful: {response_data}")
|
||||
return OAuthClientInformationFull.model_validate(response_data)
|
||||
|
||||
except httpx.HTTPStatusError:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Registration error")
|
||||
raise
|
||||
|
||||
async def async_auth_flow(
|
||||
self, request: httpx.Request
|
||||
) -> AsyncGenerator[httpx.Request, httpx.Response]:
|
||||
"""
|
||||
HTTPX auth flow integration.
|
||||
"""
|
||||
|
||||
if not self._has_valid_token():
|
||||
await self.initialize()
|
||||
await self.ensure_token()
|
||||
# Add Bearer token if available
|
||||
if self._current_tokens and self._current_tokens.access_token:
|
||||
request.headers["Authorization"] = (
|
||||
f"Bearer {self._current_tokens.access_token}"
|
||||
)
|
||||
|
||||
response = yield request
|
||||
|
||||
# Clear token on 401 to trigger re-auth
|
||||
if response.status_code == 401:
|
||||
self._current_tokens = None
|
||||
|
||||
def _has_valid_token(self) -> bool:
|
||||
"""Check if current token is valid."""
|
||||
if not self._current_tokens or not self._current_tokens.access_token:
|
||||
return False
|
||||
|
||||
# Check expiry time
|
||||
if self._token_expiry_time and time.time() > self._token_expiry_time:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def _validate_token_scopes(self, token_response: OAuthToken) -> None:
|
||||
"""
|
||||
Validate returned scopes against requested scopes.
|
||||
|
||||
Per OAuth 2.1 Section 3.3: server may grant subset, not superset.
|
||||
"""
|
||||
if not token_response.scope:
|
||||
# No scope returned = validation passes
|
||||
return
|
||||
|
||||
# Check explicitly requested scopes only
|
||||
requested_scopes: set[str] = set()
|
||||
|
||||
if self.client_metadata.scope:
|
||||
# Validate against explicit scope request
|
||||
requested_scopes = set(self.client_metadata.scope.split())
|
||||
|
||||
# Check for unauthorized scopes
|
||||
returned_scopes = set(token_response.scope.split())
|
||||
unauthorized_scopes = returned_scopes - requested_scopes
|
||||
|
||||
if unauthorized_scopes:
|
||||
raise Exception(
|
||||
f"Server granted unauthorized scopes: {unauthorized_scopes}. "
|
||||
f"Requested: {requested_scopes}, Returned: {returned_scopes}"
|
||||
)
|
||||
else:
|
||||
# No explicit scopes requested - accept server defaults
|
||||
logger.debug(
|
||||
f"No explicit scopes requested, accepting server-granted "
|
||||
f"scopes: {set(token_response.scope.split())}"
|
||||
)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Load stored tokens and client info."""
|
||||
self._current_tokens = await self.storage.get_tokens()
|
||||
self._client_info = await self.storage.get_client_info()
|
||||
|
||||
async def _get_or_register_client(self) -> OAuthClientInformationFull:
|
||||
"""Get or register client with server."""
|
||||
if not self._client_info:
|
||||
try:
|
||||
self._client_info = await self._register_oauth_client(
|
||||
self.server_url, self.client_metadata, self._metadata
|
||||
)
|
||||
await self.storage.set_client_info(self._client_info)
|
||||
except Exception:
|
||||
logger.exception("Client registration failed")
|
||||
raise
|
||||
return self._client_info
|
||||
|
||||
async def ensure_token(self) -> None:
|
||||
"""Ensure valid access token, refreshing or re-authenticating as needed."""
|
||||
async with self._token_lock:
|
||||
# Return early if token is valid
|
||||
if self._has_valid_token():
|
||||
return
|
||||
|
||||
# Try refreshing existing token
|
||||
if (
|
||||
self._current_tokens
|
||||
and self._current_tokens.refresh_token
|
||||
and await self._refresh_access_token()
|
||||
):
|
||||
return
|
||||
|
||||
# Fall back to full OAuth flow
|
||||
await self._perform_oauth_flow()
|
||||
|
||||
async def _perform_oauth_flow(self) -> None:
|
||||
"""Execute OAuth2 authorization code flow with PKCE."""
|
||||
logger.debug("Starting authentication flow.")
|
||||
|
||||
# Discover OAuth metadata
|
||||
if not self._metadata:
|
||||
self._metadata = await self._discover_oauth_metadata(self.server_url)
|
||||
|
||||
# Ensure client registration
|
||||
client_info = await self._get_or_register_client()
|
||||
|
||||
# Generate PKCE challenge
|
||||
self._code_verifier = self._generate_code_verifier()
|
||||
self._code_challenge = self._generate_code_challenge(self._code_verifier)
|
||||
|
||||
# Get authorization endpoint
|
||||
if self._metadata and self._metadata.authorization_endpoint:
|
||||
auth_url_base = str(self._metadata.authorization_endpoint)
|
||||
else:
|
||||
# Use fallback authorization endpoint
|
||||
auth_base_url = self._get_authorization_base_url(self.server_url)
|
||||
auth_url_base = urljoin(auth_base_url, "/authorize")
|
||||
|
||||
# Build authorization URL
|
||||
self._auth_state = secrets.token_urlsafe(32)
|
||||
auth_params = {
|
||||
"response_type": "code",
|
||||
"client_id": client_info.client_id,
|
||||
"redirect_uri": str(self.client_metadata.redirect_uris[0]),
|
||||
"state": self._auth_state,
|
||||
"code_challenge": self._code_challenge,
|
||||
"code_challenge_method": "S256",
|
||||
}
|
||||
|
||||
# Include explicit scopes only
|
||||
if self.client_metadata.scope:
|
||||
auth_params["scope"] = self.client_metadata.scope
|
||||
|
||||
auth_url = f"{auth_url_base}?{urlencode(auth_params)}"
|
||||
|
||||
# Redirect user for authorization
|
||||
await self.redirect_handler(auth_url)
|
||||
|
||||
auth_code, returned_state = await self.callback_handler()
|
||||
|
||||
# Validate state parameter for CSRF protection
|
||||
if returned_state is None or not secrets.compare_digest(
|
||||
returned_state, self._auth_state
|
||||
):
|
||||
raise Exception(
|
||||
f"State parameter mismatch: {returned_state} != {self._auth_state}"
|
||||
)
|
||||
|
||||
# Clear state after validation
|
||||
self._auth_state = None
|
||||
|
||||
if not auth_code:
|
||||
raise Exception("No authorization code received")
|
||||
|
||||
# Exchange authorization code for tokens
|
||||
await self._exchange_code_for_token(auth_code, client_info)
|
||||
|
||||
async def _exchange_code_for_token(
|
||||
self, auth_code: str, client_info: OAuthClientInformationFull
|
||||
) -> None:
|
||||
"""Exchange authorization code for access token."""
|
||||
# Get token endpoint
|
||||
if self._metadata and self._metadata.token_endpoint:
|
||||
token_url = str(self._metadata.token_endpoint)
|
||||
else:
|
||||
# Use fallback token endpoint
|
||||
auth_base_url = self._get_authorization_base_url(self.server_url)
|
||||
token_url = urljoin(auth_base_url, "/token")
|
||||
|
||||
token_data = {
|
||||
"grant_type": "authorization_code",
|
||||
"code": auth_code,
|
||||
"redirect_uri": str(self.client_metadata.redirect_uris[0]),
|
||||
"client_id": client_info.client_id,
|
||||
"code_verifier": self._code_verifier,
|
||||
}
|
||||
|
||||
if client_info.client_secret:
|
||||
token_data["client_secret"] = client_info.client_secret
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
token_url,
|
||||
data=token_data,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
timeout=30.0,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
# Parse OAuth error response
|
||||
try:
|
||||
error_data = response.json()
|
||||
error_msg = error_data.get(
|
||||
"error_description", error_data.get("error", "Unknown error")
|
||||
)
|
||||
raise Exception(
|
||||
f"Token exchange failed: {error_msg} "
|
||||
f"(HTTP {response.status_code})"
|
||||
)
|
||||
except Exception:
|
||||
raise Exception(
|
||||
f"Token exchange failed: {response.status_code} {response.text}"
|
||||
)
|
||||
|
||||
# Parse token response
|
||||
token_response = OAuthToken.model_validate(response.json())
|
||||
|
||||
# Validate token scopes
|
||||
await self._validate_token_scopes(token_response)
|
||||
|
||||
# Calculate token expiry
|
||||
if token_response.expires_in:
|
||||
self._token_expiry_time = time.time() + token_response.expires_in
|
||||
else:
|
||||
self._token_expiry_time = None
|
||||
|
||||
# Store tokens
|
||||
await self.storage.set_tokens(token_response)
|
||||
self._current_tokens = token_response
|
||||
|
||||
async def _refresh_access_token(self) -> bool:
|
||||
"""Refresh access token using refresh token."""
|
||||
if not self._current_tokens or not self._current_tokens.refresh_token:
|
||||
return False
|
||||
|
||||
# Get client credentials
|
||||
client_info = await self._get_or_register_client()
|
||||
|
||||
# Get token endpoint
|
||||
if self._metadata and self._metadata.token_endpoint:
|
||||
token_url = str(self._metadata.token_endpoint)
|
||||
else:
|
||||
# Use fallback token endpoint
|
||||
auth_base_url = self._get_authorization_base_url(self.server_url)
|
||||
token_url = urljoin(auth_base_url, "/token")
|
||||
|
||||
refresh_data = {
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": self._current_tokens.refresh_token,
|
||||
"client_id": client_info.client_id,
|
||||
}
|
||||
|
||||
if client_info.client_secret:
|
||||
refresh_data["client_secret"] = client_info.client_secret
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
token_url,
|
||||
data=refresh_data,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
timeout=30.0,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"Token refresh failed: {response.status_code}")
|
||||
return False
|
||||
|
||||
# Parse refreshed tokens
|
||||
token_response = OAuthToken.model_validate(response.json())
|
||||
|
||||
# Validate token scopes
|
||||
await self._validate_token_scopes(token_response)
|
||||
|
||||
# Calculate token expiry
|
||||
if token_response.expires_in:
|
||||
self._token_expiry_time = time.time() + token_response.expires_in
|
||||
else:
|
||||
self._token_expiry_time = None
|
||||
|
||||
# Store refreshed tokens
|
||||
await self.storage.set_tokens(token_response)
|
||||
self._current_tokens = token_response
|
||||
|
||||
return True
|
||||
|
||||
except Exception:
|
||||
logger.exception("Token refresh failed")
|
||||
return False
|
||||
@@ -83,6 +83,7 @@ class StreamableHTTPTransport:
|
||||
headers: dict[str, Any] | None = None,
|
||||
timeout: timedelta = timedelta(seconds=30),
|
||||
sse_read_timeout: timedelta = timedelta(seconds=60 * 5),
|
||||
auth: httpx.Auth | None = None,
|
||||
) -> None:
|
||||
"""Initialize the StreamableHTTP transport.
|
||||
|
||||
@@ -91,11 +92,13 @@ class StreamableHTTPTransport:
|
||||
headers: Optional headers to include in requests.
|
||||
timeout: HTTP timeout for regular operations.
|
||||
sse_read_timeout: Timeout for SSE read operations.
|
||||
auth: Optional HTTPX authentication handler.
|
||||
"""
|
||||
self.url = url
|
||||
self.headers = headers or {}
|
||||
self.timeout = timeout
|
||||
self.sse_read_timeout = sse_read_timeout
|
||||
self.auth = auth
|
||||
self.session_id: str | None = None
|
||||
self.request_headers = {
|
||||
ACCEPT: f"{JSON}, {SSE}",
|
||||
@@ -427,6 +430,7 @@ async def streamablehttp_client(
|
||||
timeout: timedelta = timedelta(seconds=30),
|
||||
sse_read_timeout: timedelta = timedelta(seconds=60 * 5),
|
||||
terminate_on_close: bool = True,
|
||||
auth: httpx.Auth | None = None,
|
||||
) -> AsyncGenerator[
|
||||
tuple[
|
||||
MemoryObjectReceiveStream[SessionMessage | Exception],
|
||||
@@ -447,7 +451,7 @@ async def streamablehttp_client(
|
||||
- write_stream: Stream for sending messages to the server
|
||||
- get_session_id_callback: Function to retrieve the current session ID
|
||||
"""
|
||||
transport = StreamableHTTPTransport(url, headers, timeout, sse_read_timeout)
|
||||
transport = StreamableHTTPTransport(url, headers, timeout, sse_read_timeout, auth)
|
||||
|
||||
read_stream_writer, read_stream = anyio.create_memory_object_stream[
|
||||
SessionMessage | Exception
|
||||
@@ -465,6 +469,7 @@ async def streamablehttp_client(
|
||||
timeout=httpx.Timeout(
|
||||
transport.timeout.seconds, read=transport.sse_read_timeout.seconds
|
||||
),
|
||||
auth=transport.auth,
|
||||
) as client:
|
||||
# Define callbacks that need access to tg
|
||||
def start_get_stream() -> None:
|
||||
|
||||
@@ -10,6 +10,7 @@ __all__ = ["create_mcp_http_client"]
|
||||
def create_mcp_http_client(
|
||||
headers: dict[str, str] | None = None,
|
||||
timeout: httpx.Timeout | None = None,
|
||||
auth: httpx.Auth | None = None,
|
||||
) -> httpx.AsyncClient:
|
||||
"""Create a standardized httpx AsyncClient with MCP defaults.
|
||||
|
||||
@@ -21,6 +22,7 @@ def create_mcp_http_client(
|
||||
headers: Optional headers to include with all requests.
|
||||
timeout: Request timeout as httpx.Timeout object.
|
||||
Defaults to 30 seconds if not specified.
|
||||
auth: Optional authentication handler.
|
||||
|
||||
Returns:
|
||||
Configured httpx.AsyncClient instance with MCP defaults.
|
||||
@@ -43,6 +45,12 @@ def create_mcp_http_client(
|
||||
timeout = httpx.Timeout(60.0, read=300.0)
|
||||
async with create_mcp_http_client(headers, timeout) as client:
|
||||
response = await client.get("/long-request")
|
||||
|
||||
# With authentication
|
||||
from httpx import BasicAuth
|
||||
auth = BasicAuth(username="user", password="pass")
|
||||
async with create_mcp_http_client(headers, timeout, auth) as client:
|
||||
response = await client.get("/protected-endpoint")
|
||||
"""
|
||||
# Set MCP defaults
|
||||
kwargs: dict[str, Any] = {
|
||||
@@ -59,4 +67,8 @@ def create_mcp_http_client(
|
||||
if headers is not None:
|
||||
kwargs["headers"] = headers
|
||||
|
||||
# Handle authentication
|
||||
if auth is not None:
|
||||
kwargs["auth"] = auth
|
||||
|
||||
return httpx.AsyncClient(**kwargs)
|
||||
|
||||
Reference in New Issue
Block a user