Support for http request injection propagation in StreamableHttp (#833)

This commit is contained in:
ihrpr
2025-05-29 15:21:06 +01:00
committed by GitHub
parent 7f94bef85e
commit 05b7156ea8
3 changed files with 259 additions and 172 deletions

View File

@@ -397,7 +397,8 @@ class StreamableHTTPServerTransport:
await response(scope, receive, send)
# Process the message after sending the response
session_message = SessionMessage(message)
metadata = ServerMessageMetadata(request_context=request)
session_message = SessionMessage(message, metadata=metadata)
await writer.send(session_message)
return
@@ -412,7 +413,8 @@ class StreamableHTTPServerTransport:
if self.is_json_response_enabled:
# Process the message
session_message = SessionMessage(message)
metadata = ServerMessageMetadata(request_context=request)
session_message = SessionMessage(message, metadata=metadata)
await writer.send(session_message)
try:
# Process messages from the request-specific stream
@@ -511,7 +513,8 @@ class StreamableHTTPServerTransport:
async with anyio.create_task_group() as tg:
tg.start_soon(response, scope, receive, send)
# Then send the message to be processed by the server
session_message = SessionMessage(message)
metadata = ServerMessageMetadata(request_context=request)
session_message = SessionMessage(message, metadata=metadata)
await writer.send(session_message)
except Exception:
logger.exception("SSE response error")

View File

@@ -24,7 +24,6 @@ from mcp.client.sse import sse_client
from mcp.client.streamable_http import streamablehttp_client
from mcp.server.fastmcp import FastMCP
from mcp.server.fastmcp.resources import FunctionResource
from mcp.server.fastmcp.server import Context
from mcp.shared.context import RequestContext
from mcp.types import (
CreateMessageRequestParams,
@@ -196,6 +195,33 @@ def make_everything_fastmcp() -> FastMCP:
# Since FastMCP doesn't support system messages in the same way
return f"Context: {context}. Query: {user_query}"
# Tool that echoes request headers from context
@mcp.tool(description="Echo request headers from context")
def echo_headers(ctx: Context[Any, Any, Request]) -> str:
"""Returns the request headers as JSON."""
headers_info = {}
if ctx.request_context.request:
# Now the type system knows request is a Starlette Request object
headers_info = dict(ctx.request_context.request.headers)
return json.dumps(headers_info)
# Tool that returns full request context
@mcp.tool(description="Echo request context with custom data")
def echo_context(custom_request_id: str, ctx: Context[Any, Any, Request]) -> str:
"""Returns request context including headers and custom data."""
context_data = {
"custom_request_id": custom_request_id,
"headers": {},
"method": None,
"path": None,
}
if ctx.request_context.request:
request = ctx.request_context.request
context_data["headers"] = dict(request.headers)
context_data["method"] = request.method
context_data["path"] = request.url.path
return json.dumps(context_data)
return mcp
@@ -432,174 +458,6 @@ async def test_fastmcp_without_auth(server: None, server_url: str) -> None:
assert tool_result.content[0].text == "Echo: hello"
def make_fastmcp_with_context_app():
"""Create a FastMCP server that can access request context."""
mcp = FastMCP(name="ContextServer")
# Tool that echoes request headers
@mcp.tool(description="Echo request headers from context")
def echo_headers(ctx: Context[Any, Any, Request]) -> str:
"""Returns the request headers as JSON."""
headers_info = {}
if ctx.request_context.request:
# Now the type system knows request is a Starlette Request object
headers_info = dict(ctx.request_context.request.headers)
return json.dumps(headers_info)
# Tool that returns full request context
@mcp.tool(description="Echo request context with custom data")
def echo_context(custom_request_id: str, ctx: Context[Any, Any, Request]) -> str:
"""Returns request context including headers and custom data."""
context_data = {
"custom_request_id": custom_request_id,
"headers": {},
"method": None,
"path": None,
}
if ctx.request_context.request:
request = ctx.request_context.request
context_data["headers"] = dict(request.headers)
context_data["method"] = request.method
context_data["path"] = request.url.path
return json.dumps(context_data)
# Create the SSE app
app = mcp.sse_app()
return mcp, app
def run_context_server(server_port: int) -> None:
"""Run the context-aware FastMCP server."""
_, app = make_fastmcp_with_context_app()
server = uvicorn.Server(
config=uvicorn.Config(
app=app, host="127.0.0.1", port=server_port, log_level="error"
)
)
print(f"Starting context server on port {server_port}")
server.run()
@pytest.fixture()
def context_aware_server(server_port: int) -> Generator[None, None, None]:
"""Start the context-aware server in a separate process."""
proc = multiprocessing.Process(
target=run_context_server, args=(server_port,), daemon=True
)
print("Starting context-aware server process")
proc.start()
# Wait for server to be running
max_attempts = 20
attempt = 0
print("Waiting for context-aware server to start")
while attempt < max_attempts:
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.connect(("127.0.0.1", server_port))
break
except ConnectionRefusedError:
time.sleep(0.1)
attempt += 1
else:
raise RuntimeError(
f"Context server failed to start after {max_attempts} attempts"
)
yield
print("Killing context-aware server")
proc.kill()
proc.join(timeout=2)
if proc.is_alive():
print("Context server process failed to terminate")
@pytest.mark.anyio
async def test_fast_mcp_with_request_context(
context_aware_server: None, server_url: str
) -> None:
"""Test that FastMCP properly propagates request context to tools."""
# Test with custom headers
custom_headers = {
"Authorization": "Bearer fastmcp-test-token",
"X-Custom-Header": "fastmcp-value",
"X-Request-Id": "req-123",
}
async with sse_client(server_url + "/sse", headers=custom_headers) as streams:
async with ClientSession(*streams) as session:
# Initialize the session
result = await session.initialize()
assert isinstance(result, InitializeResult)
assert result.serverInfo.name == "ContextServer"
# Test 1: Call tool that echoes headers
headers_result = await session.call_tool("echo_headers", {})
assert len(headers_result.content) == 1
assert isinstance(headers_result.content[0], TextContent)
headers_data = json.loads(headers_result.content[0].text)
assert headers_data.get("authorization") == "Bearer fastmcp-test-token"
assert headers_data.get("x-custom-header") == "fastmcp-value"
assert headers_data.get("x-request-id") == "req-123"
# Test 2: Call tool that returns full context
context_result = await session.call_tool(
"echo_context", {"custom_request_id": "test-123"}
)
assert len(context_result.content) == 1
assert isinstance(context_result.content[0], TextContent)
context_data = json.loads(context_result.content[0].text)
assert context_data["custom_request_id"] == "test-123"
assert (
context_data["headers"].get("authorization")
== "Bearer fastmcp-test-token"
)
assert context_data["method"] == "POST" #
@pytest.mark.anyio
async def test_fast_mcp_request_context_isolation(
context_aware_server: None, server_url: str
) -> None:
"""Test that request contexts are isolated between different FastMCP clients."""
contexts = []
# Create multiple clients with different headers
for i in range(3):
headers = {
"Authorization": f"Bearer token-{i}",
"X-Request-Id": f"fastmcp-req-{i}",
"X-Custom-Value": f"value-{i}",
}
async with sse_client(server_url + "/sse", headers=headers) as streams:
async with ClientSession(*streams) as session:
await session.initialize()
# Call the tool that returns context
tool_result = await session.call_tool(
"echo_context", {"custom_request_id": f"test-req-{i}"}
)
# Parse and store the result
assert len(tool_result.content) == 1
assert isinstance(tool_result.content[0], TextContent)
context_data = json.loads(tool_result.content[0].text)
contexts.append(context_data)
# Verify each request had its own isolated context
assert len(contexts) == 3
for i, ctx in enumerate(contexts):
assert ctx["custom_request_id"] == f"test-req-{i}"
assert ctx["headers"].get("authorization") == f"Bearer token-{i}"
assert ctx["headers"].get("x-request-id") == f"fastmcp-req-{i}"
assert ctx["headers"].get("x-custom-value") == f"value-{i}"
@pytest.mark.anyio
async def test_fastmcp_streamable_http(
streamable_http_server: None, http_server_url: str
@@ -967,6 +825,30 @@ async def call_all_mcp_features(
assert isinstance(complex_result, GetPromptResult)
assert len(complex_result.messages) >= 1
# Test request context propagation (only works when headers are available)
headers_result = await session.call_tool("echo_headers", {})
assert len(headers_result.content) == 1
assert isinstance(headers_result.content[0], TextContent)
# If we got headers, verify they exist
headers_data = json.loads(headers_result.content[0].text)
# The headers depend on the transport and test setup
print(f"Received headers: {headers_data}")
# Test 6: Call tool that returns full context
context_result = await session.call_tool(
"echo_context", {"custom_request_id": "test-123"}
)
assert len(context_result.content) == 1
assert isinstance(context_result.content[0], TextContent)
context_data = json.loads(context_result.content[0].text)
assert context_data["custom_request_id"] == "test-123"
# The method should be POST for most transports
if context_data["method"]:
assert context_data["method"] == "POST"
async def sampling_callback(
context: RequestContext[ClientSession, None],

View File

@@ -4,6 +4,7 @@ Tests for the StreamableHTTP server and client transport.
Contains tests for both server and client sides of the StreamableHTTP transport.
"""
import json
import multiprocessing
import socket
import time
@@ -17,6 +18,7 @@ import requests
import uvicorn
from pydantic import AnyUrl
from starlette.applications import Starlette
from starlette.requests import Request
from starlette.routing import Mount
import mcp.types as types
@@ -1223,3 +1225,203 @@ async def test_streamablehttp_server_sampling(basic_server, basic_server_url):
captured_message_params.messages[0].content.text
== "Server needs client sampling"
)
# Context-aware server implementation for testing request context propagation
class ContextAwareServerTest(Server):
def __init__(self):
super().__init__("ContextAwareServer")
@self.list_tools()
async def handle_list_tools() -> list[Tool]:
return [
Tool(
name="echo_headers",
description="Echo request headers from context",
inputSchema={"type": "object", "properties": {}},
),
Tool(
name="echo_context",
description="Echo request context with custom data",
inputSchema={
"type": "object",
"properties": {
"request_id": {"type": "string"},
},
"required": ["request_id"],
},
),
]
@self.call_tool()
async def handle_call_tool(name: str, args: dict) -> list[TextContent]:
ctx = self.request_context
if name == "echo_headers":
# Access the request object from context
headers_info = {}
if ctx.request and isinstance(ctx.request, Request):
headers_info = dict(ctx.request.headers)
return [
TextContent(
type="text",
text=json.dumps(headers_info),
)
]
elif name == "echo_context":
# Return full context information
context_data = {
"request_id": args.get("request_id"),
"headers": {},
"method": None,
"path": None,
}
if ctx.request and isinstance(ctx.request, Request):
request = ctx.request
context_data["headers"] = dict(request.headers)
context_data["method"] = request.method
context_data["path"] = request.url.path
return [
TextContent(
type="text",
text=json.dumps(context_data),
)
]
return [TextContent(type="text", text=f"Unknown tool: {name}")]
# Server runner for context-aware testing
def run_context_aware_server(port: int):
"""Run the context-aware test server."""
server = ContextAwareServerTest()
session_manager = StreamableHTTPSessionManager(
app=server,
event_store=None,
json_response=False,
)
app = Starlette(
debug=True,
routes=[
Mount("/mcp", app=session_manager.handle_request),
],
lifespan=lambda app: session_manager.run(),
)
server_instance = uvicorn.Server(
config=uvicorn.Config(
app=app,
host="127.0.0.1",
port=port,
log_level="error",
)
)
server_instance.run()
@pytest.fixture
def context_aware_server(basic_server_port: int) -> Generator[None, None, None]:
"""Start the context-aware server in a separate process."""
proc = multiprocessing.Process(
target=run_context_aware_server, args=(basic_server_port,), daemon=True
)
proc.start()
# Wait for server to be running
max_attempts = 20
attempt = 0
while attempt < max_attempts:
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.connect(("127.0.0.1", basic_server_port))
break
except ConnectionRefusedError:
time.sleep(0.1)
attempt += 1
else:
raise RuntimeError(
f"Context-aware server failed to start after {max_attempts} attempts"
)
yield
proc.kill()
proc.join(timeout=2)
if proc.is_alive():
print("Context-aware server process failed to terminate")
@pytest.mark.anyio
async def test_streamablehttp_request_context_propagation(
context_aware_server: None, basic_server_url: str
) -> None:
"""Test that request context is properly propagated through StreamableHTTP."""
custom_headers = {
"Authorization": "Bearer test-token",
"X-Custom-Header": "test-value",
"X-Trace-Id": "trace-123",
}
async with streamablehttp_client(
f"{basic_server_url}/mcp", headers=custom_headers
) as (read_stream, write_stream, _):
async with ClientSession(read_stream, write_stream) as session:
result = await session.initialize()
assert isinstance(result, InitializeResult)
assert result.serverInfo.name == "ContextAwareServer"
# Call the tool that echoes headers back
tool_result = await session.call_tool("echo_headers", {})
# Parse the JSON response
assert len(tool_result.content) == 1
assert isinstance(tool_result.content[0], TextContent)
headers_data = json.loads(tool_result.content[0].text)
# Verify headers were propagated
assert headers_data.get("authorization") == "Bearer test-token"
assert headers_data.get("x-custom-header") == "test-value"
assert headers_data.get("x-trace-id") == "trace-123"
@pytest.mark.anyio
async def test_streamablehttp_request_context_isolation(
context_aware_server: None, basic_server_url: str
) -> None:
"""Test that request contexts are isolated between StreamableHTTP clients."""
contexts = []
# Create multiple clients with different headers
for i in range(3):
headers = {
"X-Request-Id": f"request-{i}",
"X-Custom-Value": f"value-{i}",
"Authorization": f"Bearer token-{i}",
}
async with streamablehttp_client(
f"{basic_server_url}/mcp", headers=headers
) as (read_stream, write_stream, _):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()
# Call the tool that echoes context
tool_result = await session.call_tool(
"echo_context", {"request_id": f"request-{i}"}
)
assert len(tool_result.content) == 1
assert isinstance(tool_result.content[0], TextContent)
context_data = json.loads(tool_result.content[0].text)
contexts.append(context_data)
# Verify each request had its own context
assert len(contexts) == 3
for i, ctx in enumerate(contexts):
assert ctx["request_id"] == f"request-{i}"
assert ctx["headers"].get("x-request-id") == f"request-{i}"
assert ctx["headers"].get("x-custom-value") == f"value-{i}"
assert ctx["headers"].get("authorization") == f"Bearer token-{i}"