mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-18 22:44:20 +01:00
Streamable HTTP - improve usability, fast mcp and auth (#641)
This commit is contained in:
@@ -1,37 +1,17 @@
|
||||
import contextlib
|
||||
import logging
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
import anyio
|
||||
import click
|
||||
import mcp.types as types
|
||||
from mcp.server.lowlevel import Server
|
||||
from mcp.server.streamableHttp import (
|
||||
StreamableHTTPServerTransport,
|
||||
)
|
||||
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
|
||||
from starlette.applications import Starlette
|
||||
from starlette.routing import Mount
|
||||
from starlette.types import Receive, Scope, Send
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
# Global task group that will be initialized in the lifespan
|
||||
task_group = None
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def lifespan(app):
|
||||
"""Application lifespan context manager for managing task group."""
|
||||
global task_group
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
task_group = tg
|
||||
logger.info("Application started, task group initialized!")
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
logger.info("Application shutting down, cleaning up resources...")
|
||||
if task_group:
|
||||
tg.cancel_scope.cancel()
|
||||
task_group = None
|
||||
logger.info("Resources cleaned up successfully.")
|
||||
|
||||
|
||||
@click.command()
|
||||
@@ -122,35 +102,28 @@ def main(
|
||||
)
|
||||
]
|
||||
|
||||
# ASGI handler for stateless HTTP connections
|
||||
async def handle_streamable_http(scope, receive, send):
|
||||
logger.debug("Creating new transport")
|
||||
# Use lock to prevent race conditions when creating new sessions
|
||||
http_transport = StreamableHTTPServerTransport(
|
||||
mcp_session_id=None,
|
||||
is_json_response_enabled=json_response,
|
||||
)
|
||||
async with http_transport.connect() as streams:
|
||||
read_stream, write_stream = streams
|
||||
# Create the session manager with true stateless mode
|
||||
session_manager = StreamableHTTPSessionManager(
|
||||
app=app,
|
||||
event_store=None,
|
||||
json_response=json_response,
|
||||
stateless=True,
|
||||
)
|
||||
|
||||
if not task_group:
|
||||
raise RuntimeError("Task group is not initialized")
|
||||
async def handle_streamable_http(
|
||||
scope: Scope, receive: Receive, send: Send
|
||||
) -> None:
|
||||
await session_manager.handle_request(scope, receive, send)
|
||||
|
||||
async def run_server():
|
||||
await app.run(
|
||||
read_stream,
|
||||
write_stream,
|
||||
app.create_initialization_options(),
|
||||
# Runs in standalone mode for stateless deployments
|
||||
# where clients perform initialization with any node
|
||||
standalone_mode=True,
|
||||
)
|
||||
|
||||
# Start server task
|
||||
task_group.start_soon(run_server)
|
||||
|
||||
# Handle the HTTP request and return the response
|
||||
await http_transport.handle_request(scope, receive, send)
|
||||
@contextlib.asynccontextmanager
|
||||
async def lifespan(app: Starlette) -> AsyncIterator[None]:
|
||||
"""Context manager for session manager."""
|
||||
async with session_manager.run():
|
||||
logger.info("Application started with StreamableHTTP session manager!")
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
logger.info("Application shutting down...")
|
||||
|
||||
# Create an ASGI application using the transport
|
||||
starlette_app = Starlette(
|
||||
|
||||
@@ -1,58 +1,22 @@
|
||||
import contextlib
|
||||
import logging
|
||||
from http import HTTPStatus
|
||||
from uuid import uuid4
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
import anyio
|
||||
import click
|
||||
import mcp.types as types
|
||||
from mcp.server.lowlevel import Server
|
||||
from mcp.server.streamable_http import (
|
||||
MCP_SESSION_ID_HEADER,
|
||||
StreamableHTTPServerTransport,
|
||||
)
|
||||
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
|
||||
from pydantic import AnyUrl
|
||||
from starlette.applications import Starlette
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
from starlette.routing import Mount
|
||||
from starlette.types import Receive, Scope, Send
|
||||
|
||||
from .event_store import InMemoryEventStore
|
||||
|
||||
# Configure logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global task group that will be initialized in the lifespan
|
||||
task_group = None
|
||||
|
||||
# Event store for resumability
|
||||
# The InMemoryEventStore enables resumability support for StreamableHTTP transport.
|
||||
# It stores SSE events with unique IDs, allowing clients to:
|
||||
# 1. Receive event IDs for each SSE message
|
||||
# 2. Resume streams by sending Last-Event-ID in GET requests
|
||||
# 3. Replay missed events after reconnection
|
||||
# Note: This in-memory implementation is for demonstration ONLY.
|
||||
# For production, use a persistent storage solution.
|
||||
event_store = InMemoryEventStore()
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def lifespan(app):
|
||||
"""Application lifespan context manager for managing task group."""
|
||||
global task_group
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
task_group = tg
|
||||
logger.info("Application started, task group initialized!")
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
logger.info("Application shutting down, cleaning up resources...")
|
||||
if task_group:
|
||||
tg.cancel_scope.cancel()
|
||||
task_group = None
|
||||
logger.info("Resources cleaned up successfully.")
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option("--port", default=3000, help="Port to listen on for HTTP")
|
||||
@@ -156,60 +120,38 @@ def main(
|
||||
)
|
||||
]
|
||||
|
||||
# We need to store the server instances between requests
|
||||
server_instances = {}
|
||||
# Lock to prevent race conditions when creating new sessions
|
||||
session_creation_lock = anyio.Lock()
|
||||
# Create event store for resumability
|
||||
# The InMemoryEventStore enables resumability support for StreamableHTTP transport.
|
||||
# It stores SSE events with unique IDs, allowing clients to:
|
||||
# 1. Receive event IDs for each SSE message
|
||||
# 2. Resume streams by sending Last-Event-ID in GET requests
|
||||
# 3. Replay missed events after reconnection
|
||||
# Note: This in-memory implementation is for demonstration ONLY.
|
||||
# For production, use a persistent storage solution.
|
||||
event_store = InMemoryEventStore()
|
||||
|
||||
# Create the session manager with our app and event store
|
||||
session_manager = StreamableHTTPSessionManager(
|
||||
app=app,
|
||||
event_store=event_store, # Enable resumability
|
||||
json_response=json_response,
|
||||
)
|
||||
|
||||
# ASGI handler for streamable HTTP connections
|
||||
async def handle_streamable_http(scope, receive, send):
|
||||
request = Request(scope, receive)
|
||||
request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER)
|
||||
if (
|
||||
request_mcp_session_id is not None
|
||||
and request_mcp_session_id in server_instances
|
||||
):
|
||||
transport = server_instances[request_mcp_session_id]
|
||||
logger.debug("Session already exists, handling request directly")
|
||||
await transport.handle_request(scope, receive, send)
|
||||
elif request_mcp_session_id is None:
|
||||
# try to establish new session
|
||||
logger.debug("Creating new transport")
|
||||
# Use lock to prevent race conditions when creating new sessions
|
||||
async with session_creation_lock:
|
||||
new_session_id = uuid4().hex
|
||||
http_transport = StreamableHTTPServerTransport(
|
||||
mcp_session_id=new_session_id,
|
||||
is_json_response_enabled=json_response,
|
||||
event_store=event_store, # Enable resumability
|
||||
)
|
||||
server_instances[http_transport.mcp_session_id] = http_transport
|
||||
logger.info(f"Created new transport with session ID: {new_session_id}")
|
||||
async def handle_streamable_http(
|
||||
scope: Scope, receive: Receive, send: Send
|
||||
) -> None:
|
||||
await session_manager.handle_request(scope, receive, send)
|
||||
|
||||
async def run_server(task_status=None):
|
||||
async with http_transport.connect() as streams:
|
||||
read_stream, write_stream = streams
|
||||
if task_status:
|
||||
task_status.started()
|
||||
await app.run(
|
||||
read_stream,
|
||||
write_stream,
|
||||
app.create_initialization_options(),
|
||||
)
|
||||
|
||||
if not task_group:
|
||||
raise RuntimeError("Task group is not initialized")
|
||||
|
||||
await task_group.start(run_server)
|
||||
|
||||
# Handle the HTTP request and return the response
|
||||
await http_transport.handle_request(scope, receive, send)
|
||||
else:
|
||||
response = Response(
|
||||
"Bad Request: No valid session ID provided",
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
)
|
||||
await response(scope, receive, send)
|
||||
@contextlib.asynccontextmanager
|
||||
async def lifespan(app: Starlette) -> AsyncIterator[None]:
|
||||
"""Context manager for managing session manager lifecycle."""
|
||||
async with session_manager.run():
|
||||
logger.info("Application started with StreamableHTTP session manager!")
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
logger.info("Application shutting down...")
|
||||
|
||||
# Create an ASGI application using the transport
|
||||
starlette_app = Starlette(
|
||||
|
||||
@@ -47,6 +47,8 @@ from mcp.server.lowlevel.server import lifespan as default_lifespan
|
||||
from mcp.server.session import ServerSession, ServerSessionT
|
||||
from mcp.server.sse import SseServerTransport
|
||||
from mcp.server.stdio import stdio_server
|
||||
from mcp.server.streamable_http import EventStore
|
||||
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
|
||||
from mcp.shared.context import LifespanContextT, RequestContext
|
||||
from mcp.types import (
|
||||
AnyFunction,
|
||||
@@ -90,6 +92,13 @@ class Settings(BaseSettings, Generic[LifespanResultT]):
|
||||
mount_path: str = "/" # Mount path (e.g. "/github", defaults to root path)
|
||||
sse_path: str = "/sse"
|
||||
message_path: str = "/messages/"
|
||||
streamable_http_path: str = "/mcp"
|
||||
|
||||
# StreamableHTTP settings
|
||||
json_response: bool = False
|
||||
stateless_http: bool = (
|
||||
False # If True, uses true stateless mode (new transport per request)
|
||||
)
|
||||
|
||||
# resource settings
|
||||
warn_on_duplicate_resources: bool = True
|
||||
@@ -131,6 +140,7 @@ class FastMCP:
|
||||
instructions: str | None = None,
|
||||
auth_server_provider: OAuthAuthorizationServerProvider[Any, Any, Any]
|
||||
| None = None,
|
||||
event_store: EventStore | None = None,
|
||||
**settings: Any,
|
||||
):
|
||||
self.settings = Settings(**settings)
|
||||
@@ -162,8 +172,10 @@ class FastMCP:
|
||||
"is specified"
|
||||
)
|
||||
self._auth_server_provider = auth_server_provider
|
||||
self._event_store = event_store
|
||||
self._custom_starlette_routes: list[Route] = []
|
||||
self.dependencies = self.settings.dependencies
|
||||
self._session_manager: StreamableHTTPSessionManager | None = None
|
||||
|
||||
# Set up MCP protocol handlers
|
||||
self._setup_handlers()
|
||||
@@ -179,25 +191,47 @@ class FastMCP:
|
||||
def instructions(self) -> str | None:
|
||||
return self._mcp_server.instructions
|
||||
|
||||
@property
|
||||
def session_manager(self) -> StreamableHTTPSessionManager:
|
||||
"""Get the StreamableHTTP session manager.
|
||||
|
||||
This is exposed to enable advanced use cases like mounting multiple
|
||||
FastMCP servers in a single FastAPI application.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If called before streamable_http_app() has been called.
|
||||
"""
|
||||
if self._session_manager is None:
|
||||
raise RuntimeError(
|
||||
"Session manager can only be accessed after"
|
||||
"calling streamable_http_app()."
|
||||
"The session manager is created lazily"
|
||||
"to avoid unnecessary initialization."
|
||||
)
|
||||
return self._session_manager
|
||||
|
||||
def run(
|
||||
self,
|
||||
transport: Literal["stdio", "sse"] = "stdio",
|
||||
transport: Literal["stdio", "sse", "streamable-http"] = "stdio",
|
||||
mount_path: str | None = None,
|
||||
) -> None:
|
||||
"""Run the FastMCP server. Note this is a synchronous function.
|
||||
|
||||
Args:
|
||||
transport: Transport protocol to use ("stdio" or "sse")
|
||||
transport: Transport protocol to use ("stdio", "sse", or "streamable-http")
|
||||
mount_path: Optional mount path for SSE transport
|
||||
"""
|
||||
TRANSPORTS = Literal["stdio", "sse"]
|
||||
TRANSPORTS = Literal["stdio", "sse", "streamable-http"]
|
||||
if transport not in TRANSPORTS.__args__: # type: ignore
|
||||
raise ValueError(f"Unknown transport: {transport}")
|
||||
|
||||
if transport == "stdio":
|
||||
anyio.run(self.run_stdio_async)
|
||||
else: # transport == "sse"
|
||||
anyio.run(lambda: self.run_sse_async(mount_path))
|
||||
match transport:
|
||||
case "stdio":
|
||||
anyio.run(self.run_stdio_async)
|
||||
case "sse":
|
||||
anyio.run(lambda: self.run_sse_async(mount_path))
|
||||
case "streamable-http":
|
||||
anyio.run(self.run_streamable_http_async)
|
||||
|
||||
def _setup_handlers(self) -> None:
|
||||
"""Set up core MCP protocol handlers."""
|
||||
@@ -573,6 +607,21 @@ class FastMCP:
|
||||
server = uvicorn.Server(config)
|
||||
await server.serve()
|
||||
|
||||
async def run_streamable_http_async(self) -> None:
|
||||
"""Run the server using StreamableHTTP transport."""
|
||||
import uvicorn
|
||||
|
||||
starlette_app = self.streamable_http_app()
|
||||
|
||||
config = uvicorn.Config(
|
||||
starlette_app,
|
||||
host=self.settings.host,
|
||||
port=self.settings.port,
|
||||
log_level=self.settings.log_level.lower(),
|
||||
)
|
||||
server = uvicorn.Server(config)
|
||||
await server.serve()
|
||||
|
||||
def _normalize_path(self, mount_path: str, endpoint: str) -> str:
|
||||
"""
|
||||
Combine mount path and endpoint to return a normalized path.
|
||||
@@ -687,9 +736,9 @@ class FastMCP:
|
||||
else:
|
||||
# Auth is disabled, no need for RequireAuthMiddleware
|
||||
# Since handle_sse is an ASGI app, we need to create a compatible endpoint
|
||||
async def sse_endpoint(request: Request) -> None:
|
||||
async def sse_endpoint(request: Request) -> Response:
|
||||
# Convert the Starlette request to ASGI parameters
|
||||
await handle_sse(request.scope, request.receive, request._send) # type: ignore[reportPrivateUsage]
|
||||
return await handle_sse(request.scope, request.receive, request._send) # type: ignore[reportPrivateUsage]
|
||||
|
||||
routes.append(
|
||||
Route(
|
||||
@@ -712,6 +761,80 @@ class FastMCP:
|
||||
debug=self.settings.debug, routes=routes, middleware=middleware
|
||||
)
|
||||
|
||||
def streamable_http_app(self) -> Starlette:
|
||||
"""Return an instance of the StreamableHTTP server app."""
|
||||
from starlette.middleware import Middleware
|
||||
from starlette.routing import Mount
|
||||
|
||||
# Create session manager on first call (lazy initialization)
|
||||
if self._session_manager is None:
|
||||
self._session_manager = StreamableHTTPSessionManager(
|
||||
app=self._mcp_server,
|
||||
event_store=self._event_store,
|
||||
json_response=self.settings.json_response,
|
||||
stateless=self.settings.stateless_http, # Use the stateless setting
|
||||
)
|
||||
|
||||
# Create the ASGI handler
|
||||
async def handle_streamable_http(
|
||||
scope: Scope, receive: Receive, send: Send
|
||||
) -> None:
|
||||
await self.session_manager.handle_request(scope, receive, send)
|
||||
|
||||
# Create routes
|
||||
routes: list[Route | Mount] = []
|
||||
middleware: list[Middleware] = []
|
||||
required_scopes = []
|
||||
|
||||
# Add auth endpoints if auth provider is configured
|
||||
if self._auth_server_provider:
|
||||
assert self.settings.auth
|
||||
from mcp.server.auth.routes import create_auth_routes
|
||||
|
||||
required_scopes = self.settings.auth.required_scopes or []
|
||||
|
||||
middleware = [
|
||||
Middleware(
|
||||
AuthenticationMiddleware,
|
||||
backend=BearerAuthBackend(
|
||||
provider=self._auth_server_provider,
|
||||
),
|
||||
),
|
||||
Middleware(AuthContextMiddleware),
|
||||
]
|
||||
routes.extend(
|
||||
create_auth_routes(
|
||||
provider=self._auth_server_provider,
|
||||
issuer_url=self.settings.auth.issuer_url,
|
||||
service_documentation_url=self.settings.auth.service_documentation_url,
|
||||
client_registration_options=self.settings.auth.client_registration_options,
|
||||
revocation_options=self.settings.auth.revocation_options,
|
||||
)
|
||||
)
|
||||
routes.append(
|
||||
Mount(
|
||||
self.settings.streamable_http_path,
|
||||
app=RequireAuthMiddleware(handle_streamable_http, required_scopes),
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Auth is disabled, no wrapper needed
|
||||
routes.append(
|
||||
Mount(
|
||||
self.settings.streamable_http_path,
|
||||
app=handle_streamable_http,
|
||||
)
|
||||
)
|
||||
|
||||
routes.extend(self._custom_starlette_routes)
|
||||
|
||||
return Starlette(
|
||||
debug=self.settings.debug,
|
||||
routes=routes,
|
||||
middleware=middleware,
|
||||
lifespan=lambda app: self.session_manager.run(),
|
||||
)
|
||||
|
||||
async def list_prompts(self) -> list[MCPPrompt]:
|
||||
"""List all available prompts."""
|
||||
prompts = self._prompt_manager.list_prompts()
|
||||
|
||||
258
src/mcp/server/streamable_http_manager.py
Normal file
258
src/mcp/server/streamable_http_manager.py
Normal file
@@ -0,0 +1,258 @@
|
||||
"""StreamableHTTP Session Manager for MCP servers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import threading
|
||||
from collections.abc import AsyncIterator
|
||||
from http import HTTPStatus
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import anyio
|
||||
from anyio.abc import TaskStatus
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
from starlette.types import Receive, Scope, Send
|
||||
|
||||
from mcp.server.lowlevel.server import Server as MCPServer
|
||||
from mcp.server.streamable_http import (
|
||||
MCP_SESSION_ID_HEADER,
|
||||
EventStore,
|
||||
StreamableHTTPServerTransport,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StreamableHTTPSessionManager:
|
||||
"""
|
||||
Manages StreamableHTTP sessions with optional resumability via event store.
|
||||
|
||||
This class abstracts away the complexity of session management, event storage,
|
||||
and request handling for StreamableHTTP transports. It handles:
|
||||
|
||||
1. Session tracking for clients
|
||||
2. Resumability via an optional event store
|
||||
3. Connection management and lifecycle
|
||||
4. Request handling and transport setup
|
||||
|
||||
Important: Only one StreamableHTTPSessionManager instance should be created
|
||||
per application. The instance cannot be reused after its run() context has
|
||||
completed. If you need to restart the manager, create a new instance.
|
||||
|
||||
Args:
|
||||
app: The MCP server instance
|
||||
event_store: Optional event store for resumability support.
|
||||
If provided, enables resumable connections where clients
|
||||
can reconnect and receive missed events.
|
||||
If None, sessions are still tracked but not resumable.
|
||||
json_response: Whether to use JSON responses instead of SSE streams
|
||||
stateless: If True, creates a completely fresh transport for each request
|
||||
with no session tracking or state persistence between requests.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app: MCPServer[Any],
|
||||
event_store: EventStore | None = None,
|
||||
json_response: bool = False,
|
||||
stateless: bool = False,
|
||||
):
|
||||
self.app = app
|
||||
self.event_store = event_store
|
||||
self.json_response = json_response
|
||||
self.stateless = stateless
|
||||
|
||||
# Session tracking (only used if not stateless)
|
||||
self._session_creation_lock = anyio.Lock()
|
||||
self._server_instances: dict[str, StreamableHTTPServerTransport] = {}
|
||||
|
||||
# The task group will be set during lifespan
|
||||
self._task_group = None
|
||||
# Thread-safe tracking of run() calls
|
||||
self._run_lock = threading.Lock()
|
||||
self._has_started = False
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def run(self) -> AsyncIterator[None]:
|
||||
"""
|
||||
Run the session manager with proper lifecycle management.
|
||||
|
||||
This creates and manages the task group for all session operations.
|
||||
|
||||
Important: This method can only be called once per instance. The same
|
||||
StreamableHTTPSessionManager instance cannot be reused after this
|
||||
context manager exits. Create a new instance if you need to restart.
|
||||
|
||||
Use this in the lifespan context manager of your Starlette app:
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def lifespan(app: Starlette) -> AsyncIterator[None]:
|
||||
async with session_manager.run():
|
||||
yield
|
||||
"""
|
||||
# Thread-safe check to ensure run() is only called once
|
||||
with self._run_lock:
|
||||
if self._has_started:
|
||||
raise RuntimeError(
|
||||
"StreamableHTTPSessionManager .run() can only be called "
|
||||
"once per instance. Create a new instance if you need to run again."
|
||||
)
|
||||
self._has_started = True
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
# Store the task group for later use
|
||||
self._task_group = tg
|
||||
logger.info("StreamableHTTP session manager started")
|
||||
try:
|
||||
yield # Let the application run
|
||||
finally:
|
||||
logger.info("StreamableHTTP session manager shutting down")
|
||||
# Cancel task group to stop all spawned tasks
|
||||
tg.cancel_scope.cancel()
|
||||
self._task_group = None
|
||||
# Clear any remaining server instances
|
||||
self._server_instances.clear()
|
||||
|
||||
async def handle_request(
|
||||
self,
|
||||
scope: Scope,
|
||||
receive: Receive,
|
||||
send: Send,
|
||||
) -> None:
|
||||
"""
|
||||
Process ASGI request with proper session handling and transport setup.
|
||||
|
||||
Dispatches to the appropriate handler based on stateless mode.
|
||||
|
||||
Args:
|
||||
scope: ASGI scope
|
||||
receive: ASGI receive function
|
||||
send: ASGI send function
|
||||
"""
|
||||
if self._task_group is None:
|
||||
raise RuntimeError("Task group is not initialized. Make sure to use run().")
|
||||
|
||||
# Dispatch to the appropriate handler
|
||||
if self.stateless:
|
||||
await self._handle_stateless_request(scope, receive, send)
|
||||
else:
|
||||
await self._handle_stateful_request(scope, receive, send)
|
||||
|
||||
async def _handle_stateless_request(
|
||||
self,
|
||||
scope: Scope,
|
||||
receive: Receive,
|
||||
send: Send,
|
||||
) -> None:
|
||||
"""
|
||||
Process request in stateless mode - creating a new transport for each request.
|
||||
|
||||
Args:
|
||||
scope: ASGI scope
|
||||
receive: ASGI receive function
|
||||
send: ASGI send function
|
||||
"""
|
||||
logger.debug("Stateless mode: Creating new transport for this request")
|
||||
# No session ID needed in stateless mode
|
||||
http_transport = StreamableHTTPServerTransport(
|
||||
mcp_session_id=None, # No session tracking in stateless mode
|
||||
is_json_response_enabled=self.json_response,
|
||||
event_store=None, # No event store in stateless mode
|
||||
)
|
||||
|
||||
# Start server in a new task
|
||||
async def run_stateless_server(
|
||||
*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED
|
||||
):
|
||||
async with http_transport.connect() as streams:
|
||||
read_stream, write_stream = streams
|
||||
task_status.started()
|
||||
await self.app.run(
|
||||
read_stream,
|
||||
write_stream,
|
||||
self.app.create_initialization_options(),
|
||||
stateless=True,
|
||||
)
|
||||
|
||||
# Assert task group is not None for type checking
|
||||
assert self._task_group is not None
|
||||
# Start the server task
|
||||
await self._task_group.start(run_stateless_server)
|
||||
|
||||
# Handle the HTTP request and return the response
|
||||
await http_transport.handle_request(scope, receive, send)
|
||||
|
||||
async def _handle_stateful_request(
|
||||
self,
|
||||
scope: Scope,
|
||||
receive: Receive,
|
||||
send: Send,
|
||||
) -> None:
|
||||
"""
|
||||
Process request in stateful mode - maintaining session state between requests.
|
||||
|
||||
Args:
|
||||
scope: ASGI scope
|
||||
receive: ASGI receive function
|
||||
send: ASGI send function
|
||||
"""
|
||||
request = Request(scope, receive)
|
||||
request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER)
|
||||
|
||||
# Existing session case
|
||||
if (
|
||||
request_mcp_session_id is not None
|
||||
and request_mcp_session_id in self._server_instances
|
||||
):
|
||||
transport = self._server_instances[request_mcp_session_id]
|
||||
logger.debug("Session already exists, handling request directly")
|
||||
await transport.handle_request(scope, receive, send)
|
||||
return
|
||||
|
||||
if request_mcp_session_id is None:
|
||||
# New session case
|
||||
logger.debug("Creating new transport")
|
||||
async with self._session_creation_lock:
|
||||
new_session_id = uuid4().hex
|
||||
http_transport = StreamableHTTPServerTransport(
|
||||
mcp_session_id=new_session_id,
|
||||
is_json_response_enabled=self.json_response,
|
||||
event_store=self.event_store, # May be None (no resumability)
|
||||
)
|
||||
|
||||
assert http_transport.mcp_session_id is not None
|
||||
self._server_instances[http_transport.mcp_session_id] = http_transport
|
||||
logger.info(f"Created new transport with session ID: {new_session_id}")
|
||||
|
||||
# Define the server runner
|
||||
async def run_server(
|
||||
*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED
|
||||
) -> None:
|
||||
async with http_transport.connect() as streams:
|
||||
read_stream, write_stream = streams
|
||||
task_status.started()
|
||||
await self.app.run(
|
||||
read_stream,
|
||||
write_stream,
|
||||
self.app.create_initialization_options(),
|
||||
stateless=False, # Stateful mode
|
||||
)
|
||||
|
||||
# Assert task group is not None for type checking
|
||||
assert self._task_group is not None
|
||||
# Start the server task
|
||||
await self._task_group.start(run_server)
|
||||
|
||||
# Handle the HTTP request and return the response
|
||||
await http_transport.handle_request(scope, receive, send)
|
||||
else:
|
||||
# Invalid session ID
|
||||
response = Response(
|
||||
"Bad Request: No valid session ID provided",
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
)
|
||||
await response(scope, receive, send)
|
||||
@@ -15,6 +15,7 @@ import uvicorn
|
||||
|
||||
from mcp.client.session import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
from mcp.types import InitializeResult, TextContent
|
||||
|
||||
@@ -33,6 +34,34 @@ def server_url(server_port: int) -> str:
|
||||
return f"http://127.0.0.1:{server_port}"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def http_server_port() -> int:
|
||||
"""Get a free port for testing the StreamableHTTP server."""
|
||||
with socket.socket() as s:
|
||||
s.bind(("127.0.0.1", 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def http_server_url(http_server_port: int) -> str:
|
||||
"""Get the StreamableHTTP server URL for testing."""
|
||||
return f"http://127.0.0.1:{http_server_port}"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def stateless_http_server_port() -> int:
|
||||
"""Get a free port for testing the stateless StreamableHTTP server."""
|
||||
with socket.socket() as s:
|
||||
s.bind(("127.0.0.1", 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def stateless_http_server_url(stateless_http_server_port: int) -> str:
|
||||
"""Get the stateless StreamableHTTP server URL for testing."""
|
||||
return f"http://127.0.0.1:{stateless_http_server_port}"
|
||||
|
||||
|
||||
# Create a function to make the FastMCP server app
|
||||
def make_fastmcp_app():
|
||||
"""Create a FastMCP server without auth settings."""
|
||||
@@ -51,6 +80,40 @@ def make_fastmcp_app():
|
||||
return mcp, app
|
||||
|
||||
|
||||
def make_fastmcp_streamable_http_app():
|
||||
"""Create a FastMCP server with StreamableHTTP transport."""
|
||||
from starlette.applications import Starlette
|
||||
|
||||
mcp = FastMCP(name="NoAuthServer")
|
||||
|
||||
# Add a simple tool
|
||||
@mcp.tool(description="A simple echo tool")
|
||||
def echo(message: str) -> str:
|
||||
return f"Echo: {message}"
|
||||
|
||||
# Create the StreamableHTTP app
|
||||
app: Starlette = mcp.streamable_http_app()
|
||||
|
||||
return mcp, app
|
||||
|
||||
|
||||
def make_fastmcp_stateless_http_app():
|
||||
"""Create a FastMCP server with stateless StreamableHTTP transport."""
|
||||
from starlette.applications import Starlette
|
||||
|
||||
mcp = FastMCP(name="StatelessServer", stateless_http=True)
|
||||
|
||||
# Add a simple tool
|
||||
@mcp.tool(description="A simple echo tool")
|
||||
def echo(message: str) -> str:
|
||||
return f"Echo: {message}"
|
||||
|
||||
# Create the StreamableHTTP app
|
||||
app: Starlette = mcp.streamable_http_app()
|
||||
|
||||
return mcp, app
|
||||
|
||||
|
||||
def run_server(server_port: int) -> None:
|
||||
"""Run the server."""
|
||||
_, app = make_fastmcp_app()
|
||||
@@ -63,6 +126,30 @@ def run_server(server_port: int) -> None:
|
||||
server.run()
|
||||
|
||||
|
||||
def run_streamable_http_server(server_port: int) -> None:
|
||||
"""Run the StreamableHTTP server."""
|
||||
_, app = make_fastmcp_streamable_http_app()
|
||||
server = uvicorn.Server(
|
||||
config=uvicorn.Config(
|
||||
app=app, host="127.0.0.1", port=server_port, log_level="error"
|
||||
)
|
||||
)
|
||||
print(f"Starting StreamableHTTP server on port {server_port}")
|
||||
server.run()
|
||||
|
||||
|
||||
def run_stateless_http_server(server_port: int) -> None:
|
||||
"""Run the stateless StreamableHTTP server."""
|
||||
_, app = make_fastmcp_stateless_http_app()
|
||||
server = uvicorn.Server(
|
||||
config=uvicorn.Config(
|
||||
app=app, host="127.0.0.1", port=server_port, log_level="error"
|
||||
)
|
||||
)
|
||||
print(f"Starting stateless StreamableHTTP server on port {server_port}")
|
||||
server.run()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def server(server_port: int) -> Generator[None, None, None]:
|
||||
"""Start the server in a separate process and clean up after the test."""
|
||||
@@ -94,6 +181,80 @@ def server(server_port: int) -> Generator[None, None, None]:
|
||||
print("Server process failed to terminate")
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def streamable_http_server(http_server_port: int) -> Generator[None, None, None]:
|
||||
"""Start the StreamableHTTP server in a separate process."""
|
||||
proc = multiprocessing.Process(
|
||||
target=run_streamable_http_server, args=(http_server_port,), daemon=True
|
||||
)
|
||||
print("Starting StreamableHTTP server process")
|
||||
proc.start()
|
||||
|
||||
# Wait for server to be running
|
||||
max_attempts = 20
|
||||
attempt = 0
|
||||
print("Waiting for StreamableHTTP server to start")
|
||||
while attempt < max_attempts:
|
||||
try:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.connect(("127.0.0.1", http_server_port))
|
||||
break
|
||||
except ConnectionRefusedError:
|
||||
time.sleep(0.1)
|
||||
attempt += 1
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"StreamableHTTP server failed to start after {max_attempts} attempts"
|
||||
)
|
||||
|
||||
yield
|
||||
|
||||
print("Killing StreamableHTTP server")
|
||||
proc.kill()
|
||||
proc.join(timeout=2)
|
||||
if proc.is_alive():
|
||||
print("StreamableHTTP server process failed to terminate")
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def stateless_http_server(
|
||||
stateless_http_server_port: int,
|
||||
) -> Generator[None, None, None]:
|
||||
"""Start the stateless StreamableHTTP server in a separate process."""
|
||||
proc = multiprocessing.Process(
|
||||
target=run_stateless_http_server,
|
||||
args=(stateless_http_server_port,),
|
||||
daemon=True,
|
||||
)
|
||||
print("Starting stateless StreamableHTTP server process")
|
||||
proc.start()
|
||||
|
||||
# Wait for server to be running
|
||||
max_attempts = 20
|
||||
attempt = 0
|
||||
print("Waiting for stateless StreamableHTTP server to start")
|
||||
while attempt < max_attempts:
|
||||
try:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.connect(("127.0.0.1", stateless_http_server_port))
|
||||
break
|
||||
except ConnectionRefusedError:
|
||||
time.sleep(0.1)
|
||||
attempt += 1
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Stateless server failed to start after {max_attempts} attempts"
|
||||
)
|
||||
|
||||
yield
|
||||
|
||||
print("Killing stateless StreamableHTTP server")
|
||||
proc.kill()
|
||||
proc.join(timeout=2)
|
||||
if proc.is_alive():
|
||||
print("Stateless StreamableHTTP server process failed to terminate")
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_fastmcp_without_auth(server: None, server_url: str) -> None:
|
||||
"""Test that FastMCP works when auth settings are not provided."""
|
||||
@@ -110,3 +271,55 @@ async def test_fastmcp_without_auth(server: None, server_url: str) -> None:
|
||||
assert len(tool_result.content) == 1
|
||||
assert isinstance(tool_result.content[0], TextContent)
|
||||
assert tool_result.content[0].text == "Echo: hello"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_fastmcp_streamable_http(
|
||||
streamable_http_server: None, http_server_url: str
|
||||
) -> None:
|
||||
"""Test that FastMCP works with StreamableHTTP transport."""
|
||||
# Connect to the server using StreamableHTTP
|
||||
async with streamablehttp_client(http_server_url + "/mcp") as (
|
||||
read_stream,
|
||||
write_stream,
|
||||
_,
|
||||
):
|
||||
# Create a session using the client streams
|
||||
async with ClientSession(read_stream, write_stream) as session:
|
||||
# Test initialization
|
||||
result = await session.initialize()
|
||||
assert isinstance(result, InitializeResult)
|
||||
assert result.serverInfo.name == "NoAuthServer"
|
||||
|
||||
# Test that we can call tools without authentication
|
||||
tool_result = await session.call_tool("echo", {"message": "hello"})
|
||||
assert len(tool_result.content) == 1
|
||||
assert isinstance(tool_result.content[0], TextContent)
|
||||
assert tool_result.content[0].text == "Echo: hello"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_fastmcp_stateless_streamable_http(
|
||||
stateless_http_server: None, stateless_http_server_url: str
|
||||
) -> None:
|
||||
"""Test that FastMCP works with stateless StreamableHTTP transport."""
|
||||
# Connect to the server using StreamableHTTP
|
||||
async with streamablehttp_client(stateless_http_server_url + "/mcp") as (
|
||||
read_stream,
|
||||
write_stream,
|
||||
_,
|
||||
):
|
||||
async with ClientSession(read_stream, write_stream) as session:
|
||||
result = await session.initialize()
|
||||
assert isinstance(result, InitializeResult)
|
||||
assert result.serverInfo.name == "StatelessServer"
|
||||
tool_result = await session.call_tool("echo", {"message": "hello"})
|
||||
assert len(tool_result.content) == 1
|
||||
assert isinstance(tool_result.content[0], TextContent)
|
||||
assert tool_result.content[0].text == "Echo: hello"
|
||||
|
||||
for i in range(3):
|
||||
tool_result = await session.call_tool("echo", {"message": f"test_{i}"})
|
||||
assert len(tool_result.content) == 1
|
||||
assert isinstance(tool_result.content[0], TextContent)
|
||||
assert tool_result.content[0].text == f"Echo: test_{i}"
|
||||
|
||||
81
tests/server/test_streamable_http_manager.py
Normal file
81
tests/server/test_streamable_http_manager.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""Tests for StreamableHTTPSessionManager."""
|
||||
|
||||
import anyio
|
||||
import pytest
|
||||
|
||||
from mcp.server.lowlevel import Server
|
||||
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_run_can_only_be_called_once():
|
||||
"""Test that run() can only be called once per instance."""
|
||||
app = Server("test-server")
|
||||
manager = StreamableHTTPSessionManager(app=app)
|
||||
|
||||
# First call should succeed
|
||||
async with manager.run():
|
||||
pass
|
||||
|
||||
# Second call should raise RuntimeError
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
async with manager.run():
|
||||
pass
|
||||
|
||||
assert (
|
||||
"StreamableHTTPSessionManager .run() can only be called once per instance"
|
||||
in str(excinfo.value)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_run_prevents_concurrent_calls():
|
||||
"""Test that concurrent calls to run() are prevented."""
|
||||
app = Server("test-server")
|
||||
manager = StreamableHTTPSessionManager(app=app)
|
||||
|
||||
errors = []
|
||||
|
||||
async def try_run():
|
||||
try:
|
||||
async with manager.run():
|
||||
# Simulate some work
|
||||
await anyio.sleep(0.1)
|
||||
except RuntimeError as e:
|
||||
errors.append(e)
|
||||
|
||||
# Try to run concurrently
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(try_run)
|
||||
tg.start_soon(try_run)
|
||||
|
||||
# One should succeed, one should fail
|
||||
assert len(errors) == 1
|
||||
assert (
|
||||
"StreamableHTTPSessionManager .run() can only be called once per instance"
|
||||
in str(errors[0])
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_handle_request_without_run_raises_error():
|
||||
"""Test that handle_request raises error if run() hasn't been called."""
|
||||
app = Server("test-server")
|
||||
manager = StreamableHTTPSessionManager(app=app)
|
||||
|
||||
# Mock ASGI parameters
|
||||
scope = {"type": "http", "method": "POST", "path": "/test"}
|
||||
|
||||
async def receive():
|
||||
return {"type": "http.request", "body": b""}
|
||||
|
||||
async def send(message):
|
||||
pass
|
||||
|
||||
# Should raise error because run() hasn't been called
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
await manager.handle_request(scope, receive, send)
|
||||
|
||||
assert "Task group is not initialized. Make sure to use run()." in str(
|
||||
excinfo.value
|
||||
)
|
||||
@@ -4,13 +4,10 @@ Tests for the StreamableHTTP server and client transport.
|
||||
Contains tests for both server and client sides of the StreamableHTTP transport.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import multiprocessing
|
||||
import socket
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from http import HTTPStatus
|
||||
from uuid import uuid4
|
||||
|
||||
import anyio
|
||||
import httpx
|
||||
@@ -19,8 +16,6 @@ import requests
|
||||
import uvicorn
|
||||
from pydantic import AnyUrl
|
||||
from starlette.applications import Starlette
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
from starlette.routing import Mount
|
||||
|
||||
import mcp.types as types
|
||||
@@ -37,6 +32,7 @@ from mcp.server.streamable_http import (
|
||||
StreamableHTTPServerTransport,
|
||||
StreamId,
|
||||
)
|
||||
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
|
||||
from mcp.shared.exceptions import McpError
|
||||
from mcp.shared.message import (
|
||||
ClientMessageMetadata,
|
||||
@@ -184,7 +180,7 @@ class ServerTest(Server):
|
||||
def create_app(
|
||||
is_json_response_enabled=False, event_store: EventStore | None = None
|
||||
) -> Starlette:
|
||||
"""Create a Starlette application for testing that matches the example server.
|
||||
"""Create a Starlette application for testing using the session manager.
|
||||
|
||||
Args:
|
||||
is_json_response_enabled: If True, use JSON responses instead of SSE streams.
|
||||
@@ -193,85 +189,20 @@ def create_app(
|
||||
# Create server instance
|
||||
server = ServerTest()
|
||||
|
||||
server_instances = {}
|
||||
# Lock to prevent race conditions when creating new sessions
|
||||
session_creation_lock = anyio.Lock()
|
||||
task_group = None
|
||||
# Create the session manager
|
||||
session_manager = StreamableHTTPSessionManager(
|
||||
app=server,
|
||||
event_store=event_store,
|
||||
json_response=is_json_response_enabled,
|
||||
)
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def lifespan(app):
|
||||
"""Application lifespan context manager for managing task group."""
|
||||
nonlocal task_group
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
task_group = tg
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if task_group:
|
||||
tg.cancel_scope.cancel()
|
||||
task_group = None
|
||||
|
||||
async def handle_streamable_http(scope, receive, send):
|
||||
request = Request(scope, receive)
|
||||
request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER)
|
||||
|
||||
# Use existing transport if session ID matches
|
||||
if (
|
||||
request_mcp_session_id is not None
|
||||
and request_mcp_session_id in server_instances
|
||||
):
|
||||
transport = server_instances[request_mcp_session_id]
|
||||
|
||||
await transport.handle_request(scope, receive, send)
|
||||
elif request_mcp_session_id is None:
|
||||
async with session_creation_lock:
|
||||
new_session_id = uuid4().hex
|
||||
|
||||
http_transport = StreamableHTTPServerTransport(
|
||||
mcp_session_id=new_session_id,
|
||||
is_json_response_enabled=is_json_response_enabled,
|
||||
event_store=event_store,
|
||||
)
|
||||
|
||||
async def run_server(task_status=None):
|
||||
async with http_transport.connect() as streams:
|
||||
read_stream, write_stream = streams
|
||||
if task_status:
|
||||
task_status.started()
|
||||
await server.run(
|
||||
read_stream,
|
||||
write_stream,
|
||||
server.create_initialization_options(),
|
||||
)
|
||||
|
||||
if task_group is None:
|
||||
response = Response(
|
||||
"Internal Server Error: Task group is not initialized",
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
await response(scope, receive, send)
|
||||
return
|
||||
|
||||
# Store the instance before starting the task to prevent races
|
||||
server_instances[http_transport.mcp_session_id] = http_transport
|
||||
await task_group.start(run_server)
|
||||
|
||||
await http_transport.handle_request(scope, receive, send)
|
||||
else:
|
||||
response = Response(
|
||||
"Bad Request: No valid session ID provided",
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
)
|
||||
await response(scope, receive, send)
|
||||
|
||||
# Create an ASGI application
|
||||
# Create an ASGI application that uses the session manager
|
||||
app = Starlette(
|
||||
debug=True,
|
||||
routes=[
|
||||
Mount("/mcp", app=handle_streamable_http),
|
||||
Mount("/mcp", app=session_manager.handle_request),
|
||||
],
|
||||
lifespan=lifespan,
|
||||
lifespan=lambda app: session_manager.run(),
|
||||
)
|
||||
|
||||
return app
|
||||
|
||||
Reference in New Issue
Block a user