Support for http request injection propagation to tools (#816)

This commit is contained in:
ihrpr
2025-05-28 15:59:14 +01:00
committed by GitHub
parent 532b1176f9
commit 70014a2bbb
12 changed files with 413 additions and 35 deletions

View File

@@ -5,14 +5,18 @@ These tests validate the proper functioning of FastMCP in various configurations
including with and without authentication.
"""
import json
import multiprocessing
import socket
import time
from collections.abc import Generator
from typing import Any
import pytest
import uvicorn
from pydantic import AnyUrl
from starlette.applications import Starlette
from starlette.requests import Request
import mcp.types as types
from mcp.client.session import ClientSession
@@ -20,6 +24,7 @@ from mcp.client.sse import sse_client
from mcp.client.streamable_http import streamablehttp_client
from mcp.server.fastmcp import FastMCP
from mcp.server.fastmcp.resources import FunctionResource
from mcp.server.fastmcp.server import Context
from mcp.shared.context import RequestContext
from mcp.types import (
CreateMessageRequestParams,
@@ -78,8 +83,6 @@ def stateless_http_server_url(stateless_http_server_port: int) -> str:
# Create a function to make the FastMCP server app
def make_fastmcp_app():
"""Create a FastMCP server without auth settings."""
from starlette.applications import Starlette
mcp = FastMCP(name="NoAuthServer")
# Add a simple tool
@@ -88,7 +91,7 @@ def make_fastmcp_app():
return f"Echo: {message}"
# Create the SSE app
app: Starlette = mcp.sse_app()
app = mcp.sse_app()
return mcp, app
@@ -198,17 +201,14 @@ def make_everything_fastmcp() -> FastMCP:
def make_everything_fastmcp_app():
"""Create a comprehensive FastMCP server with SSE transport."""
from starlette.applications import Starlette
mcp = make_everything_fastmcp()
# Create the SSE app
app: Starlette = mcp.sse_app()
app = mcp.sse_app()
return mcp, app
def make_fastmcp_streamable_http_app():
"""Create a FastMCP server with StreamableHTTP transport."""
from starlette.applications import Starlette
mcp = FastMCP(name="NoAuthServer")
@@ -225,8 +225,6 @@ def make_fastmcp_streamable_http_app():
def make_everything_fastmcp_streamable_http_app():
"""Create a comprehensive FastMCP server with StreamableHTTP transport."""
from starlette.applications import Starlette
# Create a new instance with different name for HTTP transport
mcp = make_everything_fastmcp()
# We can't change the name after creation, so we'll use the same name
@@ -237,7 +235,6 @@ def make_everything_fastmcp_streamable_http_app():
def make_fastmcp_stateless_http_app():
"""Create a FastMCP server with stateless StreamableHTTP transport."""
from starlette.applications import Starlette
mcp = FastMCP(name="StatelessServer", stateless_http=True)
@@ -435,6 +432,174 @@ async def test_fastmcp_without_auth(server: None, server_url: str) -> None:
assert tool_result.content[0].text == "Echo: hello"
def make_fastmcp_with_context_app():
"""Create a FastMCP server that can access request context."""
mcp = FastMCP(name="ContextServer")
# Tool that echoes request headers
@mcp.tool(description="Echo request headers from context")
def echo_headers(ctx: Context[Any, Any, Request]) -> str:
"""Returns the request headers as JSON."""
headers_info = {}
if ctx.request_context.request:
# Now the type system knows request is a Starlette Request object
headers_info = dict(ctx.request_context.request.headers)
return json.dumps(headers_info)
# Tool that returns full request context
@mcp.tool(description="Echo request context with custom data")
def echo_context(custom_request_id: str, ctx: Context[Any, Any, Request]) -> str:
"""Returns request context including headers and custom data."""
context_data = {
"custom_request_id": custom_request_id,
"headers": {},
"method": None,
"path": None,
}
if ctx.request_context.request:
request = ctx.request_context.request
context_data["headers"] = dict(request.headers)
context_data["method"] = request.method
context_data["path"] = request.url.path
return json.dumps(context_data)
# Create the SSE app
app = mcp.sse_app()
return mcp, app
def run_context_server(server_port: int) -> None:
"""Run the context-aware FastMCP server."""
_, app = make_fastmcp_with_context_app()
server = uvicorn.Server(
config=uvicorn.Config(
app=app, host="127.0.0.1", port=server_port, log_level="error"
)
)
print(f"Starting context server on port {server_port}")
server.run()
@pytest.fixture()
def context_aware_server(server_port: int) -> Generator[None, None, None]:
"""Start the context-aware server in a separate process."""
proc = multiprocessing.Process(
target=run_context_server, args=(server_port,), daemon=True
)
print("Starting context-aware server process")
proc.start()
# Wait for server to be running
max_attempts = 20
attempt = 0
print("Waiting for context-aware server to start")
while attempt < max_attempts:
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.connect(("127.0.0.1", server_port))
break
except ConnectionRefusedError:
time.sleep(0.1)
attempt += 1
else:
raise RuntimeError(
f"Context server failed to start after {max_attempts} attempts"
)
yield
print("Killing context-aware server")
proc.kill()
proc.join(timeout=2)
if proc.is_alive():
print("Context server process failed to terminate")
@pytest.mark.anyio
async def test_fast_mcp_with_request_context(
context_aware_server: None, server_url: str
) -> None:
"""Test that FastMCP properly propagates request context to tools."""
# Test with custom headers
custom_headers = {
"Authorization": "Bearer fastmcp-test-token",
"X-Custom-Header": "fastmcp-value",
"X-Request-Id": "req-123",
}
async with sse_client(server_url + "/sse", headers=custom_headers) as streams:
async with ClientSession(*streams) as session:
# Initialize the session
result = await session.initialize()
assert isinstance(result, InitializeResult)
assert result.serverInfo.name == "ContextServer"
# Test 1: Call tool that echoes headers
headers_result = await session.call_tool("echo_headers", {})
assert len(headers_result.content) == 1
assert isinstance(headers_result.content[0], TextContent)
headers_data = json.loads(headers_result.content[0].text)
assert headers_data.get("authorization") == "Bearer fastmcp-test-token"
assert headers_data.get("x-custom-header") == "fastmcp-value"
assert headers_data.get("x-request-id") == "req-123"
# Test 2: Call tool that returns full context
context_result = await session.call_tool(
"echo_context", {"custom_request_id": "test-123"}
)
assert len(context_result.content) == 1
assert isinstance(context_result.content[0], TextContent)
context_data = json.loads(context_result.content[0].text)
assert context_data["custom_request_id"] == "test-123"
assert (
context_data["headers"].get("authorization")
== "Bearer fastmcp-test-token"
)
assert context_data["method"] == "POST" #
@pytest.mark.anyio
async def test_fast_mcp_request_context_isolation(
context_aware_server: None, server_url: str
) -> None:
"""Test that request contexts are isolated between different FastMCP clients."""
contexts = []
# Create multiple clients with different headers
for i in range(3):
headers = {
"Authorization": f"Bearer token-{i}",
"X-Request-Id": f"fastmcp-req-{i}",
"X-Custom-Value": f"value-{i}",
}
async with sse_client(server_url + "/sse", headers=headers) as streams:
async with ClientSession(*streams) as session:
await session.initialize()
# Call the tool that returns context
tool_result = await session.call_tool(
"echo_context", {"custom_request_id": f"test-req-{i}"}
)
# Parse and store the result
assert len(tool_result.content) == 1
assert isinstance(tool_result.content[0], TextContent)
context_data = json.loads(tool_result.content[0].text)
contexts.append(context_data)
# Verify each request had its own isolated context
assert len(contexts) == 3
for i, ctx in enumerate(contexts):
assert ctx["custom_request_id"] == f"test-req-{i}"
assert ctx["headers"].get("authorization") == f"Bearer token-{i}"
assert ctx["headers"].get("x-request-id") == f"fastmcp-req-{i}"
assert ctx["headers"].get("x-custom-value") == f"value-{i}"
@pytest.mark.anyio
async def test_fastmcp_streamable_http(
streamable_http_server: None, http_server_url: str

View File

@@ -9,7 +9,7 @@ from mcp.server.fastmcp.exceptions import ToolError
from mcp.server.fastmcp.tools import Tool, ToolManager
from mcp.server.fastmcp.utilities.func_metadata import ArgModelBase, FuncMetadata
from mcp.server.session import ServerSessionT
from mcp.shared.context import LifespanContextT
from mcp.shared.context import LifespanContextT, RequestT
from mcp.types import ToolAnnotations
@@ -347,7 +347,7 @@ class TestContextHandling:
assert tool.context_kwarg is None
def tool_with_parametrized_context(
x: int, ctx: Context[ServerSessionT, LifespanContextT]
x: int, ctx: Context[ServerSessionT, LifespanContextT, RequestT]
) -> str:
return str(x)

View File

@@ -1,3 +1,4 @@
import json
import multiprocessing
import socket
import time
@@ -318,3 +319,187 @@ async def test_sse_client_basic_connection_mounted_app(
# Test ping
ping_result = await session.send_ping()
assert isinstance(ping_result, EmptyResult)
# Test server with request context that returns headers in the response
class RequestContextServer(Server[object, Request]):
def __init__(self):
super().__init__("request_context_server")
@self.call_tool()
async def handle_call_tool(name: str, args: dict) -> list[TextContent]:
headers_info = {}
context = self.request_context
if context.request:
headers_info = dict(context.request.headers)
if name == "echo_headers":
return [TextContent(type="text", text=json.dumps(headers_info))]
elif name == "echo_context":
context_data = {
"request_id": args.get("request_id"),
"headers": headers_info,
}
return [TextContent(type="text", text=json.dumps(context_data))]
return [TextContent(type="text", text=f"Called {name}")]
@self.list_tools()
async def handle_list_tools() -> list[Tool]:
return [
Tool(
name="echo_headers",
description="Echoes request headers",
inputSchema={"type": "object", "properties": {}},
),
Tool(
name="echo_context",
description="Echoes request context",
inputSchema={
"type": "object",
"properties": {"request_id": {"type": "string"}},
"required": ["request_id"],
},
),
]
def run_context_server(server_port: int) -> None:
"""Run a server that captures request context"""
sse = SseServerTransport("/messages/")
context_server = RequestContextServer()
async def handle_sse(request: Request) -> Response:
async with sse.connect_sse(
request.scope, request.receive, request._send
) as streams:
await context_server.run(
streams[0], streams[1], context_server.create_initialization_options()
)
return Response()
app = Starlette(
routes=[
Route("/sse", endpoint=handle_sse),
Mount("/messages/", app=sse.handle_post_message),
]
)
server = uvicorn.Server(
config=uvicorn.Config(
app=app, host="127.0.0.1", port=server_port, log_level="error"
)
)
print(f"starting context server on {server_port}")
server.run()
@pytest.fixture()
def context_server(server_port: int) -> Generator[None, None, None]:
"""Fixture that provides a server with request context capture"""
proc = multiprocessing.Process(
target=run_context_server, kwargs={"server_port": server_port}, daemon=True
)
print("starting context server process")
proc.start()
# Wait for server to be running
max_attempts = 20
attempt = 0
print("waiting for context server to start")
while attempt < max_attempts:
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.connect(("127.0.0.1", server_port))
break
except ConnectionRefusedError:
time.sleep(0.1)
attempt += 1
else:
raise RuntimeError(
f"Context server failed to start after {max_attempts} attempts"
)
yield
print("killing context server")
proc.kill()
proc.join(timeout=2)
if proc.is_alive():
print("context server process failed to terminate")
@pytest.mark.anyio
async def test_request_context_propagation(
context_server: None, server_url: str
) -> None:
"""Test that request context is properly propagated through SSE transport."""
# Test with custom headers
custom_headers = {
"Authorization": "Bearer test-token",
"X-Custom-Header": "test-value",
"X-Trace-Id": "trace-123",
}
async with sse_client(server_url + "/sse", headers=custom_headers) as (
read_stream,
write_stream,
):
async with ClientSession(read_stream, write_stream) as session:
# Initialize the session
result = await session.initialize()
assert isinstance(result, InitializeResult)
# Call the tool that echoes headers back
tool_result = await session.call_tool("echo_headers", {})
# Parse the JSON response
assert len(tool_result.content) == 1
headers_data = json.loads(
tool_result.content[0].text
if tool_result.content[0].type == "text"
else "{}"
)
# Verify headers were propagated
assert headers_data.get("authorization") == "Bearer test-token"
assert headers_data.get("x-custom-header") == "test-value"
assert headers_data.get("x-trace-id") == "trace-123"
@pytest.mark.anyio
async def test_request_context_isolation(context_server: None, server_url: str) -> None:
"""Test that request contexts are isolated between different SSE clients."""
contexts = []
# Create multiple clients with different headers
for i in range(3):
headers = {"X-Request-Id": f"request-{i}", "X-Custom-Value": f"value-{i}"}
async with sse_client(server_url + "/sse", headers=headers) as (
read_stream,
write_stream,
):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()
# Call the tool that echoes context
tool_result = await session.call_tool(
"echo_context", {"request_id": f"request-{i}"}
)
assert len(tool_result.content) == 1
context_data = json.loads(
tool_result.content[0].text
if tool_result.content[0].type == "text"
else "{}"
)
contexts.append(context_data)
# Verify each request had its own context
assert len(contexts) == 3
for i, ctx in enumerate(contexts):
assert ctx["request_id"] == f"request-{i}"
assert ctx["headers"].get("x-request-id") == f"request-{i}"
assert ctx["headers"].get("x-custom-value") == f"value-{i}"