mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 23:04:25 +01:00
Add support for serverside oauth (#255)
Co-authored-by: David Soria Parra <davidsp@anthropic.com> Co-authored-by: Basil Hosmer <basil@anthropic.com> Co-authored-by: ihrpr <inna@anthropic.com>
This commit is contained in:
@@ -4,7 +4,7 @@ from __future__ import annotations as _annotations
|
||||
|
||||
import inspect
|
||||
import re
|
||||
from collections.abc import AsyncIterator, Callable, Iterable, Sequence
|
||||
from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Sequence
|
||||
from contextlib import (
|
||||
AbstractAsyncContextManager,
|
||||
asynccontextmanager,
|
||||
@@ -18,9 +18,22 @@ from pydantic import BaseModel, Field
|
||||
from pydantic.networks import AnyUrl
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from starlette.applications import Starlette
|
||||
from starlette.middleware import Middleware
|
||||
from starlette.middleware.authentication import AuthenticationMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
from starlette.routing import Mount, Route
|
||||
from starlette.types import Receive, Scope, Send
|
||||
|
||||
from mcp.server.auth.middleware.auth_context import AuthContextMiddleware
|
||||
from mcp.server.auth.middleware.bearer_auth import (
|
||||
BearerAuthBackend,
|
||||
RequireAuthMiddleware,
|
||||
)
|
||||
from mcp.server.auth.provider import OAuthAuthorizationServerProvider
|
||||
from mcp.server.auth.settings import (
|
||||
AuthSettings,
|
||||
)
|
||||
from mcp.server.fastmcp.exceptions import ResourceError
|
||||
from mcp.server.fastmcp.prompts import Prompt, PromptManager
|
||||
from mcp.server.fastmcp.resources import FunctionResource, Resource, ResourceManager
|
||||
@@ -62,6 +75,8 @@ class Settings(BaseSettings, Generic[LifespanResultT]):
|
||||
model_config = SettingsConfigDict(
|
||||
env_prefix="FASTMCP_",
|
||||
env_file=".env",
|
||||
env_nested_delimiter="__",
|
||||
nested_model_default_partial_update=True,
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
@@ -93,6 +108,8 @@ class Settings(BaseSettings, Generic[LifespanResultT]):
|
||||
Callable[[FastMCP], AbstractAsyncContextManager[LifespanResultT]] | None
|
||||
) = Field(None, description="Lifespan context manager")
|
||||
|
||||
auth: AuthSettings | None = None
|
||||
|
||||
|
||||
def lifespan_wrapper(
|
||||
app: FastMCP,
|
||||
@@ -108,7 +125,12 @@ def lifespan_wrapper(
|
||||
|
||||
class FastMCP:
|
||||
def __init__(
|
||||
self, name: str | None = None, instructions: str | None = None, **settings: Any
|
||||
self,
|
||||
name: str | None = None,
|
||||
instructions: str | None = None,
|
||||
auth_server_provider: OAuthAuthorizationServerProvider[Any, Any, Any]
|
||||
| None = None,
|
||||
**settings: Any,
|
||||
):
|
||||
self.settings = Settings(**settings)
|
||||
|
||||
@@ -128,6 +150,18 @@ class FastMCP:
|
||||
self._prompt_manager = PromptManager(
|
||||
warn_on_duplicate_prompts=self.settings.warn_on_duplicate_prompts
|
||||
)
|
||||
if (self.settings.auth is not None) != (auth_server_provider is not None):
|
||||
# TODO: after we support separate authorization servers (see
|
||||
# https://github.com/modelcontextprotocol/modelcontextprotocol/pull/284)
|
||||
# we should validate that if auth is enabled, we have either an
|
||||
# auth_server_provider to host our own authorization server,
|
||||
# OR the URL of a 3rd party authorization server.
|
||||
raise ValueError(
|
||||
"settings.auth must be specified if and only if auth_server_provider "
|
||||
"is specified"
|
||||
)
|
||||
self._auth_server_provider = auth_server_provider
|
||||
self._custom_starlette_routes: list[Route] = []
|
||||
self.dependencies = self.settings.dependencies
|
||||
|
||||
# Set up MCP protocol handlers
|
||||
@@ -465,6 +499,50 @@ class FastMCP:
|
||||
|
||||
return decorator
|
||||
|
||||
def custom_route(
|
||||
self,
|
||||
path: str,
|
||||
methods: list[str],
|
||||
name: str | None = None,
|
||||
include_in_schema: bool = True,
|
||||
):
|
||||
"""
|
||||
Decorator to register a custom HTTP route on the FastMCP server.
|
||||
|
||||
Allows adding arbitrary HTTP endpoints outside the standard MCP protocol,
|
||||
which can be useful for OAuth callbacks, health checks, or admin APIs.
|
||||
The handler function must be an async function that accepts a Starlette
|
||||
Request and returns a Response.
|
||||
|
||||
Args:
|
||||
path: URL path for the route (e.g., "/oauth/callback")
|
||||
methods: List of HTTP methods to support (e.g., ["GET", "POST"])
|
||||
name: Optional name for the route (to reference this route with
|
||||
Starlette's reverse URL lookup feature)
|
||||
include_in_schema: Whether to include in OpenAPI schema, defaults to True
|
||||
|
||||
Example:
|
||||
@server.custom_route("/health", methods=["GET"])
|
||||
async def health_check(request: Request) -> Response:
|
||||
return JSONResponse({"status": "ok"})
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
func: Callable[[Request], Awaitable[Response]],
|
||||
) -> Callable[[Request], Awaitable[Response]]:
|
||||
self._custom_starlette_routes.append(
|
||||
Route(
|
||||
path,
|
||||
endpoint=func,
|
||||
methods=methods,
|
||||
name=name,
|
||||
include_in_schema=include_in_schema,
|
||||
)
|
||||
)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
async def run_stdio_async(self) -> None:
|
||||
"""Run the server using stdio transport."""
|
||||
async with stdio_server() as (read_stream, write_stream):
|
||||
@@ -491,13 +569,20 @@ class FastMCP:
|
||||
|
||||
def sse_app(self) -> Starlette:
|
||||
"""Return an instance of the SSE server app."""
|
||||
from starlette.middleware import Middleware
|
||||
from starlette.routing import Mount, Route
|
||||
|
||||
# Set up auth context and dependencies
|
||||
|
||||
sse = SseServerTransport(self.settings.message_path)
|
||||
|
||||
async def handle_sse(request: Request) -> None:
|
||||
async def handle_sse(scope: Scope, receive: Receive, send: Send):
|
||||
# Add client ID from auth context into request context if available
|
||||
|
||||
async with sse.connect_sse(
|
||||
request.scope,
|
||||
request.receive,
|
||||
request._send, # type: ignore[reportPrivateUsage]
|
||||
scope,
|
||||
receive,
|
||||
send,
|
||||
) as streams:
|
||||
await self._mcp_server.run(
|
||||
streams[0],
|
||||
@@ -505,12 +590,59 @@ class FastMCP:
|
||||
self._mcp_server.create_initialization_options(),
|
||||
)
|
||||
|
||||
# Create routes
|
||||
routes: list[Route | Mount] = []
|
||||
middleware: list[Middleware] = []
|
||||
required_scopes = []
|
||||
|
||||
# Add auth endpoints if auth provider is configured
|
||||
if self._auth_server_provider:
|
||||
assert self.settings.auth
|
||||
from mcp.server.auth.routes import create_auth_routes
|
||||
|
||||
required_scopes = self.settings.auth.required_scopes or []
|
||||
|
||||
middleware = [
|
||||
# extract auth info from request (but do not require it)
|
||||
Middleware(
|
||||
AuthenticationMiddleware,
|
||||
backend=BearerAuthBackend(
|
||||
provider=self._auth_server_provider,
|
||||
),
|
||||
),
|
||||
# Add the auth context middleware to store
|
||||
# authenticated user in a contextvar
|
||||
Middleware(AuthContextMiddleware),
|
||||
]
|
||||
routes.extend(
|
||||
create_auth_routes(
|
||||
provider=self._auth_server_provider,
|
||||
issuer_url=self.settings.auth.issuer_url,
|
||||
service_documentation_url=self.settings.auth.service_documentation_url,
|
||||
client_registration_options=self.settings.auth.client_registration_options,
|
||||
revocation_options=self.settings.auth.revocation_options,
|
||||
)
|
||||
)
|
||||
|
||||
routes.append(
|
||||
Route(
|
||||
self.settings.sse_path,
|
||||
endpoint=RequireAuthMiddleware(handle_sse, required_scopes),
|
||||
methods=["GET"],
|
||||
)
|
||||
)
|
||||
routes.append(
|
||||
Mount(
|
||||
self.settings.message_path,
|
||||
app=RequireAuthMiddleware(sse.handle_post_message, required_scopes),
|
||||
)
|
||||
)
|
||||
# mount these routes last, so they have the lowest route matching precedence
|
||||
routes.extend(self._custom_starlette_routes)
|
||||
|
||||
# Create Starlette app with routes and middleware
|
||||
return Starlette(
|
||||
debug=self.settings.debug,
|
||||
routes=[
|
||||
Route(self.settings.sse_path, endpoint=handle_sse),
|
||||
Mount(self.settings.message_path, app=sse.handle_post_message),
|
||||
],
|
||||
debug=self.settings.debug, routes=routes, middleware=middleware
|
||||
)
|
||||
|
||||
async def list_prompts(self) -> list[MCPPrompt]:
|
||||
|
||||
Reference in New Issue
Block a user