import contextlib import logging from http import HTTPStatus from uuid import uuid4 import anyio import click import mcp.types as types from mcp.server.lowlevel import Server from mcp.server.streamable_http import ( MCP_SESSION_ID_HEADER, StreamableHTTPServerTransport, ) from pydantic import AnyUrl from starlette.applications import Starlette from starlette.requests import Request from starlette.responses import Response from starlette.routing import Mount from .event_store import InMemoryEventStore # Configure logging logger = logging.getLogger(__name__) # Global task group that will be initialized in the lifespan task_group = None # Event store for resumability # The InMemoryEventStore enables resumability support for StreamableHTTP transport. # It stores SSE events with unique IDs, allowing clients to: # 1. Receive event IDs for each SSE message # 2. Resume streams by sending Last-Event-ID in GET requests # 3. Replay missed events after reconnection # Note: This in-memory implementation is for demonstration ONLY. # For production, use a persistent storage solution. event_store = InMemoryEventStore() @contextlib.asynccontextmanager async def lifespan(app): """Application lifespan context manager for managing task group.""" global task_group async with anyio.create_task_group() as tg: task_group = tg logger.info("Application started, task group initialized!") try: yield finally: logger.info("Application shutting down, cleaning up resources...") if task_group: tg.cancel_scope.cancel() task_group = None logger.info("Resources cleaned up successfully.") @click.command() @click.option("--port", default=3000, help="Port to listen on for HTTP") @click.option( "--log-level", default="INFO", help="Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)", ) @click.option( "--json-response", is_flag=True, default=False, help="Enable JSON responses instead of SSE streams", ) def main( port: int, log_level: str, json_response: bool, ) -> int: # Configure logging logging.basicConfig( level=getattr(logging, log_level.upper()), format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) app = Server("mcp-streamable-http-demo") @app.call_tool() async def call_tool( name: str, arguments: dict ) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]: ctx = app.request_context interval = arguments.get("interval", 1.0) count = arguments.get("count", 5) caller = arguments.get("caller", "unknown") # Send the specified number of notifications with the given interval for i in range(count): # Include more detailed message for resumability demonstration notification_msg = ( f"[{i+1}/{count}] Event from '{caller}' - " f"Use Last-Event-ID to resume if disconnected" ) await ctx.session.send_log_message( level="info", data=notification_msg, logger="notification_stream", # Associates this notification with the original request # Ensures notifications are sent to the correct response stream # Without this, notifications will either go to: # - a standalone SSE stream (if GET request is supported) # - nowhere (if GET request isn't supported) related_request_id=ctx.request_id, ) logger.debug(f"Sent notification {i+1}/{count} for caller: {caller}") if i < count - 1: # Don't wait after the last notification await anyio.sleep(interval) # This will send a resource notificaiton though standalone SSE # established by GET request await ctx.session.send_resource_updated(uri=AnyUrl("http:///test_resource")) return [ types.TextContent( type="text", text=( f"Sent {count} notifications with {interval}s interval" f" for caller: {caller}" ), ) ] @app.list_tools() async def list_tools() -> list[types.Tool]: return [ types.Tool( name="start-notification-stream", description=( "Sends a stream of notifications with configurable count" " and interval" ), inputSchema={ "type": "object", "required": ["interval", "count", "caller"], "properties": { "interval": { "type": "number", "description": "Interval between notifications in seconds", }, "count": { "type": "number", "description": "Number of notifications to send", }, "caller": { "type": "string", "description": ( "Identifier of the caller to include in notifications" ), }, }, }, ) ] # We need to store the server instances between requests server_instances = {} # Lock to prevent race conditions when creating new sessions session_creation_lock = anyio.Lock() # ASGI handler for streamable HTTP connections async def handle_streamable_http(scope, receive, send): request = Request(scope, receive) request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER) if ( request_mcp_session_id is not None and request_mcp_session_id in server_instances ): transport = server_instances[request_mcp_session_id] logger.debug("Session already exists, handling request directly") await transport.handle_request(scope, receive, send) elif request_mcp_session_id is None: # try to establish new session logger.debug("Creating new transport") # Use lock to prevent race conditions when creating new sessions async with session_creation_lock: new_session_id = uuid4().hex http_transport = StreamableHTTPServerTransport( mcp_session_id=new_session_id, is_json_response_enabled=json_response, event_store=event_store, # Enable resumability ) server_instances[http_transport.mcp_session_id] = http_transport logger.info(f"Created new transport with session ID: {new_session_id}") async def run_server(task_status=None): async with http_transport.connect() as streams: read_stream, write_stream = streams if task_status: task_status.started() await app.run( read_stream, write_stream, app.create_initialization_options(), ) if not task_group: raise RuntimeError("Task group is not initialized") await task_group.start(run_server) # Handle the HTTP request and return the response await http_transport.handle_request(scope, receive, send) else: response = Response( "Bad Request: No valid session ID provided", status_code=HTTPStatus.BAD_REQUEST, ) await response(scope, receive, send) # Create an ASGI application using the transport starlette_app = Starlette( debug=True, routes=[ Mount("/mcp", app=handle_streamable_http), ], lifespan=lifespan, ) import uvicorn uvicorn.run(starlette_app, host="0.0.0.0", port=port) return 0