mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-18 22:44:20 +01:00
226 lines
8.0 KiB
Python
226 lines
8.0 KiB
Python
import contextlib
|
|
import logging
|
|
from http import HTTPStatus
|
|
from uuid import uuid4
|
|
|
|
import anyio
|
|
import click
|
|
import mcp.types as types
|
|
from mcp.server.lowlevel import Server
|
|
from mcp.server.streamable_http import (
|
|
MCP_SESSION_ID_HEADER,
|
|
StreamableHTTPServerTransport,
|
|
)
|
|
from pydantic import AnyUrl
|
|
from starlette.applications import Starlette
|
|
from starlette.requests import Request
|
|
from starlette.responses import Response
|
|
from starlette.routing import Mount
|
|
|
|
from .event_store import InMemoryEventStore
|
|
|
|
# Configure logging
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Global task group that will be initialized in the lifespan
|
|
task_group = None
|
|
|
|
# Event store for resumability
|
|
# The InMemoryEventStore enables resumability support for StreamableHTTP transport.
|
|
# It stores SSE events with unique IDs, allowing clients to:
|
|
# 1. Receive event IDs for each SSE message
|
|
# 2. Resume streams by sending Last-Event-ID in GET requests
|
|
# 3. Replay missed events after reconnection
|
|
# Note: This in-memory implementation is for demonstration ONLY.
|
|
# For production, use a persistent storage solution.
|
|
event_store = InMemoryEventStore()
|
|
|
|
|
|
@contextlib.asynccontextmanager
|
|
async def lifespan(app):
|
|
"""Application lifespan context manager for managing task group."""
|
|
global task_group
|
|
|
|
async with anyio.create_task_group() as tg:
|
|
task_group = tg
|
|
logger.info("Application started, task group initialized!")
|
|
try:
|
|
yield
|
|
finally:
|
|
logger.info("Application shutting down, cleaning up resources...")
|
|
if task_group:
|
|
tg.cancel_scope.cancel()
|
|
task_group = None
|
|
logger.info("Resources cleaned up successfully.")
|
|
|
|
|
|
@click.command()
|
|
@click.option("--port", default=3000, help="Port to listen on for HTTP")
|
|
@click.option(
|
|
"--log-level",
|
|
default="INFO",
|
|
help="Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)",
|
|
)
|
|
@click.option(
|
|
"--json-response",
|
|
is_flag=True,
|
|
default=False,
|
|
help="Enable JSON responses instead of SSE streams",
|
|
)
|
|
def main(
|
|
port: int,
|
|
log_level: str,
|
|
json_response: bool,
|
|
) -> int:
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=getattr(logging, log_level.upper()),
|
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
|
)
|
|
|
|
app = Server("mcp-streamable-http-demo")
|
|
|
|
@app.call_tool()
|
|
async def call_tool(
|
|
name: str, arguments: dict
|
|
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
|
|
ctx = app.request_context
|
|
interval = arguments.get("interval", 1.0)
|
|
count = arguments.get("count", 5)
|
|
caller = arguments.get("caller", "unknown")
|
|
|
|
# Send the specified number of notifications with the given interval
|
|
for i in range(count):
|
|
# Include more detailed message for resumability demonstration
|
|
notification_msg = (
|
|
f"[{i+1}/{count}] Event from '{caller}' - "
|
|
f"Use Last-Event-ID to resume if disconnected"
|
|
)
|
|
await ctx.session.send_log_message(
|
|
level="info",
|
|
data=notification_msg,
|
|
logger="notification_stream",
|
|
# Associates this notification with the original request
|
|
# Ensures notifications are sent to the correct response stream
|
|
# Without this, notifications will either go to:
|
|
# - a standalone SSE stream (if GET request is supported)
|
|
# - nowhere (if GET request isn't supported)
|
|
related_request_id=ctx.request_id,
|
|
)
|
|
logger.debug(f"Sent notification {i+1}/{count} for caller: {caller}")
|
|
if i < count - 1: # Don't wait after the last notification
|
|
await anyio.sleep(interval)
|
|
|
|
# This will send a resource notificaiton though standalone SSE
|
|
# established by GET request
|
|
await ctx.session.send_resource_updated(uri=AnyUrl("http:///test_resource"))
|
|
return [
|
|
types.TextContent(
|
|
type="text",
|
|
text=(
|
|
f"Sent {count} notifications with {interval}s interval"
|
|
f" for caller: {caller}"
|
|
),
|
|
)
|
|
]
|
|
|
|
@app.list_tools()
|
|
async def list_tools() -> list[types.Tool]:
|
|
return [
|
|
types.Tool(
|
|
name="start-notification-stream",
|
|
description=(
|
|
"Sends a stream of notifications with configurable count"
|
|
" and interval"
|
|
),
|
|
inputSchema={
|
|
"type": "object",
|
|
"required": ["interval", "count", "caller"],
|
|
"properties": {
|
|
"interval": {
|
|
"type": "number",
|
|
"description": "Interval between notifications in seconds",
|
|
},
|
|
"count": {
|
|
"type": "number",
|
|
"description": "Number of notifications to send",
|
|
},
|
|
"caller": {
|
|
"type": "string",
|
|
"description": (
|
|
"Identifier of the caller to include in notifications"
|
|
),
|
|
},
|
|
},
|
|
},
|
|
)
|
|
]
|
|
|
|
# We need to store the server instances between requests
|
|
server_instances = {}
|
|
# Lock to prevent race conditions when creating new sessions
|
|
session_creation_lock = anyio.Lock()
|
|
|
|
# ASGI handler for streamable HTTP connections
|
|
async def handle_streamable_http(scope, receive, send):
|
|
request = Request(scope, receive)
|
|
request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER)
|
|
if (
|
|
request_mcp_session_id is not None
|
|
and request_mcp_session_id in server_instances
|
|
):
|
|
transport = server_instances[request_mcp_session_id]
|
|
logger.debug("Session already exists, handling request directly")
|
|
await transport.handle_request(scope, receive, send)
|
|
elif request_mcp_session_id is None:
|
|
# try to establish new session
|
|
logger.debug("Creating new transport")
|
|
# Use lock to prevent race conditions when creating new sessions
|
|
async with session_creation_lock:
|
|
new_session_id = uuid4().hex
|
|
http_transport = StreamableHTTPServerTransport(
|
|
mcp_session_id=new_session_id,
|
|
is_json_response_enabled=json_response,
|
|
event_store=event_store, # Enable resumability
|
|
)
|
|
server_instances[http_transport.mcp_session_id] = http_transport
|
|
logger.info(f"Created new transport with session ID: {new_session_id}")
|
|
async with http_transport.connect() as streams:
|
|
read_stream, write_stream = streams
|
|
|
|
async def run_server():
|
|
await app.run(
|
|
read_stream,
|
|
write_stream,
|
|
app.create_initialization_options(),
|
|
)
|
|
|
|
if not task_group:
|
|
raise RuntimeError("Task group is not initialized")
|
|
|
|
task_group.start_soon(run_server)
|
|
|
|
# Handle the HTTP request and return the response
|
|
await http_transport.handle_request(scope, receive, send)
|
|
else:
|
|
response = Response(
|
|
"Bad Request: No valid session ID provided",
|
|
status_code=HTTPStatus.BAD_REQUEST,
|
|
)
|
|
await response(scope, receive, send)
|
|
|
|
# Create an ASGI application using the transport
|
|
starlette_app = Starlette(
|
|
debug=True,
|
|
routes=[
|
|
Mount("/mcp", app=handle_streamable_http),
|
|
],
|
|
lifespan=lifespan,
|
|
)
|
|
|
|
import uvicorn
|
|
|
|
uvicorn.run(starlette_app, host="0.0.0.0", port=port)
|
|
|
|
return 0
|