Files
mcp-python-sdk/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py

206 lines
7.1 KiB
Python

import contextlib
import logging
from http import HTTPStatus
from uuid import uuid4
import anyio
import click
import mcp.types as types
from mcp.server.lowlevel import Server
from mcp.server.streamableHttp import (
MCP_SESSION_ID_HEADER,
StreamableHTTPServerTransport,
)
from pydantic import AnyUrl
from starlette.applications import Starlette
from starlette.requests import Request
from starlette.responses import Response
from starlette.routing import Mount
# Configure logging
logger = logging.getLogger(__name__)
# Global task group that will be initialized in the lifespan
task_group = None
@contextlib.asynccontextmanager
async def lifespan(app):
"""Application lifespan context manager for managing task group."""
global task_group
async with anyio.create_task_group() as tg:
task_group = tg
logger.info("Application started, task group initialized!")
try:
yield
finally:
logger.info("Application shutting down, cleaning up resources...")
if task_group:
tg.cancel_scope.cancel()
task_group = None
logger.info("Resources cleaned up successfully.")
@click.command()
@click.option("--port", default=3000, help="Port to listen on for HTTP")
@click.option(
"--log-level",
default="INFO",
help="Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)",
)
@click.option(
"--json-response",
is_flag=True,
default=False,
help="Enable JSON responses instead of SSE streams",
)
def main(
port: int,
log_level: str,
json_response: bool,
) -> int:
# Configure logging
logging.basicConfig(
level=getattr(logging, log_level.upper()),
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
app = Server("mcp-streamable-http-demo")
@app.call_tool()
async def call_tool(
name: str, arguments: dict
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
ctx = app.request_context
interval = arguments.get("interval", 1.0)
count = arguments.get("count", 5)
caller = arguments.get("caller", "unknown")
# Send the specified number of notifications with the given interval
for i in range(count):
await ctx.session.send_log_message(
level="info",
data=f"Notification {i+1}/{count} from caller: {caller}",
logger="notification_stream",
# Associates this notification with the original request
# Ensures notifications are sent to the correct response stream
# Without this, notifications will either go to:
# - a standalone SSE stream (if GET request is supported)
# - nowhere (if GET request isn't supported)
related_request_id=ctx.request_id,
)
if i < count - 1: # Don't wait after the last notification
await anyio.sleep(interval)
# This will send a resource notificaiton though standalone SSE
# established by GET request
await ctx.session.send_resource_updated(uri=AnyUrl("http:///test_resource"))
return [
types.TextContent(
type="text",
text=(
f"Sent {count} notifications with {interval}s interval"
f" for caller: {caller}"
),
)
]
@app.list_tools()
async def list_tools() -> list[types.Tool]:
return [
types.Tool(
name="start-notification-stream",
description=(
"Sends a stream of notifications with configurable count"
" and interval"
),
inputSchema={
"type": "object",
"required": ["interval", "count", "caller"],
"properties": {
"interval": {
"type": "number",
"description": "Interval between notifications in seconds",
},
"count": {
"type": "number",
"description": "Number of notifications to send",
},
"caller": {
"type": "string",
"description": (
"Identifier of the caller to include in notifications"
),
},
},
},
)
]
# We need to store the server instances between requests
server_instances = {}
# Lock to prevent race conditions when creating new sessions
session_creation_lock = anyio.Lock()
# ASGI handler for streamable HTTP connections
async def handle_streamable_http(scope, receive, send):
request = Request(scope, receive)
request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER)
if (
request_mcp_session_id is not None
and request_mcp_session_id in server_instances
):
transport = server_instances[request_mcp_session_id]
logger.debug("Session already exists, handling request directly")
await transport.handle_request(scope, receive, send)
elif request_mcp_session_id is None:
# try to establish new session
logger.debug("Creating new transport")
# Use lock to prevent race conditions when creating new sessions
async with session_creation_lock:
new_session_id = uuid4().hex
http_transport = StreamableHTTPServerTransport(
mcp_session_id=new_session_id,
is_json_response_enabled=json_response,
)
server_instances[http_transport.mcp_session_id] = http_transport
async with http_transport.connect() as streams:
read_stream, write_stream = streams
async def run_server():
await app.run(
read_stream,
write_stream,
app.create_initialization_options(),
)
if not task_group:
raise RuntimeError("Task group is not initialized")
task_group.start_soon(run_server)
# Handle the HTTP request and return the response
await http_transport.handle_request(scope, receive, send)
else:
response = Response(
"Bad Request: No valid session ID provided",
status_code=HTTPStatus.BAD_REQUEST,
)
await response(scope, receive, send)
# Create an ASGI application using the transport
starlette_app = Starlette(
debug=True,
routes=[
Mount("/mcp", app=handle_streamable_http),
],
lifespan=lifespan,
)
import uvicorn
uvicorn.run(starlette_app, host="0.0.0.0", port=port)
return 0