StreamableHttp - Server transport with state management (#553)

This commit is contained in:
ihrpr
2025-05-02 11:58:54 +01:00
committed by GitHub
parent 2210c1be18
commit 78f0b11a09
14 changed files with 1570 additions and 18 deletions

View File

@@ -0,0 +1,37 @@
# MCP Simple StreamableHttp Server Example
A simple MCP server example demonstrating the StreamableHttp transport, which enables HTTP-based communication with MCP servers using streaming.
## Features
- Uses the StreamableHTTP transport for server-client communication
- Supports REST API operations (POST, GET, DELETE) for `/mcp` endpoint
- Task management with anyio task groups
- Ability to send multiple notifications over time to the client
- Proper resource cleanup and lifespan management
## Usage
Start the server on the default or custom port:
```bash
# Using custom port
uv run mcp-simple-streamablehttp --port 3000
# Custom logging level
uv run mcp-simple-streamablehttp --log-level DEBUG
# Enable JSON responses instead of SSE streams
uv run mcp-simple-streamablehttp --json-response
```
The server exposes a tool named "start-notification-stream" that accepts three arguments:
- `interval`: Time between notifications in seconds (e.g., 1.0)
- `count`: Number of notifications to send (e.g., 5)
- `caller`: Identifier string for the caller
## Client
You can connect to this server using an HTTP client, for now only Typescript SDK has streamable HTTP client examples or you can use (Inspector)[https://github.com/modelcontextprotocol/inspector]

View File

@@ -0,0 +1,4 @@
from .server import main
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,201 @@
import contextlib
import logging
from http import HTTPStatus
from uuid import uuid4
import anyio
import click
import mcp.types as types
from mcp.server.lowlevel import Server
from mcp.server.streamableHttp import (
MCP_SESSION_ID_HEADER,
StreamableHTTPServerTransport,
)
from starlette.applications import Starlette
from starlette.requests import Request
from starlette.responses import Response
from starlette.routing import Mount
# Configure logging
logger = logging.getLogger(__name__)
# Global task group that will be initialized in the lifespan
task_group = None
@contextlib.asynccontextmanager
async def lifespan(app):
"""Application lifespan context manager for managing task group."""
global task_group
async with anyio.create_task_group() as tg:
task_group = tg
logger.info("Application started, task group initialized!")
try:
yield
finally:
logger.info("Application shutting down, cleaning up resources...")
if task_group:
tg.cancel_scope.cancel()
task_group = None
logger.info("Resources cleaned up successfully.")
@click.command()
@click.option("--port", default=3000, help="Port to listen on for HTTP")
@click.option(
"--log-level",
default="INFO",
help="Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)",
)
@click.option(
"--json-response",
is_flag=True,
default=False,
help="Enable JSON responses instead of SSE streams",
)
def main(
port: int,
log_level: str,
json_response: bool,
) -> int:
# Configure logging
logging.basicConfig(
level=getattr(logging, log_level.upper()),
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
app = Server("mcp-streamable-http-demo")
@app.call_tool()
async def call_tool(
name: str, arguments: dict
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
ctx = app.request_context
interval = arguments.get("interval", 1.0)
count = arguments.get("count", 5)
caller = arguments.get("caller", "unknown")
# Send the specified number of notifications with the given interval
for i in range(count):
await ctx.session.send_log_message(
level="info",
data=f"Notification {i+1}/{count} from caller: {caller}",
logger="notification_stream",
# Associates this notification with the original request
# Ensures notifications are sent to the correct response stream
# Without this, notifications will either go to:
# - a standalone SSE stream (if GET request is supported)
# - nowhere (if GET request isn't supported)
related_request_id=ctx.request_id,
)
if i < count - 1: # Don't wait after the last notification
await anyio.sleep(interval)
return [
types.TextContent(
type="text",
text=(
f"Sent {count} notifications with {interval}s interval"
f" for caller: {caller}"
),
)
]
@app.list_tools()
async def list_tools() -> list[types.Tool]:
return [
types.Tool(
name="start-notification-stream",
description=(
"Sends a stream of notifications with configurable count"
" and interval"
),
inputSchema={
"type": "object",
"required": ["interval", "count", "caller"],
"properties": {
"interval": {
"type": "number",
"description": "Interval between notifications in seconds",
},
"count": {
"type": "number",
"description": "Number of notifications to send",
},
"caller": {
"type": "string",
"description": (
"Identifier of the caller to include in notifications"
),
},
},
},
)
]
# We need to store the server instances between requests
server_instances = {}
# Lock to prevent race conditions when creating new sessions
session_creation_lock = anyio.Lock()
# ASGI handler for streamable HTTP connections
async def handle_streamable_http(scope, receive, send):
request = Request(scope, receive)
request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER)
if (
request_mcp_session_id is not None
and request_mcp_session_id in server_instances
):
transport = server_instances[request_mcp_session_id]
logger.debug("Session already exists, handling request directly")
await transport.handle_request(scope, receive, send)
elif request_mcp_session_id is None:
# try to establish new session
logger.debug("Creating new transport")
# Use lock to prevent race conditions when creating new sessions
async with session_creation_lock:
new_session_id = uuid4().hex
http_transport = StreamableHTTPServerTransport(
mcp_session_id=new_session_id,
is_json_response_enabled=json_response,
)
server_instances[http_transport.mcp_session_id] = http_transport
async with http_transport.connect() as streams:
read_stream, write_stream = streams
async def run_server():
await app.run(
read_stream,
write_stream,
app.create_initialization_options(),
)
if not task_group:
raise RuntimeError("Task group is not initialized")
task_group.start_soon(run_server)
# Handle the HTTP request and return the response
await http_transport.handle_request(scope, receive, send)
else:
response = Response(
"Bad Request: No valid session ID provided",
status_code=HTTPStatus.BAD_REQUEST,
)
await response(scope, receive, send)
# Create an ASGI application using the transport
starlette_app = Starlette(
debug=True,
routes=[
Mount("/mcp", app=handle_streamable_http),
],
lifespan=lifespan,
)
import uvicorn
uvicorn.run(starlette_app, host="0.0.0.0", port=port)
return 0

View File

@@ -0,0 +1,36 @@
[project]
name = "mcp-simple-streamablehttp"
version = "0.1.0"
description = "A simple MCP server exposing a StreamableHttp transport for testing"
readme = "README.md"
requires-python = ">=3.10"
authors = [{ name = "Anthropic, PBC." }]
keywords = ["mcp", "llm", "automation", "web", "fetch", "http", "streamable"]
license = { text = "MIT" }
dependencies = ["anyio>=4.5", "click>=8.1.0", "httpx>=0.27", "mcp", "starlette", "uvicorn"]
[project.scripts]
mcp-simple-streamablehttp = "mcp_simple_streamablehttp.server:main"
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.hatch.build.targets.wheel]
packages = ["mcp_simple_streamablehttp"]
[tool.pyright]
include = ["mcp_simple_streamablehttp"]
venvPath = "."
venv = ".venv"
[tool.ruff.lint]
select = ["E", "F", "I"]
ignore = []
[tool.ruff]
line-length = 88
target-version = "py310"
[tool.uv]
dev-dependencies = ["pyright>=1.1.378", "pytest>=8.3.3", "ruff>=0.6.9"]

View File

@@ -814,7 +814,10 @@ class Context(BaseModel, Generic[ServerSessionT, LifespanContextT]):
**extra: Additional structured data to include
"""
await self.request_context.session.send_log_message(
level=level, data=message, logger=logger_name
level=level,
data=message,
logger=logger_name,
related_request_id=self.request_id,
)
@property

View File

@@ -179,7 +179,11 @@ class ServerSession(
)
async def send_log_message(
self, level: types.LoggingLevel, data: Any, logger: str | None = None
self,
level: types.LoggingLevel,
data: Any,
logger: str | None = None,
related_request_id: types.RequestId | None = None,
) -> None:
"""Send a log message notification."""
await self.send_notification(
@@ -192,7 +196,8 @@ class ServerSession(
logger=logger,
),
)
)
),
related_request_id,
)
async def send_resource_updated(self, uri: AnyUrl) -> None:
@@ -261,7 +266,11 @@ class ServerSession(
)
async def send_progress_notification(
self, progress_token: str | int, progress: float, total: float | None = None
self,
progress_token: str | int,
progress: float,
total: float | None = None,
related_request_id: str | None = None,
) -> None:
"""Send a progress notification."""
await self.send_notification(
@@ -274,7 +283,8 @@ class ServerSession(
total=total,
),
)
)
),
related_request_id,
)
async def send_resource_list_changed(self) -> None:

View File

@@ -0,0 +1,644 @@
"""
StreamableHTTP Server Transport Module
This module implements an HTTP transport layer with Streamable HTTP.
The transport handles bidirectional communication using HTTP requests and
responses, with streaming support for long-running operations.
"""
import json
import logging
import re
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from http import HTTPStatus
from typing import Any
import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import ValidationError
from sse_starlette import EventSourceResponse
from starlette.requests import Request
from starlette.responses import Response
from starlette.types import Receive, Scope, Send
from mcp.types import (
INTERNAL_ERROR,
INVALID_PARAMS,
INVALID_REQUEST,
PARSE_ERROR,
ErrorData,
JSONRPCError,
JSONRPCMessage,
JSONRPCNotification,
JSONRPCRequest,
JSONRPCResponse,
RequestId,
)
logger = logging.getLogger(__name__)
# Maximum size for incoming messages
MAXIMUM_MESSAGE_SIZE = 4 * 1024 * 1024 # 4MB
# Header names
MCP_SESSION_ID_HEADER = "mcp-session-id"
LAST_EVENT_ID_HEADER = "last-event-id"
# Content types
CONTENT_TYPE_JSON = "application/json"
CONTENT_TYPE_SSE = "text/event-stream"
# Session ID validation pattern (visible ASCII characters ranging from 0x21 to 0x7E)
# Pattern ensures entire string contains only valid characters by using ^ and $ anchors
SESSION_ID_PATTERN = re.compile(r"^[\x21-\x7E]+$")
class StreamableHTTPServerTransport:
"""
HTTP server transport with event streaming support for MCP.
Handles JSON-RPC messages in HTTP POST requests with SSE streaming.
Supports optional JSON responses and session management.
"""
# Server notification streams for POST requests as well as standalone SSE stream
_read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception] | None = (
None
)
_write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage] | None = None
def __init__(
self,
mcp_session_id: str | None,
is_json_response_enabled: bool = False,
) -> None:
"""
Initialize a new StreamableHTTP server transport.
Args:
mcp_session_id: Optional session identifier for this connection.
Must contain only visible ASCII characters (0x21-0x7E).
is_json_response_enabled: If True, return JSON responses for requests
instead of SSE streams. Default is False.
Raises:
ValueError: If the session ID contains invalid characters.
"""
if mcp_session_id is not None and not SESSION_ID_PATTERN.fullmatch(
mcp_session_id
):
raise ValueError(
"Session ID must only contain visible ASCII characters (0x21-0x7E)"
)
self.mcp_session_id = mcp_session_id
self.is_json_response_enabled = is_json_response_enabled
self._request_streams: dict[
RequestId, MemoryObjectSendStream[JSONRPCMessage]
] = {}
self._terminated = False
def _create_error_response(
self,
error_message: str,
status_code: HTTPStatus,
error_code: int = INVALID_REQUEST,
headers: dict[str, str] | None = None,
) -> Response:
"""Create an error response with a simple string message."""
response_headers = {"Content-Type": CONTENT_TYPE_JSON}
if headers:
response_headers.update(headers)
if self.mcp_session_id:
response_headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id
# Return a properly formatted JSON error response
error_response = JSONRPCError(
jsonrpc="2.0",
id="server-error", # We don't have a request ID for general errors
error=ErrorData(
code=error_code,
message=error_message,
),
)
return Response(
error_response.model_dump_json(by_alias=True, exclude_none=True),
status_code=status_code,
headers=response_headers,
)
def _create_json_response(
self,
response_message: JSONRPCMessage | None,
status_code: HTTPStatus = HTTPStatus.OK,
headers: dict[str, str] | None = None,
) -> Response:
"""Create a JSON response from a JSONRPCMessage"""
response_headers = {"Content-Type": CONTENT_TYPE_JSON}
if headers:
response_headers.update(headers)
if self.mcp_session_id:
response_headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id
return Response(
response_message.model_dump_json(by_alias=True, exclude_none=True)
if response_message
else None,
status_code=status_code,
headers=response_headers,
)
def _get_session_id(self, request: Request) -> str | None:
"""Extract the session ID from request headers."""
return request.headers.get(MCP_SESSION_ID_HEADER)
async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None:
"""Application entry point that handles all HTTP requests"""
request = Request(scope, receive)
if self._terminated:
# If the session has been terminated, return 404 Not Found
response = self._create_error_response(
"Not Found: Session has been terminated",
HTTPStatus.NOT_FOUND,
)
await response(scope, receive, send)
return
if request.method == "POST":
await self._handle_post_request(scope, request, receive, send)
elif request.method == "GET":
await self._handle_get_request(request, send)
elif request.method == "DELETE":
await self._handle_delete_request(request, send)
else:
await self._handle_unsupported_request(request, send)
def _check_accept_headers(self, request: Request) -> tuple[bool, bool]:
"""Check if the request accepts the required media types."""
accept_header = request.headers.get("accept", "")
accept_types = [media_type.strip() for media_type in accept_header.split(",")]
has_json = any(
media_type.startswith(CONTENT_TYPE_JSON) for media_type in accept_types
)
has_sse = any(
media_type.startswith(CONTENT_TYPE_SSE) for media_type in accept_types
)
return has_json, has_sse
def _check_content_type(self, request: Request) -> bool:
"""Check if the request has the correct Content-Type."""
content_type = request.headers.get("content-type", "")
content_type_parts = [
part.strip() for part in content_type.split(";")[0].split(",")
]
return any(part == CONTENT_TYPE_JSON for part in content_type_parts)
async def _handle_post_request(
self, scope: Scope, request: Request, receive: Receive, send: Send
) -> None:
"""Handle POST requests containing JSON-RPC messages."""
writer = self._read_stream_writer
if writer is None:
raise ValueError(
"No read stream writer available. Ensure connect() is called first."
)
try:
# Check Accept headers
has_json, has_sse = self._check_accept_headers(request)
if not (has_json and has_sse):
response = self._create_error_response(
(
"Not Acceptable: Client must accept both application/json and "
"text/event-stream"
),
HTTPStatus.NOT_ACCEPTABLE,
)
await response(scope, receive, send)
return
# Validate Content-Type
if not self._check_content_type(request):
response = self._create_error_response(
"Unsupported Media Type: Content-Type must be application/json",
HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
)
await response(scope, receive, send)
return
# Parse the body - only read it once
body = await request.body()
if len(body) > MAXIMUM_MESSAGE_SIZE:
response = self._create_error_response(
"Payload Too Large: Message exceeds maximum size",
HTTPStatus.REQUEST_ENTITY_TOO_LARGE,
)
await response(scope, receive, send)
return
try:
raw_message = json.loads(body)
except json.JSONDecodeError as e:
response = self._create_error_response(
f"Parse error: {str(e)}", HTTPStatus.BAD_REQUEST, PARSE_ERROR
)
await response(scope, receive, send)
return
try:
message = JSONRPCMessage.model_validate(raw_message)
except ValidationError as e:
response = self._create_error_response(
f"Validation error: {str(e)}",
HTTPStatus.BAD_REQUEST,
INVALID_PARAMS,
)
await response(scope, receive, send)
return
# Check if this is an initialization request
is_initialization_request = (
isinstance(message.root, JSONRPCRequest)
and message.root.method == "initialize"
)
if is_initialization_request:
# Check if the server already has an established session
if self.mcp_session_id:
# Check if request has a session ID
request_session_id = self._get_session_id(request)
# If request has a session ID but doesn't match, return 404
if request_session_id and request_session_id != self.mcp_session_id:
response = self._create_error_response(
"Not Found: Invalid or expired session ID",
HTTPStatus.NOT_FOUND,
)
await response(scope, receive, send)
return
# For non-initialization requests, validate the session
elif not await self._validate_session(request, send):
return
# For notifications and responses only, return 202 Accepted
if not isinstance(message.root, JSONRPCRequest):
# Create response object and send it
response = self._create_json_response(
None,
HTTPStatus.ACCEPTED,
)
await response(scope, receive, send)
# Process the message after sending the response
await writer.send(message)
return
# Extract the request ID outside the try block for proper scope
request_id = str(message.root.id)
# Create promise stream for getting response
request_stream_writer, request_stream_reader = (
anyio.create_memory_object_stream[JSONRPCMessage](0)
)
# Register this stream for the request ID
self._request_streams[request_id] = request_stream_writer
if self.is_json_response_enabled:
# Process the message
await writer.send(message)
try:
# Process messages from the request-specific stream
# We need to collect all messages until we get a response
response_message = None
# Use similar approach to SSE writer for consistency
async for received_message in request_stream_reader:
# If it's a response, this is what we're waiting for
if isinstance(
received_message.root, JSONRPCResponse | JSONRPCError
):
response_message = received_message
break
# For notifications and request, keep waiting
else:
logger.debug(f"received: {received_message.root.method}")
# At this point we should have a response
if response_message:
# Create JSON response
response = self._create_json_response(response_message)
await response(scope, receive, send)
else:
# This shouldn't happen in normal operation
logger.error(
"No response message received before stream closed"
)
response = self._create_error_response(
"Error processing request: No response received",
HTTPStatus.INTERNAL_SERVER_ERROR,
)
await response(scope, receive, send)
except Exception as e:
logger.exception(f"Error processing JSON response: {e}")
response = self._create_error_response(
f"Error processing request: {str(e)}",
HTTPStatus.INTERNAL_SERVER_ERROR,
INTERNAL_ERROR,
)
await response(scope, receive, send)
finally:
# Clean up the request stream
if request_id in self._request_streams:
self._request_streams.pop(request_id, None)
await request_stream_reader.aclose()
await request_stream_writer.aclose()
else:
# Create SSE stream
sse_stream_writer, sse_stream_reader = (
anyio.create_memory_object_stream[dict[str, Any]](0)
)
async def sse_writer():
# Get the request ID from the incoming request message
try:
async with sse_stream_writer, request_stream_reader:
# Process messages from the request-specific stream
async for received_message in request_stream_reader:
# Build the event data
event_data = {
"event": "message",
"data": received_message.model_dump_json(
by_alias=True, exclude_none=True
),
}
await sse_stream_writer.send(event_data)
# If response, remove from pending streams and close
if isinstance(
received_message.root,
JSONRPCResponse | JSONRPCError,
):
if request_id:
self._request_streams.pop(request_id, None)
break
except Exception as e:
logger.exception(f"Error in SSE writer: {e}")
finally:
logger.debug("Closing SSE writer")
# Clean up the request-specific streams
if request_id and request_id in self._request_streams:
self._request_streams.pop(request_id, None)
# Create and start EventSourceResponse
# SSE stream mode (original behavior)
# Set up headers
headers = {
"Cache-Control": "no-cache, no-transform",
"Connection": "keep-alive",
"Content-Type": CONTENT_TYPE_SSE,
**(
{MCP_SESSION_ID_HEADER: self.mcp_session_id}
if self.mcp_session_id
else {}
),
}
response = EventSourceResponse(
content=sse_stream_reader,
data_sender_callable=sse_writer,
headers=headers,
)
# Start the SSE response (this will send headers immediately)
try:
# First send the response to establish the SSE connection
async with anyio.create_task_group() as tg:
tg.start_soon(response, scope, receive, send)
# Then send the message to be processed by the server
await writer.send(message)
except Exception:
logger.exception("SSE response error")
# Clean up the request stream if something goes wrong
if request_id and request_id in self._request_streams:
self._request_streams.pop(request_id, None)
except Exception as err:
logger.exception("Error handling POST request")
response = self._create_error_response(
f"Error handling POST request: {err}",
HTTPStatus.INTERNAL_SERVER_ERROR,
INTERNAL_ERROR,
)
await response(scope, receive, send)
if writer:
await writer.send(err)
return
async def _handle_get_request(self, request: Request, send: Send) -> None:
"""Handle GET requests for SSE stream establishment."""
# Validate session ID if server has one
if not await self._validate_session(request, send):
return
# Validate Accept header - must include text/event-stream
_, has_sse = self._check_accept_headers(request)
if not has_sse:
response = self._create_error_response(
"Not Acceptable: Client must accept text/event-stream",
HTTPStatus.NOT_ACCEPTABLE,
)
await response(request.scope, request.receive, send)
return
# TODO: Implement SSE stream for GET requests
# For now, return 405 Method Not Allowed
response = self._create_error_response(
"SSE stream from GET request not implemented yet",
HTTPStatus.METHOD_NOT_ALLOWED,
)
await response(request.scope, request.receive, send)
async def _handle_delete_request(self, request: Request, send: Send) -> None:
"""Handle DELETE requests for explicit session termination."""
# Validate session ID
if not self.mcp_session_id:
# If no session ID set, return Method Not Allowed
response = self._create_error_response(
"Method Not Allowed: Session termination not supported",
HTTPStatus.METHOD_NOT_ALLOWED,
)
await response(request.scope, request.receive, send)
return
if not await self._validate_session(request, send):
return
self._terminate_session()
response = self._create_json_response(
None,
HTTPStatus.OK,
)
await response(request.scope, request.receive, send)
def _terminate_session(self) -> None:
"""Terminate the current session, closing all streams.
Once terminated, all requests with this session ID will receive 404 Not Found.
"""
self._terminated = True
logger.info(f"Terminating session: {self.mcp_session_id}")
# We need a copy of the keys to avoid modification during iteration
request_stream_keys = list(self._request_streams.keys())
# Close all request streams (synchronously)
for key in request_stream_keys:
try:
# Get the stream
stream = self._request_streams.get(key)
if stream:
# We must use close() here, not aclose() since this is a sync method
stream.close()
except Exception as e:
logger.debug(f"Error closing stream {key} during termination: {e}")
# Clear the request streams dictionary immediately
self._request_streams.clear()
async def _handle_unsupported_request(self, request: Request, send: Send) -> None:
"""Handle unsupported HTTP methods."""
headers = {
"Content-Type": CONTENT_TYPE_JSON,
"Allow": "GET, POST, DELETE",
}
if self.mcp_session_id:
headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id
response = self._create_error_response(
"Method Not Allowed",
HTTPStatus.METHOD_NOT_ALLOWED,
headers=headers,
)
await response(request.scope, request.receive, send)
async def _validate_session(self, request: Request, send: Send) -> bool:
"""Validate the session ID in the request."""
if not self.mcp_session_id:
# If we're not using session IDs, return True
return True
# Get the session ID from the request headers
request_session_id = self._get_session_id(request)
# If no session ID provided but required, return error
if not request_session_id:
response = self._create_error_response(
"Bad Request: Missing session ID",
HTTPStatus.BAD_REQUEST,
)
await response(request.scope, request.receive, send)
return False
# If session ID doesn't match, return error
if request_session_id != self.mcp_session_id:
response = self._create_error_response(
"Not Found: Invalid or expired session ID",
HTTPStatus.NOT_FOUND,
)
await response(request.scope, request.receive, send)
return False
return True
@asynccontextmanager
async def connect(
self,
) -> AsyncGenerator[
tuple[
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
MemoryObjectSendStream[JSONRPCMessage],
],
None,
]:
"""Context manager that provides read and write streams for a connection.
Yields:
Tuple of (read_stream, write_stream) for bidirectional communication
"""
# Create the memory streams for this connection
read_stream_writer, read_stream = anyio.create_memory_object_stream[
JSONRPCMessage | Exception
](0)
write_stream, write_stream_reader = anyio.create_memory_object_stream[
JSONRPCMessage
](0)
# Store the streams
self._read_stream_writer = read_stream_writer
self._write_stream_reader = write_stream_reader
# Start a task group for message routing
async with anyio.create_task_group() as tg:
# Create a message router that distributes messages to request streams
async def message_router():
try:
async for message in write_stream_reader:
# Determine which request stream(s) should receive this message
target_request_id = None
if isinstance(
message.root, JSONRPCNotification | JSONRPCRequest
):
# Extract related_request_id from meta if it exists
if (
(params := getattr(message.root, "params", None))
and (meta := params.get("_meta"))
and (related_id := meta.get("related_request_id"))
is not None
):
target_request_id = str(related_id)
else:
target_request_id = str(message.root.id)
# Send to the specific request stream if available
if (
target_request_id
and target_request_id in self._request_streams
):
try:
await self._request_streams[target_request_id].send(
message
)
except (
anyio.BrokenResourceError,
anyio.ClosedResourceError,
):
# Stream might be closed, remove from registry
self._request_streams.pop(target_request_id, None)
except Exception as e:
logger.exception(f"Error in message router: {e}")
# Start the message router
tg.start_soon(message_router)
try:
# Yield the streams for the caller to use
yield read_stream, write_stream
finally:
for stream in list(self._request_streams.values()):
try:
await stream.aclose()
except Exception:
pass
self._request_streams.clear()

View File

@@ -6,7 +6,6 @@ from types import TracebackType
from typing import Any, Generic, TypeVar
import anyio
import anyio.lowlevel
import httpx
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import BaseModel
@@ -24,6 +23,7 @@ from mcp.types import (
JSONRPCNotification,
JSONRPCRequest,
JSONRPCResponse,
NotificationParams,
RequestParams,
ServerNotification,
ServerRequest,
@@ -274,16 +274,32 @@ class BaseSession(
await response_stream.aclose()
await response_stream_reader.aclose()
async def send_notification(self, notification: SendNotificationT) -> None:
async def send_notification(
self,
notification: SendNotificationT,
related_request_id: RequestId | None = None,
) -> None:
"""
Emits a notification, which is a one-way message that does not expect
a response.
"""
# Some transport implementations may need to set the related_request_id
# to attribute to the notifications to the request that triggered them.
if related_request_id is not None and notification.root.params is not None:
# Create meta if it doesn't exist
if notification.root.params.meta is None:
meta_dict = {"related_request_id": related_request_id}
else:
meta_dict = notification.root.params.meta.model_dump(
by_alias=True, mode="json", exclude_none=True
)
meta_dict["related_request_id"] = related_request_id
notification.root.params.meta = NotificationParams.Meta(**meta_dict)
jsonrpc_notification = JSONRPCNotification(
jsonrpc="2.0",
**notification.model_dump(by_alias=True, mode="json", exclude_none=True),
)
await self._write_stream.send(JSONRPCMessage(jsonrpc_notification))
async def _send_response(

View File

@@ -9,6 +9,7 @@ from mcp.shared.memory import (
from mcp.shared.session import RequestResponder
from mcp.types import (
LoggingMessageNotificationParams,
NotificationParams,
TextContent,
)
@@ -78,6 +79,11 @@ async def test_logging_callback():
)
assert log_result.isError is False
assert len(logging_collector.log_messages) == 1
assert logging_collector.log_messages[0] == LoggingMessageNotificationParams(
level="info", logger="test_logger", data="Test log message"
)
# Create meta object with related_request_id added dynamically
meta = NotificationParams.Meta()
setattr(meta, "related_request_id", "2")
log = logging_collector.log_messages[0]
assert log.level == "info"
assert log.logger == "test_logger"
assert log.data == "Test log message"
assert log.meta == meta

View File

@@ -35,7 +35,7 @@ async def test_messages_are_executed_concurrently():
end_time = anyio.current_time()
duration = end_time - start_time
assert duration < 3 * _sleep_time_seconds
assert duration < 6 * _sleep_time_seconds
print(duration)

View File

@@ -544,14 +544,28 @@ class TestContextInjection:
assert mock_log.call_count == 4
mock_log.assert_any_call(
level="debug", data="Debug message", logger=None
)
mock_log.assert_any_call(level="info", data="Info message", logger=None)
mock_log.assert_any_call(
level="warning", data="Warning message", logger=None
level="debug",
data="Debug message",
logger=None,
related_request_id="1",
)
mock_log.assert_any_call(
level="error", data="Error message", logger=None
level="info",
data="Info message",
logger=None,
related_request_id="1",
)
mock_log.assert_any_call(
level="warning",
data="Warning message",
logger=None,
related_request_id="1",
)
mock_log.assert_any_call(
level="error",
data="Error message",
logger=None,
related_request_id="1",
)
@pytest.mark.anyio

View File

@@ -0,0 +1,543 @@
"""
Tests for the StreamableHTTP server transport validation.
This file contains tests for request validation in the StreamableHTTP transport.
"""
import contextlib
import multiprocessing
import socket
import time
from collections.abc import Generator
from http import HTTPStatus
from uuid import uuid4
import anyio
import pytest
import requests
import uvicorn
from pydantic import AnyUrl
from starlette.applications import Starlette
from starlette.requests import Request
from starlette.responses import Response
from starlette.routing import Mount
from mcp.server import Server
from mcp.server.streamableHttp import (
MCP_SESSION_ID_HEADER,
SESSION_ID_PATTERN,
StreamableHTTPServerTransport,
)
from mcp.shared.exceptions import McpError
from mcp.types import (
ErrorData,
TextContent,
Tool,
)
# Test constants
SERVER_NAME = "test_streamable_http_server"
TEST_SESSION_ID = "test-session-id-12345"
INIT_REQUEST = {
"jsonrpc": "2.0",
"method": "initialize",
"params": {
"clientInfo": {"name": "test-client", "version": "1.0"},
"protocolVersion": "2025-03-26",
"capabilities": {},
},
"id": "init-1",
}
# Test server implementation that follows MCP protocol
class ServerTest(Server):
def __init__(self):
super().__init__(SERVER_NAME)
@self.read_resource()
async def handle_read_resource(uri: AnyUrl) -> str | bytes:
if uri.scheme == "foobar":
return f"Read {uri.host}"
elif uri.scheme == "slow":
# Simulate a slow resource
await anyio.sleep(2.0)
return f"Slow response from {uri.host}"
raise McpError(
error=ErrorData(
code=404, message="OOPS! no resource with that URI was found"
)
)
@self.list_tools()
async def handle_list_tools() -> list[Tool]:
return [
Tool(
name="test_tool",
description="A test tool",
inputSchema={"type": "object", "properties": {}},
)
]
@self.call_tool()
async def handle_call_tool(name: str, args: dict) -> list[TextContent]:
return [TextContent(type="text", text=f"Called {name}")]
def create_app(is_json_response_enabled=False) -> Starlette:
"""Create a Starlette application for testing that matches the example server.
Args:
is_json_response_enabled: If True, use JSON responses instead of SSE streams.
"""
# Create server instance
server = ServerTest()
server_instances = {}
# Lock to prevent race conditions when creating new sessions
session_creation_lock = anyio.Lock()
task_group = None
@contextlib.asynccontextmanager
async def lifespan(app):
"""Application lifespan context manager for managing task group."""
nonlocal task_group
async with anyio.create_task_group() as tg:
task_group = tg
print("Application started, task group initialized!")
try:
yield
finally:
print("Application shutting down, cleaning up resources...")
if task_group:
tg.cancel_scope.cancel()
task_group = None
print("Resources cleaned up successfully.")
async def handle_streamable_http(scope, receive, send):
request = Request(scope, receive)
request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER)
# Use existing transport if session ID matches
if (
request_mcp_session_id is not None
and request_mcp_session_id in server_instances
):
transport = server_instances[request_mcp_session_id]
await transport.handle_request(scope, receive, send)
elif request_mcp_session_id is None:
async with session_creation_lock:
new_session_id = uuid4().hex
http_transport = StreamableHTTPServerTransport(
mcp_session_id=new_session_id,
is_json_response_enabled=is_json_response_enabled,
)
async with http_transport.connect() as streams:
read_stream, write_stream = streams
async def run_server():
try:
await server.run(
read_stream,
write_stream,
server.create_initialization_options(),
)
except Exception as e:
print(f"Server exception: {e}")
if task_group is None:
response = Response(
"Internal Server Error: Task group is not initialized",
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
)
await response(scope, receive, send)
return
# Store the instance before starting the task to prevent races
server_instances[http_transport.mcp_session_id] = http_transport
task_group.start_soon(run_server)
await http_transport.handle_request(scope, receive, send)
else:
response = Response(
"Bad Request: No valid session ID provided",
status_code=HTTPStatus.BAD_REQUEST,
)
await response(scope, receive, send)
# Create an ASGI application
app = Starlette(
debug=True,
routes=[
Mount("/mcp", app=handle_streamable_http),
],
lifespan=lifespan,
)
return app
def run_server(port: int, is_json_response_enabled=False) -> None:
"""Run the test server.
Args:
port: Port to listen on.
is_json_response_enabled: If True, use JSON responses instead of SSE streams.
"""
print(
f"Starting test server on port {port} with "
f"json_enabled={is_json_response_enabled}"
)
app = create_app(is_json_response_enabled)
# Configure server
config = uvicorn.Config(
app=app,
host="127.0.0.1",
port=port,
log_level="info",
limit_concurrency=10,
timeout_keep_alive=5,
access_log=False,
)
# Start the server
server = uvicorn.Server(config=config)
# This is important to catch exceptions and prevent test hangs
try:
print("Server starting...")
server.run()
except Exception as e:
print(f"ERROR: Server failed to run: {e}")
import traceback
traceback.print_exc()
print("Server shutdown")
# Test fixtures - using same approach as SSE tests
@pytest.fixture
def basic_server_port() -> int:
"""Find an available port for the basic server."""
with socket.socket() as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]
@pytest.fixture
def json_server_port() -> int:
"""Find an available port for the JSON response server."""
with socket.socket() as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]
@pytest.fixture
def basic_server(basic_server_port: int) -> Generator[None, None, None]:
"""Start a basic server."""
proc = multiprocessing.Process(
target=run_server, kwargs={"port": basic_server_port}, daemon=True
)
proc.start()
# Wait for server to be running
max_attempts = 20
attempt = 0
while attempt < max_attempts:
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.connect(("127.0.0.1", basic_server_port))
break
except ConnectionRefusedError:
time.sleep(0.1)
attempt += 1
else:
raise RuntimeError(f"Server failed to start after {max_attempts} attempts")
yield
# Clean up
proc.kill()
proc.join(timeout=2)
if proc.is_alive():
print("server process failed to terminate")
@pytest.fixture
def json_response_server(json_server_port: int) -> Generator[None, None, None]:
"""Start a server with JSON response enabled."""
proc = multiprocessing.Process(
target=run_server,
kwargs={"port": json_server_port, "is_json_response_enabled": True},
daemon=True,
)
proc.start()
# Wait for server to be running
max_attempts = 20
attempt = 0
while attempt < max_attempts:
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.connect(("127.0.0.1", json_server_port))
break
except ConnectionRefusedError:
time.sleep(0.1)
attempt += 1
else:
raise RuntimeError(f"Server failed to start after {max_attempts} attempts")
yield
# Clean up
proc.kill()
proc.join(timeout=2)
if proc.is_alive():
print("server process failed to terminate")
@pytest.fixture
def basic_server_url(basic_server_port: int) -> str:
"""Get the URL for the basic test server."""
return f"http://127.0.0.1:{basic_server_port}"
@pytest.fixture
def json_server_url(json_server_port: int) -> str:
"""Get the URL for the JSON response test server."""
return f"http://127.0.0.1:{json_server_port}"
# Basic request validation tests
def test_accept_header_validation(basic_server, basic_server_url):
"""Test that Accept header is properly validated."""
# Test without Accept header
response = requests.post(
f"{basic_server_url}/mcp",
headers={"Content-Type": "application/json"},
json={"jsonrpc": "2.0", "method": "initialize", "id": 1},
)
assert response.status_code == 406
assert "Not Acceptable" in response.text
def test_content_type_validation(basic_server, basic_server_url):
"""Test that Content-Type header is properly validated."""
# Test with incorrect Content-Type
response = requests.post(
f"{basic_server_url}/mcp",
headers={
"Accept": "application/json, text/event-stream",
"Content-Type": "text/plain",
},
data="This is not JSON",
)
assert response.status_code == 415
assert "Unsupported Media Type" in response.text
def test_json_validation(basic_server, basic_server_url):
"""Test that JSON content is properly validated."""
# Test with invalid JSON
response = requests.post(
f"{basic_server_url}/mcp",
headers={
"Accept": "application/json, text/event-stream",
"Content-Type": "application/json",
},
data="this is not valid json",
)
assert response.status_code == 400
assert "Parse error" in response.text
def test_json_parsing(basic_server, basic_server_url):
"""Test that JSON content is properly parse."""
# Test with valid JSON but invalid JSON-RPC
response = requests.post(
f"{basic_server_url}/mcp",
headers={
"Accept": "application/json, text/event-stream",
"Content-Type": "application/json",
},
json={"foo": "bar"},
)
assert response.status_code == 400
assert "Validation error" in response.text
def test_method_not_allowed(basic_server, basic_server_url):
"""Test that unsupported HTTP methods are rejected."""
# Test with unsupported method (PUT)
response = requests.put(
f"{basic_server_url}/mcp",
headers={
"Accept": "application/json, text/event-stream",
"Content-Type": "application/json",
},
json={"jsonrpc": "2.0", "method": "initialize", "id": 1},
)
assert response.status_code == 405
assert "Method Not Allowed" in response.text
def test_session_validation(basic_server, basic_server_url):
"""Test session ID validation."""
# session_id not used directly in this test
# Test without session ID
response = requests.post(
f"{basic_server_url}/mcp",
headers={
"Accept": "application/json, text/event-stream",
"Content-Type": "application/json",
},
json={"jsonrpc": "2.0", "method": "list_tools", "id": 1},
)
assert response.status_code == 400
assert "Missing session ID" in response.text
def test_session_id_pattern():
"""Test that SESSION_ID_PATTERN correctly validates session IDs."""
# Valid session IDs (visible ASCII characters from 0x21 to 0x7E)
valid_session_ids = [
"test-session-id",
"1234567890",
"session!@#$%^&*()_+-=[]{}|;:,.<>?/",
"~`",
]
for session_id in valid_session_ids:
assert SESSION_ID_PATTERN.match(session_id) is not None
# Ensure fullmatch matches too (whole string)
assert SESSION_ID_PATTERN.fullmatch(session_id) is not None
# Invalid session IDs
invalid_session_ids = [
"", # Empty string
" test", # Space (0x20)
"test\t", # Tab
"test\n", # Newline
"test\r", # Carriage return
"test" + chr(0x7F), # DEL character
"test" + chr(0x80), # Extended ASCII
"test" + chr(0x00), # Null character
"test" + chr(0x20), # Space (0x20)
]
for session_id in invalid_session_ids:
# For invalid IDs, either match will fail or fullmatch will fail
if SESSION_ID_PATTERN.match(session_id) is not None:
# If match succeeds, fullmatch should fail (partial match case)
assert SESSION_ID_PATTERN.fullmatch(session_id) is None
def test_streamable_http_transport_init_validation():
"""Test that StreamableHTTPServerTransport validates session ID on init."""
# Valid session ID should initialize without errors
valid_transport = StreamableHTTPServerTransport(mcp_session_id="valid-id")
assert valid_transport.mcp_session_id == "valid-id"
# None should be accepted
none_transport = StreamableHTTPServerTransport(mcp_session_id=None)
assert none_transport.mcp_session_id is None
# Invalid session ID should raise ValueError
with pytest.raises(ValueError) as excinfo:
StreamableHTTPServerTransport(mcp_session_id="invalid id with space")
assert "Session ID must only contain visible ASCII characters" in str(excinfo.value)
# Test with control characters
with pytest.raises(ValueError):
StreamableHTTPServerTransport(mcp_session_id="test\nid")
with pytest.raises(ValueError):
StreamableHTTPServerTransport(mcp_session_id="test\n")
def test_session_termination(basic_server, basic_server_url):
"""Test session termination via DELETE and subsequent request handling."""
response = requests.post(
f"{basic_server_url}/mcp",
headers={
"Accept": "application/json, text/event-stream",
"Content-Type": "application/json",
},
json=INIT_REQUEST,
)
assert response.status_code == 200
# Now terminate the session
session_id = response.headers.get(MCP_SESSION_ID_HEADER)
response = requests.delete(
f"{basic_server_url}/mcp",
headers={MCP_SESSION_ID_HEADER: session_id},
)
assert response.status_code == 200
# Try to use the terminated session
response = requests.post(
f"{basic_server_url}/mcp",
headers={
"Accept": "application/json, text/event-stream",
"Content-Type": "application/json",
MCP_SESSION_ID_HEADER: session_id,
},
json={"jsonrpc": "2.0", "method": "ping", "id": 2},
)
assert response.status_code == 404
assert "Session has been terminated" in response.text
def test_response(basic_server, basic_server_url):
"""Test response handling for a valid request."""
mcp_url = f"{basic_server_url}/mcp"
response = requests.post(
mcp_url,
headers={
"Accept": "application/json, text/event-stream",
"Content-Type": "application/json",
},
json=INIT_REQUEST,
)
assert response.status_code == 200
# Now terminate the session
session_id = response.headers.get(MCP_SESSION_ID_HEADER)
# Try to use the terminated session
tools_response = requests.post(
mcp_url,
headers={
"Accept": "application/json, text/event-stream",
"Content-Type": "application/json",
MCP_SESSION_ID_HEADER: session_id, # Use the session ID we got earlier
},
json={"jsonrpc": "2.0", "method": "tools/list", "id": "tools-1"},
stream=True,
)
assert tools_response.status_code == 200
assert tools_response.headers.get("Content-Type") == "text/event-stream"
def test_json_response(json_response_server, json_server_url):
"""Test response handling when is_json_response_enabled is True."""
mcp_url = f"{json_server_url}/mcp"
response = requests.post(
mcp_url,
headers={
"Accept": "application/json, text/event-stream",
"Content-Type": "application/json",
},
json=INIT_REQUEST,
)
assert response.status_code == 200
assert response.headers.get("Content-Type") == "application/json"

38
uv.lock generated
View File

@@ -10,6 +10,7 @@ members = [
"mcp",
"mcp-simple-prompt",
"mcp-simple-resource",
"mcp-simple-streamablehttp",
"mcp-simple-tool",
]
@@ -632,6 +633,43 @@ dev = [
{ name = "ruff", specifier = ">=0.6.9" },
]
[[package]]
name = "mcp-simple-streamablehttp"
version = "0.1.0"
source = { editable = "examples/servers/simple-streamablehttp" }
dependencies = [
{ name = "anyio" },
{ name = "click" },
{ name = "httpx" },
{ name = "mcp" },
{ name = "starlette" },
{ name = "uvicorn" },
]
[package.dev-dependencies]
dev = [
{ name = "pyright" },
{ name = "pytest" },
{ name = "ruff" },
]
[package.metadata]
requires-dist = [
{ name = "anyio", specifier = ">=4.5" },
{ name = "click", specifier = ">=8.1.0" },
{ name = "httpx", specifier = ">=0.27" },
{ name = "mcp", editable = "." },
{ name = "starlette" },
{ name = "uvicorn" },
]
[package.metadata.requires-dev]
dev = [
{ name = "pyright", specifier = ">=1.1.378" },
{ name = "pytest", specifier = ">=8.3.3" },
{ name = "ruff", specifier = ">=0.6.9" },
]
[[package]]
name = "mcp-simple-tool"
version = "0.1.0"