mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-18 14:34:27 +01:00
StreamableHttp - Server transport with state management (#553)
This commit is contained in:
37
examples/servers/simple-streamablehttp/README.md
Normal file
37
examples/servers/simple-streamablehttp/README.md
Normal file
@@ -0,0 +1,37 @@
|
||||
# MCP Simple StreamableHttp Server Example
|
||||
|
||||
A simple MCP server example demonstrating the StreamableHttp transport, which enables HTTP-based communication with MCP servers using streaming.
|
||||
|
||||
## Features
|
||||
|
||||
- Uses the StreamableHTTP transport for server-client communication
|
||||
- Supports REST API operations (POST, GET, DELETE) for `/mcp` endpoint
|
||||
- Task management with anyio task groups
|
||||
- Ability to send multiple notifications over time to the client
|
||||
- Proper resource cleanup and lifespan management
|
||||
|
||||
## Usage
|
||||
|
||||
Start the server on the default or custom port:
|
||||
|
||||
```bash
|
||||
|
||||
# Using custom port
|
||||
uv run mcp-simple-streamablehttp --port 3000
|
||||
|
||||
# Custom logging level
|
||||
uv run mcp-simple-streamablehttp --log-level DEBUG
|
||||
|
||||
# Enable JSON responses instead of SSE streams
|
||||
uv run mcp-simple-streamablehttp --json-response
|
||||
```
|
||||
|
||||
The server exposes a tool named "start-notification-stream" that accepts three arguments:
|
||||
|
||||
- `interval`: Time between notifications in seconds (e.g., 1.0)
|
||||
- `count`: Number of notifications to send (e.g., 5)
|
||||
- `caller`: Identifier string for the caller
|
||||
|
||||
## Client
|
||||
|
||||
You can connect to this server using an HTTP client, for now only Typescript SDK has streamable HTTP client examples or you can use (Inspector)[https://github.com/modelcontextprotocol/inspector]
|
||||
@@ -0,0 +1,4 @@
|
||||
from .server import main
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,201 @@
|
||||
import contextlib
|
||||
import logging
|
||||
from http import HTTPStatus
|
||||
from uuid import uuid4
|
||||
|
||||
import anyio
|
||||
import click
|
||||
import mcp.types as types
|
||||
from mcp.server.lowlevel import Server
|
||||
from mcp.server.streamableHttp import (
|
||||
MCP_SESSION_ID_HEADER,
|
||||
StreamableHTTPServerTransport,
|
||||
)
|
||||
from starlette.applications import Starlette
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
from starlette.routing import Mount
|
||||
|
||||
# Configure logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global task group that will be initialized in the lifespan
|
||||
task_group = None
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def lifespan(app):
|
||||
"""Application lifespan context manager for managing task group."""
|
||||
global task_group
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
task_group = tg
|
||||
logger.info("Application started, task group initialized!")
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
logger.info("Application shutting down, cleaning up resources...")
|
||||
if task_group:
|
||||
tg.cancel_scope.cancel()
|
||||
task_group = None
|
||||
logger.info("Resources cleaned up successfully.")
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option("--port", default=3000, help="Port to listen on for HTTP")
|
||||
@click.option(
|
||||
"--log-level",
|
||||
default="INFO",
|
||||
help="Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)",
|
||||
)
|
||||
@click.option(
|
||||
"--json-response",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Enable JSON responses instead of SSE streams",
|
||||
)
|
||||
def main(
|
||||
port: int,
|
||||
log_level: str,
|
||||
json_response: bool,
|
||||
) -> int:
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, log_level.upper()),
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
|
||||
app = Server("mcp-streamable-http-demo")
|
||||
|
||||
@app.call_tool()
|
||||
async def call_tool(
|
||||
name: str, arguments: dict
|
||||
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
|
||||
ctx = app.request_context
|
||||
interval = arguments.get("interval", 1.0)
|
||||
count = arguments.get("count", 5)
|
||||
caller = arguments.get("caller", "unknown")
|
||||
|
||||
# Send the specified number of notifications with the given interval
|
||||
for i in range(count):
|
||||
await ctx.session.send_log_message(
|
||||
level="info",
|
||||
data=f"Notification {i+1}/{count} from caller: {caller}",
|
||||
logger="notification_stream",
|
||||
# Associates this notification with the original request
|
||||
# Ensures notifications are sent to the correct response stream
|
||||
# Without this, notifications will either go to:
|
||||
# - a standalone SSE stream (if GET request is supported)
|
||||
# - nowhere (if GET request isn't supported)
|
||||
related_request_id=ctx.request_id,
|
||||
)
|
||||
if i < count - 1: # Don't wait after the last notification
|
||||
await anyio.sleep(interval)
|
||||
|
||||
return [
|
||||
types.TextContent(
|
||||
type="text",
|
||||
text=(
|
||||
f"Sent {count} notifications with {interval}s interval"
|
||||
f" for caller: {caller}"
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
@app.list_tools()
|
||||
async def list_tools() -> list[types.Tool]:
|
||||
return [
|
||||
types.Tool(
|
||||
name="start-notification-stream",
|
||||
description=(
|
||||
"Sends a stream of notifications with configurable count"
|
||||
" and interval"
|
||||
),
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"required": ["interval", "count", "caller"],
|
||||
"properties": {
|
||||
"interval": {
|
||||
"type": "number",
|
||||
"description": "Interval between notifications in seconds",
|
||||
},
|
||||
"count": {
|
||||
"type": "number",
|
||||
"description": "Number of notifications to send",
|
||||
},
|
||||
"caller": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Identifier of the caller to include in notifications"
|
||||
),
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
# We need to store the server instances between requests
|
||||
server_instances = {}
|
||||
# Lock to prevent race conditions when creating new sessions
|
||||
session_creation_lock = anyio.Lock()
|
||||
|
||||
# ASGI handler for streamable HTTP connections
|
||||
async def handle_streamable_http(scope, receive, send):
|
||||
request = Request(scope, receive)
|
||||
request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER)
|
||||
if (
|
||||
request_mcp_session_id is not None
|
||||
and request_mcp_session_id in server_instances
|
||||
):
|
||||
transport = server_instances[request_mcp_session_id]
|
||||
logger.debug("Session already exists, handling request directly")
|
||||
await transport.handle_request(scope, receive, send)
|
||||
elif request_mcp_session_id is None:
|
||||
# try to establish new session
|
||||
logger.debug("Creating new transport")
|
||||
# Use lock to prevent race conditions when creating new sessions
|
||||
async with session_creation_lock:
|
||||
new_session_id = uuid4().hex
|
||||
http_transport = StreamableHTTPServerTransport(
|
||||
mcp_session_id=new_session_id,
|
||||
is_json_response_enabled=json_response,
|
||||
)
|
||||
server_instances[http_transport.mcp_session_id] = http_transport
|
||||
async with http_transport.connect() as streams:
|
||||
read_stream, write_stream = streams
|
||||
|
||||
async def run_server():
|
||||
await app.run(
|
||||
read_stream,
|
||||
write_stream,
|
||||
app.create_initialization_options(),
|
||||
)
|
||||
|
||||
if not task_group:
|
||||
raise RuntimeError("Task group is not initialized")
|
||||
|
||||
task_group.start_soon(run_server)
|
||||
|
||||
# Handle the HTTP request and return the response
|
||||
await http_transport.handle_request(scope, receive, send)
|
||||
else:
|
||||
response = Response(
|
||||
"Bad Request: No valid session ID provided",
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
)
|
||||
await response(scope, receive, send)
|
||||
|
||||
# Create an ASGI application using the transport
|
||||
starlette_app = Starlette(
|
||||
debug=True,
|
||||
routes=[
|
||||
Mount("/mcp", app=handle_streamable_http),
|
||||
],
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(starlette_app, host="0.0.0.0", port=port)
|
||||
|
||||
return 0
|
||||
36
examples/servers/simple-streamablehttp/pyproject.toml
Normal file
36
examples/servers/simple-streamablehttp/pyproject.toml
Normal file
@@ -0,0 +1,36 @@
|
||||
[project]
|
||||
name = "mcp-simple-streamablehttp"
|
||||
version = "0.1.0"
|
||||
description = "A simple MCP server exposing a StreamableHttp transport for testing"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
authors = [{ name = "Anthropic, PBC." }]
|
||||
keywords = ["mcp", "llm", "automation", "web", "fetch", "http", "streamable"]
|
||||
license = { text = "MIT" }
|
||||
dependencies = ["anyio>=4.5", "click>=8.1.0", "httpx>=0.27", "mcp", "starlette", "uvicorn"]
|
||||
|
||||
[project.scripts]
|
||||
mcp-simple-streamablehttp = "mcp_simple_streamablehttp.server:main"
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["mcp_simple_streamablehttp"]
|
||||
|
||||
[tool.pyright]
|
||||
include = ["mcp_simple_streamablehttp"]
|
||||
venvPath = "."
|
||||
venv = ".venv"
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "F", "I"]
|
||||
ignore = []
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 88
|
||||
target-version = "py310"
|
||||
|
||||
[tool.uv]
|
||||
dev-dependencies = ["pyright>=1.1.378", "pytest>=8.3.3", "ruff>=0.6.9"]
|
||||
@@ -814,7 +814,10 @@ class Context(BaseModel, Generic[ServerSessionT, LifespanContextT]):
|
||||
**extra: Additional structured data to include
|
||||
"""
|
||||
await self.request_context.session.send_log_message(
|
||||
level=level, data=message, logger=logger_name
|
||||
level=level,
|
||||
data=message,
|
||||
logger=logger_name,
|
||||
related_request_id=self.request_id,
|
||||
)
|
||||
|
||||
@property
|
||||
|
||||
@@ -179,7 +179,11 @@ class ServerSession(
|
||||
)
|
||||
|
||||
async def send_log_message(
|
||||
self, level: types.LoggingLevel, data: Any, logger: str | None = None
|
||||
self,
|
||||
level: types.LoggingLevel,
|
||||
data: Any,
|
||||
logger: str | None = None,
|
||||
related_request_id: types.RequestId | None = None,
|
||||
) -> None:
|
||||
"""Send a log message notification."""
|
||||
await self.send_notification(
|
||||
@@ -192,7 +196,8 @@ class ServerSession(
|
||||
logger=logger,
|
||||
),
|
||||
)
|
||||
)
|
||||
),
|
||||
related_request_id,
|
||||
)
|
||||
|
||||
async def send_resource_updated(self, uri: AnyUrl) -> None:
|
||||
@@ -261,7 +266,11 @@ class ServerSession(
|
||||
)
|
||||
|
||||
async def send_progress_notification(
|
||||
self, progress_token: str | int, progress: float, total: float | None = None
|
||||
self,
|
||||
progress_token: str | int,
|
||||
progress: float,
|
||||
total: float | None = None,
|
||||
related_request_id: str | None = None,
|
||||
) -> None:
|
||||
"""Send a progress notification."""
|
||||
await self.send_notification(
|
||||
@@ -274,7 +283,8 @@ class ServerSession(
|
||||
total=total,
|
||||
),
|
||||
)
|
||||
)
|
||||
),
|
||||
related_request_id,
|
||||
)
|
||||
|
||||
async def send_resource_list_changed(self) -> None:
|
||||
|
||||
644
src/mcp/server/streamableHttp.py
Normal file
644
src/mcp/server/streamableHttp.py
Normal file
@@ -0,0 +1,644 @@
|
||||
"""
|
||||
StreamableHTTP Server Transport Module
|
||||
|
||||
This module implements an HTTP transport layer with Streamable HTTP.
|
||||
|
||||
The transport handles bidirectional communication using HTTP requests and
|
||||
responses, with streaming support for long-running operations.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from http import HTTPStatus
|
||||
from typing import Any
|
||||
|
||||
import anyio
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from pydantic import ValidationError
|
||||
from sse_starlette import EventSourceResponse
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
from starlette.types import Receive, Scope, Send
|
||||
|
||||
from mcp.types import (
|
||||
INTERNAL_ERROR,
|
||||
INVALID_PARAMS,
|
||||
INVALID_REQUEST,
|
||||
PARSE_ERROR,
|
||||
ErrorData,
|
||||
JSONRPCError,
|
||||
JSONRPCMessage,
|
||||
JSONRPCNotification,
|
||||
JSONRPCRequest,
|
||||
JSONRPCResponse,
|
||||
RequestId,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Maximum size for incoming messages
|
||||
MAXIMUM_MESSAGE_SIZE = 4 * 1024 * 1024 # 4MB
|
||||
|
||||
# Header names
|
||||
MCP_SESSION_ID_HEADER = "mcp-session-id"
|
||||
LAST_EVENT_ID_HEADER = "last-event-id"
|
||||
|
||||
# Content types
|
||||
CONTENT_TYPE_JSON = "application/json"
|
||||
CONTENT_TYPE_SSE = "text/event-stream"
|
||||
|
||||
# Session ID validation pattern (visible ASCII characters ranging from 0x21 to 0x7E)
|
||||
# Pattern ensures entire string contains only valid characters by using ^ and $ anchors
|
||||
SESSION_ID_PATTERN = re.compile(r"^[\x21-\x7E]+$")
|
||||
|
||||
|
||||
class StreamableHTTPServerTransport:
|
||||
"""
|
||||
HTTP server transport with event streaming support for MCP.
|
||||
|
||||
Handles JSON-RPC messages in HTTP POST requests with SSE streaming.
|
||||
Supports optional JSON responses and session management.
|
||||
"""
|
||||
|
||||
# Server notification streams for POST requests as well as standalone SSE stream
|
||||
_read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception] | None = (
|
||||
None
|
||||
)
|
||||
_write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage] | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mcp_session_id: str | None,
|
||||
is_json_response_enabled: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize a new StreamableHTTP server transport.
|
||||
|
||||
Args:
|
||||
mcp_session_id: Optional session identifier for this connection.
|
||||
Must contain only visible ASCII characters (0x21-0x7E).
|
||||
is_json_response_enabled: If True, return JSON responses for requests
|
||||
instead of SSE streams. Default is False.
|
||||
|
||||
Raises:
|
||||
ValueError: If the session ID contains invalid characters.
|
||||
"""
|
||||
if mcp_session_id is not None and not SESSION_ID_PATTERN.fullmatch(
|
||||
mcp_session_id
|
||||
):
|
||||
raise ValueError(
|
||||
"Session ID must only contain visible ASCII characters (0x21-0x7E)"
|
||||
)
|
||||
|
||||
self.mcp_session_id = mcp_session_id
|
||||
self.is_json_response_enabled = is_json_response_enabled
|
||||
self._request_streams: dict[
|
||||
RequestId, MemoryObjectSendStream[JSONRPCMessage]
|
||||
] = {}
|
||||
self._terminated = False
|
||||
|
||||
def _create_error_response(
|
||||
self,
|
||||
error_message: str,
|
||||
status_code: HTTPStatus,
|
||||
error_code: int = INVALID_REQUEST,
|
||||
headers: dict[str, str] | None = None,
|
||||
) -> Response:
|
||||
"""Create an error response with a simple string message."""
|
||||
response_headers = {"Content-Type": CONTENT_TYPE_JSON}
|
||||
if headers:
|
||||
response_headers.update(headers)
|
||||
|
||||
if self.mcp_session_id:
|
||||
response_headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id
|
||||
|
||||
# Return a properly formatted JSON error response
|
||||
error_response = JSONRPCError(
|
||||
jsonrpc="2.0",
|
||||
id="server-error", # We don't have a request ID for general errors
|
||||
error=ErrorData(
|
||||
code=error_code,
|
||||
message=error_message,
|
||||
),
|
||||
)
|
||||
|
||||
return Response(
|
||||
error_response.model_dump_json(by_alias=True, exclude_none=True),
|
||||
status_code=status_code,
|
||||
headers=response_headers,
|
||||
)
|
||||
|
||||
def _create_json_response(
|
||||
self,
|
||||
response_message: JSONRPCMessage | None,
|
||||
status_code: HTTPStatus = HTTPStatus.OK,
|
||||
headers: dict[str, str] | None = None,
|
||||
) -> Response:
|
||||
"""Create a JSON response from a JSONRPCMessage"""
|
||||
response_headers = {"Content-Type": CONTENT_TYPE_JSON}
|
||||
if headers:
|
||||
response_headers.update(headers)
|
||||
|
||||
if self.mcp_session_id:
|
||||
response_headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id
|
||||
|
||||
return Response(
|
||||
response_message.model_dump_json(by_alias=True, exclude_none=True)
|
||||
if response_message
|
||||
else None,
|
||||
status_code=status_code,
|
||||
headers=response_headers,
|
||||
)
|
||||
|
||||
def _get_session_id(self, request: Request) -> str | None:
|
||||
"""Extract the session ID from request headers."""
|
||||
return request.headers.get(MCP_SESSION_ID_HEADER)
|
||||
|
||||
async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
"""Application entry point that handles all HTTP requests"""
|
||||
request = Request(scope, receive)
|
||||
if self._terminated:
|
||||
# If the session has been terminated, return 404 Not Found
|
||||
response = self._create_error_response(
|
||||
"Not Found: Session has been terminated",
|
||||
HTTPStatus.NOT_FOUND,
|
||||
)
|
||||
await response(scope, receive, send)
|
||||
return
|
||||
|
||||
if request.method == "POST":
|
||||
await self._handle_post_request(scope, request, receive, send)
|
||||
elif request.method == "GET":
|
||||
await self._handle_get_request(request, send)
|
||||
elif request.method == "DELETE":
|
||||
await self._handle_delete_request(request, send)
|
||||
else:
|
||||
await self._handle_unsupported_request(request, send)
|
||||
|
||||
def _check_accept_headers(self, request: Request) -> tuple[bool, bool]:
|
||||
"""Check if the request accepts the required media types."""
|
||||
accept_header = request.headers.get("accept", "")
|
||||
accept_types = [media_type.strip() for media_type in accept_header.split(",")]
|
||||
|
||||
has_json = any(
|
||||
media_type.startswith(CONTENT_TYPE_JSON) for media_type in accept_types
|
||||
)
|
||||
has_sse = any(
|
||||
media_type.startswith(CONTENT_TYPE_SSE) for media_type in accept_types
|
||||
)
|
||||
|
||||
return has_json, has_sse
|
||||
|
||||
def _check_content_type(self, request: Request) -> bool:
|
||||
"""Check if the request has the correct Content-Type."""
|
||||
content_type = request.headers.get("content-type", "")
|
||||
content_type_parts = [
|
||||
part.strip() for part in content_type.split(";")[0].split(",")
|
||||
]
|
||||
|
||||
return any(part == CONTENT_TYPE_JSON for part in content_type_parts)
|
||||
|
||||
async def _handle_post_request(
|
||||
self, scope: Scope, request: Request, receive: Receive, send: Send
|
||||
) -> None:
|
||||
"""Handle POST requests containing JSON-RPC messages."""
|
||||
writer = self._read_stream_writer
|
||||
if writer is None:
|
||||
raise ValueError(
|
||||
"No read stream writer available. Ensure connect() is called first."
|
||||
)
|
||||
try:
|
||||
# Check Accept headers
|
||||
has_json, has_sse = self._check_accept_headers(request)
|
||||
if not (has_json and has_sse):
|
||||
response = self._create_error_response(
|
||||
(
|
||||
"Not Acceptable: Client must accept both application/json and "
|
||||
"text/event-stream"
|
||||
),
|
||||
HTTPStatus.NOT_ACCEPTABLE,
|
||||
)
|
||||
await response(scope, receive, send)
|
||||
return
|
||||
|
||||
# Validate Content-Type
|
||||
if not self._check_content_type(request):
|
||||
response = self._create_error_response(
|
||||
"Unsupported Media Type: Content-Type must be application/json",
|
||||
HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
|
||||
)
|
||||
await response(scope, receive, send)
|
||||
return
|
||||
|
||||
# Parse the body - only read it once
|
||||
body = await request.body()
|
||||
if len(body) > MAXIMUM_MESSAGE_SIZE:
|
||||
response = self._create_error_response(
|
||||
"Payload Too Large: Message exceeds maximum size",
|
||||
HTTPStatus.REQUEST_ENTITY_TOO_LARGE,
|
||||
)
|
||||
await response(scope, receive, send)
|
||||
return
|
||||
|
||||
try:
|
||||
raw_message = json.loads(body)
|
||||
except json.JSONDecodeError as e:
|
||||
response = self._create_error_response(
|
||||
f"Parse error: {str(e)}", HTTPStatus.BAD_REQUEST, PARSE_ERROR
|
||||
)
|
||||
await response(scope, receive, send)
|
||||
return
|
||||
|
||||
try:
|
||||
message = JSONRPCMessage.model_validate(raw_message)
|
||||
except ValidationError as e:
|
||||
response = self._create_error_response(
|
||||
f"Validation error: {str(e)}",
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
INVALID_PARAMS,
|
||||
)
|
||||
await response(scope, receive, send)
|
||||
return
|
||||
|
||||
# Check if this is an initialization request
|
||||
is_initialization_request = (
|
||||
isinstance(message.root, JSONRPCRequest)
|
||||
and message.root.method == "initialize"
|
||||
)
|
||||
|
||||
if is_initialization_request:
|
||||
# Check if the server already has an established session
|
||||
if self.mcp_session_id:
|
||||
# Check if request has a session ID
|
||||
request_session_id = self._get_session_id(request)
|
||||
|
||||
# If request has a session ID but doesn't match, return 404
|
||||
if request_session_id and request_session_id != self.mcp_session_id:
|
||||
response = self._create_error_response(
|
||||
"Not Found: Invalid or expired session ID",
|
||||
HTTPStatus.NOT_FOUND,
|
||||
)
|
||||
await response(scope, receive, send)
|
||||
return
|
||||
# For non-initialization requests, validate the session
|
||||
elif not await self._validate_session(request, send):
|
||||
return
|
||||
|
||||
# For notifications and responses only, return 202 Accepted
|
||||
if not isinstance(message.root, JSONRPCRequest):
|
||||
# Create response object and send it
|
||||
response = self._create_json_response(
|
||||
None,
|
||||
HTTPStatus.ACCEPTED,
|
||||
)
|
||||
await response(scope, receive, send)
|
||||
|
||||
# Process the message after sending the response
|
||||
await writer.send(message)
|
||||
|
||||
return
|
||||
|
||||
# Extract the request ID outside the try block for proper scope
|
||||
request_id = str(message.root.id)
|
||||
# Create promise stream for getting response
|
||||
request_stream_writer, request_stream_reader = (
|
||||
anyio.create_memory_object_stream[JSONRPCMessage](0)
|
||||
)
|
||||
|
||||
# Register this stream for the request ID
|
||||
self._request_streams[request_id] = request_stream_writer
|
||||
|
||||
if self.is_json_response_enabled:
|
||||
# Process the message
|
||||
await writer.send(message)
|
||||
try:
|
||||
# Process messages from the request-specific stream
|
||||
# We need to collect all messages until we get a response
|
||||
response_message = None
|
||||
|
||||
# Use similar approach to SSE writer for consistency
|
||||
async for received_message in request_stream_reader:
|
||||
# If it's a response, this is what we're waiting for
|
||||
if isinstance(
|
||||
received_message.root, JSONRPCResponse | JSONRPCError
|
||||
):
|
||||
response_message = received_message
|
||||
break
|
||||
# For notifications and request, keep waiting
|
||||
else:
|
||||
logger.debug(f"received: {received_message.root.method}")
|
||||
|
||||
# At this point we should have a response
|
||||
if response_message:
|
||||
# Create JSON response
|
||||
response = self._create_json_response(response_message)
|
||||
await response(scope, receive, send)
|
||||
else:
|
||||
# This shouldn't happen in normal operation
|
||||
logger.error(
|
||||
"No response message received before stream closed"
|
||||
)
|
||||
response = self._create_error_response(
|
||||
"Error processing request: No response received",
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
await response(scope, receive, send)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error processing JSON response: {e}")
|
||||
response = self._create_error_response(
|
||||
f"Error processing request: {str(e)}",
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||
INTERNAL_ERROR,
|
||||
)
|
||||
await response(scope, receive, send)
|
||||
finally:
|
||||
# Clean up the request stream
|
||||
if request_id in self._request_streams:
|
||||
self._request_streams.pop(request_id, None)
|
||||
await request_stream_reader.aclose()
|
||||
await request_stream_writer.aclose()
|
||||
else:
|
||||
# Create SSE stream
|
||||
sse_stream_writer, sse_stream_reader = (
|
||||
anyio.create_memory_object_stream[dict[str, Any]](0)
|
||||
)
|
||||
|
||||
async def sse_writer():
|
||||
# Get the request ID from the incoming request message
|
||||
try:
|
||||
async with sse_stream_writer, request_stream_reader:
|
||||
# Process messages from the request-specific stream
|
||||
async for received_message in request_stream_reader:
|
||||
# Build the event data
|
||||
event_data = {
|
||||
"event": "message",
|
||||
"data": received_message.model_dump_json(
|
||||
by_alias=True, exclude_none=True
|
||||
),
|
||||
}
|
||||
|
||||
await sse_stream_writer.send(event_data)
|
||||
|
||||
# If response, remove from pending streams and close
|
||||
if isinstance(
|
||||
received_message.root,
|
||||
JSONRPCResponse | JSONRPCError,
|
||||
):
|
||||
if request_id:
|
||||
self._request_streams.pop(request_id, None)
|
||||
break
|
||||
except Exception as e:
|
||||
logger.exception(f"Error in SSE writer: {e}")
|
||||
finally:
|
||||
logger.debug("Closing SSE writer")
|
||||
# Clean up the request-specific streams
|
||||
if request_id and request_id in self._request_streams:
|
||||
self._request_streams.pop(request_id, None)
|
||||
|
||||
# Create and start EventSourceResponse
|
||||
# SSE stream mode (original behavior)
|
||||
# Set up headers
|
||||
headers = {
|
||||
"Cache-Control": "no-cache, no-transform",
|
||||
"Connection": "keep-alive",
|
||||
"Content-Type": CONTENT_TYPE_SSE,
|
||||
**(
|
||||
{MCP_SESSION_ID_HEADER: self.mcp_session_id}
|
||||
if self.mcp_session_id
|
||||
else {}
|
||||
),
|
||||
}
|
||||
response = EventSourceResponse(
|
||||
content=sse_stream_reader,
|
||||
data_sender_callable=sse_writer,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
# Start the SSE response (this will send headers immediately)
|
||||
try:
|
||||
# First send the response to establish the SSE connection
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(response, scope, receive, send)
|
||||
# Then send the message to be processed by the server
|
||||
await writer.send(message)
|
||||
except Exception:
|
||||
logger.exception("SSE response error")
|
||||
# Clean up the request stream if something goes wrong
|
||||
if request_id and request_id in self._request_streams:
|
||||
self._request_streams.pop(request_id, None)
|
||||
|
||||
except Exception as err:
|
||||
logger.exception("Error handling POST request")
|
||||
response = self._create_error_response(
|
||||
f"Error handling POST request: {err}",
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||
INTERNAL_ERROR,
|
||||
)
|
||||
await response(scope, receive, send)
|
||||
if writer:
|
||||
await writer.send(err)
|
||||
return
|
||||
|
||||
async def _handle_get_request(self, request: Request, send: Send) -> None:
|
||||
"""Handle GET requests for SSE stream establishment."""
|
||||
# Validate session ID if server has one
|
||||
if not await self._validate_session(request, send):
|
||||
return
|
||||
# Validate Accept header - must include text/event-stream
|
||||
_, has_sse = self._check_accept_headers(request)
|
||||
|
||||
if not has_sse:
|
||||
response = self._create_error_response(
|
||||
"Not Acceptable: Client must accept text/event-stream",
|
||||
HTTPStatus.NOT_ACCEPTABLE,
|
||||
)
|
||||
await response(request.scope, request.receive, send)
|
||||
return
|
||||
|
||||
# TODO: Implement SSE stream for GET requests
|
||||
# For now, return 405 Method Not Allowed
|
||||
response = self._create_error_response(
|
||||
"SSE stream from GET request not implemented yet",
|
||||
HTTPStatus.METHOD_NOT_ALLOWED,
|
||||
)
|
||||
await response(request.scope, request.receive, send)
|
||||
|
||||
async def _handle_delete_request(self, request: Request, send: Send) -> None:
|
||||
"""Handle DELETE requests for explicit session termination."""
|
||||
# Validate session ID
|
||||
if not self.mcp_session_id:
|
||||
# If no session ID set, return Method Not Allowed
|
||||
response = self._create_error_response(
|
||||
"Method Not Allowed: Session termination not supported",
|
||||
HTTPStatus.METHOD_NOT_ALLOWED,
|
||||
)
|
||||
await response(request.scope, request.receive, send)
|
||||
return
|
||||
|
||||
if not await self._validate_session(request, send):
|
||||
return
|
||||
|
||||
self._terminate_session()
|
||||
|
||||
response = self._create_json_response(
|
||||
None,
|
||||
HTTPStatus.OK,
|
||||
)
|
||||
await response(request.scope, request.receive, send)
|
||||
|
||||
def _terminate_session(self) -> None:
|
||||
"""Terminate the current session, closing all streams.
|
||||
|
||||
Once terminated, all requests with this session ID will receive 404 Not Found.
|
||||
"""
|
||||
|
||||
self._terminated = True
|
||||
logger.info(f"Terminating session: {self.mcp_session_id}")
|
||||
|
||||
# We need a copy of the keys to avoid modification during iteration
|
||||
request_stream_keys = list(self._request_streams.keys())
|
||||
|
||||
# Close all request streams (synchronously)
|
||||
for key in request_stream_keys:
|
||||
try:
|
||||
# Get the stream
|
||||
stream = self._request_streams.get(key)
|
||||
if stream:
|
||||
# We must use close() here, not aclose() since this is a sync method
|
||||
stream.close()
|
||||
except Exception as e:
|
||||
logger.debug(f"Error closing stream {key} during termination: {e}")
|
||||
|
||||
# Clear the request streams dictionary immediately
|
||||
self._request_streams.clear()
|
||||
|
||||
async def _handle_unsupported_request(self, request: Request, send: Send) -> None:
|
||||
"""Handle unsupported HTTP methods."""
|
||||
headers = {
|
||||
"Content-Type": CONTENT_TYPE_JSON,
|
||||
"Allow": "GET, POST, DELETE",
|
||||
}
|
||||
if self.mcp_session_id:
|
||||
headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id
|
||||
|
||||
response = self._create_error_response(
|
||||
"Method Not Allowed",
|
||||
HTTPStatus.METHOD_NOT_ALLOWED,
|
||||
headers=headers,
|
||||
)
|
||||
await response(request.scope, request.receive, send)
|
||||
|
||||
async def _validate_session(self, request: Request, send: Send) -> bool:
|
||||
"""Validate the session ID in the request."""
|
||||
if not self.mcp_session_id:
|
||||
# If we're not using session IDs, return True
|
||||
return True
|
||||
|
||||
# Get the session ID from the request headers
|
||||
request_session_id = self._get_session_id(request)
|
||||
|
||||
# If no session ID provided but required, return error
|
||||
if not request_session_id:
|
||||
response = self._create_error_response(
|
||||
"Bad Request: Missing session ID",
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
)
|
||||
await response(request.scope, request.receive, send)
|
||||
return False
|
||||
|
||||
# If session ID doesn't match, return error
|
||||
if request_session_id != self.mcp_session_id:
|
||||
response = self._create_error_response(
|
||||
"Not Found: Invalid or expired session ID",
|
||||
HTTPStatus.NOT_FOUND,
|
||||
)
|
||||
await response(request.scope, request.receive, send)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@asynccontextmanager
|
||||
async def connect(
|
||||
self,
|
||||
) -> AsyncGenerator[
|
||||
tuple[
|
||||
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
||||
MemoryObjectSendStream[JSONRPCMessage],
|
||||
],
|
||||
None,
|
||||
]:
|
||||
"""Context manager that provides read and write streams for a connection.
|
||||
|
||||
Yields:
|
||||
Tuple of (read_stream, write_stream) for bidirectional communication
|
||||
"""
|
||||
|
||||
# Create the memory streams for this connection
|
||||
|
||||
read_stream_writer, read_stream = anyio.create_memory_object_stream[
|
||||
JSONRPCMessage | Exception
|
||||
](0)
|
||||
write_stream, write_stream_reader = anyio.create_memory_object_stream[
|
||||
JSONRPCMessage
|
||||
](0)
|
||||
|
||||
# Store the streams
|
||||
self._read_stream_writer = read_stream_writer
|
||||
self._write_stream_reader = write_stream_reader
|
||||
|
||||
# Start a task group for message routing
|
||||
async with anyio.create_task_group() as tg:
|
||||
# Create a message router that distributes messages to request streams
|
||||
async def message_router():
|
||||
try:
|
||||
async for message in write_stream_reader:
|
||||
# Determine which request stream(s) should receive this message
|
||||
target_request_id = None
|
||||
if isinstance(
|
||||
message.root, JSONRPCNotification | JSONRPCRequest
|
||||
):
|
||||
# Extract related_request_id from meta if it exists
|
||||
if (
|
||||
(params := getattr(message.root, "params", None))
|
||||
and (meta := params.get("_meta"))
|
||||
and (related_id := meta.get("related_request_id"))
|
||||
is not None
|
||||
):
|
||||
target_request_id = str(related_id)
|
||||
else:
|
||||
target_request_id = str(message.root.id)
|
||||
|
||||
# Send to the specific request stream if available
|
||||
if (
|
||||
target_request_id
|
||||
and target_request_id in self._request_streams
|
||||
):
|
||||
try:
|
||||
await self._request_streams[target_request_id].send(
|
||||
message
|
||||
)
|
||||
except (
|
||||
anyio.BrokenResourceError,
|
||||
anyio.ClosedResourceError,
|
||||
):
|
||||
# Stream might be closed, remove from registry
|
||||
self._request_streams.pop(target_request_id, None)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error in message router: {e}")
|
||||
|
||||
# Start the message router
|
||||
tg.start_soon(message_router)
|
||||
|
||||
try:
|
||||
# Yield the streams for the caller to use
|
||||
yield read_stream, write_stream
|
||||
finally:
|
||||
for stream in list(self._request_streams.values()):
|
||||
try:
|
||||
await stream.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
self._request_streams.clear()
|
||||
@@ -6,7 +6,6 @@ from types import TracebackType
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
import anyio
|
||||
import anyio.lowlevel
|
||||
import httpx
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from pydantic import BaseModel
|
||||
@@ -24,6 +23,7 @@ from mcp.types import (
|
||||
JSONRPCNotification,
|
||||
JSONRPCRequest,
|
||||
JSONRPCResponse,
|
||||
NotificationParams,
|
||||
RequestParams,
|
||||
ServerNotification,
|
||||
ServerRequest,
|
||||
@@ -274,16 +274,32 @@ class BaseSession(
|
||||
await response_stream.aclose()
|
||||
await response_stream_reader.aclose()
|
||||
|
||||
async def send_notification(self, notification: SendNotificationT) -> None:
|
||||
async def send_notification(
|
||||
self,
|
||||
notification: SendNotificationT,
|
||||
related_request_id: RequestId | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Emits a notification, which is a one-way message that does not expect
|
||||
a response.
|
||||
"""
|
||||
# Some transport implementations may need to set the related_request_id
|
||||
# to attribute to the notifications to the request that triggered them.
|
||||
if related_request_id is not None and notification.root.params is not None:
|
||||
# Create meta if it doesn't exist
|
||||
if notification.root.params.meta is None:
|
||||
meta_dict = {"related_request_id": related_request_id}
|
||||
|
||||
else:
|
||||
meta_dict = notification.root.params.meta.model_dump(
|
||||
by_alias=True, mode="json", exclude_none=True
|
||||
)
|
||||
meta_dict["related_request_id"] = related_request_id
|
||||
notification.root.params.meta = NotificationParams.Meta(**meta_dict)
|
||||
jsonrpc_notification = JSONRPCNotification(
|
||||
jsonrpc="2.0",
|
||||
**notification.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
)
|
||||
|
||||
await self._write_stream.send(JSONRPCMessage(jsonrpc_notification))
|
||||
|
||||
async def _send_response(
|
||||
|
||||
@@ -9,6 +9,7 @@ from mcp.shared.memory import (
|
||||
from mcp.shared.session import RequestResponder
|
||||
from mcp.types import (
|
||||
LoggingMessageNotificationParams,
|
||||
NotificationParams,
|
||||
TextContent,
|
||||
)
|
||||
|
||||
@@ -78,6 +79,11 @@ async def test_logging_callback():
|
||||
)
|
||||
assert log_result.isError is False
|
||||
assert len(logging_collector.log_messages) == 1
|
||||
assert logging_collector.log_messages[0] == LoggingMessageNotificationParams(
|
||||
level="info", logger="test_logger", data="Test log message"
|
||||
)
|
||||
# Create meta object with related_request_id added dynamically
|
||||
meta = NotificationParams.Meta()
|
||||
setattr(meta, "related_request_id", "2")
|
||||
log = logging_collector.log_messages[0]
|
||||
assert log.level == "info"
|
||||
assert log.logger == "test_logger"
|
||||
assert log.data == "Test log message"
|
||||
assert log.meta == meta
|
||||
|
||||
@@ -35,7 +35,7 @@ async def test_messages_are_executed_concurrently():
|
||||
end_time = anyio.current_time()
|
||||
|
||||
duration = end_time - start_time
|
||||
assert duration < 3 * _sleep_time_seconds
|
||||
assert duration < 6 * _sleep_time_seconds
|
||||
print(duration)
|
||||
|
||||
|
||||
|
||||
@@ -544,14 +544,28 @@ class TestContextInjection:
|
||||
|
||||
assert mock_log.call_count == 4
|
||||
mock_log.assert_any_call(
|
||||
level="debug", data="Debug message", logger=None
|
||||
)
|
||||
mock_log.assert_any_call(level="info", data="Info message", logger=None)
|
||||
mock_log.assert_any_call(
|
||||
level="warning", data="Warning message", logger=None
|
||||
level="debug",
|
||||
data="Debug message",
|
||||
logger=None,
|
||||
related_request_id="1",
|
||||
)
|
||||
mock_log.assert_any_call(
|
||||
level="error", data="Error message", logger=None
|
||||
level="info",
|
||||
data="Info message",
|
||||
logger=None,
|
||||
related_request_id="1",
|
||||
)
|
||||
mock_log.assert_any_call(
|
||||
level="warning",
|
||||
data="Warning message",
|
||||
logger=None,
|
||||
related_request_id="1",
|
||||
)
|
||||
mock_log.assert_any_call(
|
||||
level="error",
|
||||
data="Error message",
|
||||
logger=None,
|
||||
related_request_id="1",
|
||||
)
|
||||
|
||||
@pytest.mark.anyio
|
||||
|
||||
543
tests/server/test_streamableHttp.py
Normal file
543
tests/server/test_streamableHttp.py
Normal file
@@ -0,0 +1,543 @@
|
||||
"""
|
||||
Tests for the StreamableHTTP server transport validation.
|
||||
|
||||
This file contains tests for request validation in the StreamableHTTP transport.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import multiprocessing
|
||||
import socket
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from http import HTTPStatus
|
||||
from uuid import uuid4
|
||||
|
||||
import anyio
|
||||
import pytest
|
||||
import requests
|
||||
import uvicorn
|
||||
from pydantic import AnyUrl
|
||||
from starlette.applications import Starlette
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
from starlette.routing import Mount
|
||||
|
||||
from mcp.server import Server
|
||||
from mcp.server.streamableHttp import (
|
||||
MCP_SESSION_ID_HEADER,
|
||||
SESSION_ID_PATTERN,
|
||||
StreamableHTTPServerTransport,
|
||||
)
|
||||
from mcp.shared.exceptions import McpError
|
||||
from mcp.types import (
|
||||
ErrorData,
|
||||
TextContent,
|
||||
Tool,
|
||||
)
|
||||
|
||||
# Test constants
|
||||
SERVER_NAME = "test_streamable_http_server"
|
||||
TEST_SESSION_ID = "test-session-id-12345"
|
||||
INIT_REQUEST = {
|
||||
"jsonrpc": "2.0",
|
||||
"method": "initialize",
|
||||
"params": {
|
||||
"clientInfo": {"name": "test-client", "version": "1.0"},
|
||||
"protocolVersion": "2025-03-26",
|
||||
"capabilities": {},
|
||||
},
|
||||
"id": "init-1",
|
||||
}
|
||||
|
||||
|
||||
# Test server implementation that follows MCP protocol
|
||||
class ServerTest(Server):
|
||||
def __init__(self):
|
||||
super().__init__(SERVER_NAME)
|
||||
|
||||
@self.read_resource()
|
||||
async def handle_read_resource(uri: AnyUrl) -> str | bytes:
|
||||
if uri.scheme == "foobar":
|
||||
return f"Read {uri.host}"
|
||||
elif uri.scheme == "slow":
|
||||
# Simulate a slow resource
|
||||
await anyio.sleep(2.0)
|
||||
return f"Slow response from {uri.host}"
|
||||
|
||||
raise McpError(
|
||||
error=ErrorData(
|
||||
code=404, message="OOPS! no resource with that URI was found"
|
||||
)
|
||||
)
|
||||
|
||||
@self.list_tools()
|
||||
async def handle_list_tools() -> list[Tool]:
|
||||
return [
|
||||
Tool(
|
||||
name="test_tool",
|
||||
description="A test tool",
|
||||
inputSchema={"type": "object", "properties": {}},
|
||||
)
|
||||
]
|
||||
|
||||
@self.call_tool()
|
||||
async def handle_call_tool(name: str, args: dict) -> list[TextContent]:
|
||||
return [TextContent(type="text", text=f"Called {name}")]
|
||||
|
||||
|
||||
def create_app(is_json_response_enabled=False) -> Starlette:
|
||||
"""Create a Starlette application for testing that matches the example server.
|
||||
|
||||
Args:
|
||||
is_json_response_enabled: If True, use JSON responses instead of SSE streams.
|
||||
"""
|
||||
# Create server instance
|
||||
server = ServerTest()
|
||||
|
||||
server_instances = {}
|
||||
# Lock to prevent race conditions when creating new sessions
|
||||
session_creation_lock = anyio.Lock()
|
||||
task_group = None
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def lifespan(app):
|
||||
"""Application lifespan context manager for managing task group."""
|
||||
nonlocal task_group
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
task_group = tg
|
||||
print("Application started, task group initialized!")
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
print("Application shutting down, cleaning up resources...")
|
||||
if task_group:
|
||||
tg.cancel_scope.cancel()
|
||||
task_group = None
|
||||
print("Resources cleaned up successfully.")
|
||||
|
||||
async def handle_streamable_http(scope, receive, send):
|
||||
request = Request(scope, receive)
|
||||
request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER)
|
||||
|
||||
# Use existing transport if session ID matches
|
||||
if (
|
||||
request_mcp_session_id is not None
|
||||
and request_mcp_session_id in server_instances
|
||||
):
|
||||
transport = server_instances[request_mcp_session_id]
|
||||
|
||||
await transport.handle_request(scope, receive, send)
|
||||
elif request_mcp_session_id is None:
|
||||
async with session_creation_lock:
|
||||
new_session_id = uuid4().hex
|
||||
|
||||
http_transport = StreamableHTTPServerTransport(
|
||||
mcp_session_id=new_session_id,
|
||||
is_json_response_enabled=is_json_response_enabled,
|
||||
)
|
||||
|
||||
async with http_transport.connect() as streams:
|
||||
read_stream, write_stream = streams
|
||||
|
||||
async def run_server():
|
||||
try:
|
||||
await server.run(
|
||||
read_stream,
|
||||
write_stream,
|
||||
server.create_initialization_options(),
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Server exception: {e}")
|
||||
|
||||
if task_group is None:
|
||||
response = Response(
|
||||
"Internal Server Error: Task group is not initialized",
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
await response(scope, receive, send)
|
||||
return
|
||||
|
||||
# Store the instance before starting the task to prevent races
|
||||
server_instances[http_transport.mcp_session_id] = http_transport
|
||||
task_group.start_soon(run_server)
|
||||
|
||||
await http_transport.handle_request(scope, receive, send)
|
||||
else:
|
||||
response = Response(
|
||||
"Bad Request: No valid session ID provided",
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
)
|
||||
await response(scope, receive, send)
|
||||
|
||||
# Create an ASGI application
|
||||
app = Starlette(
|
||||
debug=True,
|
||||
routes=[
|
||||
Mount("/mcp", app=handle_streamable_http),
|
||||
],
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def run_server(port: int, is_json_response_enabled=False) -> None:
|
||||
"""Run the test server.
|
||||
|
||||
Args:
|
||||
port: Port to listen on.
|
||||
is_json_response_enabled: If True, use JSON responses instead of SSE streams.
|
||||
"""
|
||||
print(
|
||||
f"Starting test server on port {port} with "
|
||||
f"json_enabled={is_json_response_enabled}"
|
||||
)
|
||||
|
||||
app = create_app(is_json_response_enabled)
|
||||
# Configure server
|
||||
config = uvicorn.Config(
|
||||
app=app,
|
||||
host="127.0.0.1",
|
||||
port=port,
|
||||
log_level="info",
|
||||
limit_concurrency=10,
|
||||
timeout_keep_alive=5,
|
||||
access_log=False,
|
||||
)
|
||||
|
||||
# Start the server
|
||||
server = uvicorn.Server(config=config)
|
||||
|
||||
# This is important to catch exceptions and prevent test hangs
|
||||
try:
|
||||
print("Server starting...")
|
||||
server.run()
|
||||
except Exception as e:
|
||||
print(f"ERROR: Server failed to run: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
print("Server shutdown")
|
||||
|
||||
|
||||
# Test fixtures - using same approach as SSE tests
|
||||
@pytest.fixture
|
||||
def basic_server_port() -> int:
|
||||
"""Find an available port for the basic server."""
|
||||
with socket.socket() as s:
|
||||
s.bind(("127.0.0.1", 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def json_server_port() -> int:
|
||||
"""Find an available port for the JSON response server."""
|
||||
with socket.socket() as s:
|
||||
s.bind(("127.0.0.1", 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def basic_server(basic_server_port: int) -> Generator[None, None, None]:
|
||||
"""Start a basic server."""
|
||||
proc = multiprocessing.Process(
|
||||
target=run_server, kwargs={"port": basic_server_port}, daemon=True
|
||||
)
|
||||
proc.start()
|
||||
|
||||
# Wait for server to be running
|
||||
max_attempts = 20
|
||||
attempt = 0
|
||||
while attempt < max_attempts:
|
||||
try:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.connect(("127.0.0.1", basic_server_port))
|
||||
break
|
||||
except ConnectionRefusedError:
|
||||
time.sleep(0.1)
|
||||
attempt += 1
|
||||
else:
|
||||
raise RuntimeError(f"Server failed to start after {max_attempts} attempts")
|
||||
|
||||
yield
|
||||
|
||||
# Clean up
|
||||
proc.kill()
|
||||
proc.join(timeout=2)
|
||||
if proc.is_alive():
|
||||
print("server process failed to terminate")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def json_response_server(json_server_port: int) -> Generator[None, None, None]:
|
||||
"""Start a server with JSON response enabled."""
|
||||
proc = multiprocessing.Process(
|
||||
target=run_server,
|
||||
kwargs={"port": json_server_port, "is_json_response_enabled": True},
|
||||
daemon=True,
|
||||
)
|
||||
proc.start()
|
||||
|
||||
# Wait for server to be running
|
||||
max_attempts = 20
|
||||
attempt = 0
|
||||
while attempt < max_attempts:
|
||||
try:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.connect(("127.0.0.1", json_server_port))
|
||||
break
|
||||
except ConnectionRefusedError:
|
||||
time.sleep(0.1)
|
||||
attempt += 1
|
||||
else:
|
||||
raise RuntimeError(f"Server failed to start after {max_attempts} attempts")
|
||||
|
||||
yield
|
||||
|
||||
# Clean up
|
||||
proc.kill()
|
||||
proc.join(timeout=2)
|
||||
if proc.is_alive():
|
||||
print("server process failed to terminate")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def basic_server_url(basic_server_port: int) -> str:
|
||||
"""Get the URL for the basic test server."""
|
||||
return f"http://127.0.0.1:{basic_server_port}"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def json_server_url(json_server_port: int) -> str:
|
||||
"""Get the URL for the JSON response test server."""
|
||||
return f"http://127.0.0.1:{json_server_port}"
|
||||
|
||||
|
||||
# Basic request validation tests
|
||||
def test_accept_header_validation(basic_server, basic_server_url):
|
||||
"""Test that Accept header is properly validated."""
|
||||
# Test without Accept header
|
||||
response = requests.post(
|
||||
f"{basic_server_url}/mcp",
|
||||
headers={"Content-Type": "application/json"},
|
||||
json={"jsonrpc": "2.0", "method": "initialize", "id": 1},
|
||||
)
|
||||
assert response.status_code == 406
|
||||
assert "Not Acceptable" in response.text
|
||||
|
||||
|
||||
def test_content_type_validation(basic_server, basic_server_url):
|
||||
"""Test that Content-Type header is properly validated."""
|
||||
# Test with incorrect Content-Type
|
||||
response = requests.post(
|
||||
f"{basic_server_url}/mcp",
|
||||
headers={
|
||||
"Accept": "application/json, text/event-stream",
|
||||
"Content-Type": "text/plain",
|
||||
},
|
||||
data="This is not JSON",
|
||||
)
|
||||
assert response.status_code == 415
|
||||
assert "Unsupported Media Type" in response.text
|
||||
|
||||
|
||||
def test_json_validation(basic_server, basic_server_url):
|
||||
"""Test that JSON content is properly validated."""
|
||||
# Test with invalid JSON
|
||||
response = requests.post(
|
||||
f"{basic_server_url}/mcp",
|
||||
headers={
|
||||
"Accept": "application/json, text/event-stream",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
data="this is not valid json",
|
||||
)
|
||||
assert response.status_code == 400
|
||||
assert "Parse error" in response.text
|
||||
|
||||
|
||||
def test_json_parsing(basic_server, basic_server_url):
|
||||
"""Test that JSON content is properly parse."""
|
||||
# Test with valid JSON but invalid JSON-RPC
|
||||
response = requests.post(
|
||||
f"{basic_server_url}/mcp",
|
||||
headers={
|
||||
"Accept": "application/json, text/event-stream",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={"foo": "bar"},
|
||||
)
|
||||
assert response.status_code == 400
|
||||
assert "Validation error" in response.text
|
||||
|
||||
|
||||
def test_method_not_allowed(basic_server, basic_server_url):
|
||||
"""Test that unsupported HTTP methods are rejected."""
|
||||
# Test with unsupported method (PUT)
|
||||
response = requests.put(
|
||||
f"{basic_server_url}/mcp",
|
||||
headers={
|
||||
"Accept": "application/json, text/event-stream",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={"jsonrpc": "2.0", "method": "initialize", "id": 1},
|
||||
)
|
||||
assert response.status_code == 405
|
||||
assert "Method Not Allowed" in response.text
|
||||
|
||||
|
||||
def test_session_validation(basic_server, basic_server_url):
|
||||
"""Test session ID validation."""
|
||||
# session_id not used directly in this test
|
||||
|
||||
# Test without session ID
|
||||
response = requests.post(
|
||||
f"{basic_server_url}/mcp",
|
||||
headers={
|
||||
"Accept": "application/json, text/event-stream",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={"jsonrpc": "2.0", "method": "list_tools", "id": 1},
|
||||
)
|
||||
assert response.status_code == 400
|
||||
assert "Missing session ID" in response.text
|
||||
|
||||
|
||||
def test_session_id_pattern():
|
||||
"""Test that SESSION_ID_PATTERN correctly validates session IDs."""
|
||||
# Valid session IDs (visible ASCII characters from 0x21 to 0x7E)
|
||||
valid_session_ids = [
|
||||
"test-session-id",
|
||||
"1234567890",
|
||||
"session!@#$%^&*()_+-=[]{}|;:,.<>?/",
|
||||
"~`",
|
||||
]
|
||||
|
||||
for session_id in valid_session_ids:
|
||||
assert SESSION_ID_PATTERN.match(session_id) is not None
|
||||
# Ensure fullmatch matches too (whole string)
|
||||
assert SESSION_ID_PATTERN.fullmatch(session_id) is not None
|
||||
|
||||
# Invalid session IDs
|
||||
invalid_session_ids = [
|
||||
"", # Empty string
|
||||
" test", # Space (0x20)
|
||||
"test\t", # Tab
|
||||
"test\n", # Newline
|
||||
"test\r", # Carriage return
|
||||
"test" + chr(0x7F), # DEL character
|
||||
"test" + chr(0x80), # Extended ASCII
|
||||
"test" + chr(0x00), # Null character
|
||||
"test" + chr(0x20), # Space (0x20)
|
||||
]
|
||||
|
||||
for session_id in invalid_session_ids:
|
||||
# For invalid IDs, either match will fail or fullmatch will fail
|
||||
if SESSION_ID_PATTERN.match(session_id) is not None:
|
||||
# If match succeeds, fullmatch should fail (partial match case)
|
||||
assert SESSION_ID_PATTERN.fullmatch(session_id) is None
|
||||
|
||||
|
||||
def test_streamable_http_transport_init_validation():
|
||||
"""Test that StreamableHTTPServerTransport validates session ID on init."""
|
||||
# Valid session ID should initialize without errors
|
||||
valid_transport = StreamableHTTPServerTransport(mcp_session_id="valid-id")
|
||||
assert valid_transport.mcp_session_id == "valid-id"
|
||||
|
||||
# None should be accepted
|
||||
none_transport = StreamableHTTPServerTransport(mcp_session_id=None)
|
||||
assert none_transport.mcp_session_id is None
|
||||
|
||||
# Invalid session ID should raise ValueError
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
StreamableHTTPServerTransport(mcp_session_id="invalid id with space")
|
||||
assert "Session ID must only contain visible ASCII characters" in str(excinfo.value)
|
||||
|
||||
# Test with control characters
|
||||
with pytest.raises(ValueError):
|
||||
StreamableHTTPServerTransport(mcp_session_id="test\nid")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
StreamableHTTPServerTransport(mcp_session_id="test\n")
|
||||
|
||||
|
||||
def test_session_termination(basic_server, basic_server_url):
|
||||
"""Test session termination via DELETE and subsequent request handling."""
|
||||
response = requests.post(
|
||||
f"{basic_server_url}/mcp",
|
||||
headers={
|
||||
"Accept": "application/json, text/event-stream",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json=INIT_REQUEST,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Now terminate the session
|
||||
session_id = response.headers.get(MCP_SESSION_ID_HEADER)
|
||||
response = requests.delete(
|
||||
f"{basic_server_url}/mcp",
|
||||
headers={MCP_SESSION_ID_HEADER: session_id},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Try to use the terminated session
|
||||
response = requests.post(
|
||||
f"{basic_server_url}/mcp",
|
||||
headers={
|
||||
"Accept": "application/json, text/event-stream",
|
||||
"Content-Type": "application/json",
|
||||
MCP_SESSION_ID_HEADER: session_id,
|
||||
},
|
||||
json={"jsonrpc": "2.0", "method": "ping", "id": 2},
|
||||
)
|
||||
assert response.status_code == 404
|
||||
assert "Session has been terminated" in response.text
|
||||
|
||||
|
||||
def test_response(basic_server, basic_server_url):
|
||||
"""Test response handling for a valid request."""
|
||||
mcp_url = f"{basic_server_url}/mcp"
|
||||
response = requests.post(
|
||||
mcp_url,
|
||||
headers={
|
||||
"Accept": "application/json, text/event-stream",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json=INIT_REQUEST,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Now terminate the session
|
||||
session_id = response.headers.get(MCP_SESSION_ID_HEADER)
|
||||
|
||||
# Try to use the terminated session
|
||||
tools_response = requests.post(
|
||||
mcp_url,
|
||||
headers={
|
||||
"Accept": "application/json, text/event-stream",
|
||||
"Content-Type": "application/json",
|
||||
MCP_SESSION_ID_HEADER: session_id, # Use the session ID we got earlier
|
||||
},
|
||||
json={"jsonrpc": "2.0", "method": "tools/list", "id": "tools-1"},
|
||||
stream=True,
|
||||
)
|
||||
assert tools_response.status_code == 200
|
||||
assert tools_response.headers.get("Content-Type") == "text/event-stream"
|
||||
|
||||
|
||||
def test_json_response(json_response_server, json_server_url):
|
||||
"""Test response handling when is_json_response_enabled is True."""
|
||||
mcp_url = f"{json_server_url}/mcp"
|
||||
response = requests.post(
|
||||
mcp_url,
|
||||
headers={
|
||||
"Accept": "application/json, text/event-stream",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json=INIT_REQUEST,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.headers.get("Content-Type") == "application/json"
|
||||
38
uv.lock
generated
38
uv.lock
generated
@@ -10,6 +10,7 @@ members = [
|
||||
"mcp",
|
||||
"mcp-simple-prompt",
|
||||
"mcp-simple-resource",
|
||||
"mcp-simple-streamablehttp",
|
||||
"mcp-simple-tool",
|
||||
]
|
||||
|
||||
@@ -632,6 +633,43 @@ dev = [
|
||||
{ name = "ruff", specifier = ">=0.6.9" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mcp-simple-streamablehttp"
|
||||
version = "0.1.0"
|
||||
source = { editable = "examples/servers/simple-streamablehttp" }
|
||||
dependencies = [
|
||||
{ name = "anyio" },
|
||||
{ name = "click" },
|
||||
{ name = "httpx" },
|
||||
{ name = "mcp" },
|
||||
{ name = "starlette" },
|
||||
{ name = "uvicorn" },
|
||||
]
|
||||
|
||||
[package.dev-dependencies]
|
||||
dev = [
|
||||
{ name = "pyright" },
|
||||
{ name = "pytest" },
|
||||
{ name = "ruff" },
|
||||
]
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "anyio", specifier = ">=4.5" },
|
||||
{ name = "click", specifier = ">=8.1.0" },
|
||||
{ name = "httpx", specifier = ">=0.27" },
|
||||
{ name = "mcp", editable = "." },
|
||||
{ name = "starlette" },
|
||||
{ name = "uvicorn" },
|
||||
]
|
||||
|
||||
[package.metadata.requires-dev]
|
||||
dev = [
|
||||
{ name = "pyright", specifier = ">=1.1.378" },
|
||||
{ name = "pytest", specifier = ">=8.3.3" },
|
||||
{ name = "ruff", specifier = ">=0.6.9" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mcp-simple-tool"
|
||||
version = "0.1.0"
|
||||
|
||||
Reference in New Issue
Block a user