mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 06:54:18 +01:00
Support for http request injection propagation to tools (#816)
This commit is contained in:
@@ -5,14 +5,18 @@ These tests validate the proper functioning of FastMCP in various configurations
|
||||
including with and without authentication.
|
||||
"""
|
||||
|
||||
import json
|
||||
import multiprocessing
|
||||
import socket
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
import uvicorn
|
||||
from pydantic import AnyUrl
|
||||
from starlette.applications import Starlette
|
||||
from starlette.requests import Request
|
||||
|
||||
import mcp.types as types
|
||||
from mcp.client.session import ClientSession
|
||||
@@ -20,6 +24,7 @@ from mcp.client.sse import sse_client
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
from mcp.server.fastmcp.resources import FunctionResource
|
||||
from mcp.server.fastmcp.server import Context
|
||||
from mcp.shared.context import RequestContext
|
||||
from mcp.types import (
|
||||
CreateMessageRequestParams,
|
||||
@@ -78,8 +83,6 @@ def stateless_http_server_url(stateless_http_server_port: int) -> str:
|
||||
# Create a function to make the FastMCP server app
|
||||
def make_fastmcp_app():
|
||||
"""Create a FastMCP server without auth settings."""
|
||||
from starlette.applications import Starlette
|
||||
|
||||
mcp = FastMCP(name="NoAuthServer")
|
||||
|
||||
# Add a simple tool
|
||||
@@ -88,7 +91,7 @@ def make_fastmcp_app():
|
||||
return f"Echo: {message}"
|
||||
|
||||
# Create the SSE app
|
||||
app: Starlette = mcp.sse_app()
|
||||
app = mcp.sse_app()
|
||||
|
||||
return mcp, app
|
||||
|
||||
@@ -198,17 +201,14 @@ def make_everything_fastmcp() -> FastMCP:
|
||||
|
||||
def make_everything_fastmcp_app():
|
||||
"""Create a comprehensive FastMCP server with SSE transport."""
|
||||
from starlette.applications import Starlette
|
||||
|
||||
mcp = make_everything_fastmcp()
|
||||
# Create the SSE app
|
||||
app: Starlette = mcp.sse_app()
|
||||
app = mcp.sse_app()
|
||||
return mcp, app
|
||||
|
||||
|
||||
def make_fastmcp_streamable_http_app():
|
||||
"""Create a FastMCP server with StreamableHTTP transport."""
|
||||
from starlette.applications import Starlette
|
||||
|
||||
mcp = FastMCP(name="NoAuthServer")
|
||||
|
||||
@@ -225,8 +225,6 @@ def make_fastmcp_streamable_http_app():
|
||||
|
||||
def make_everything_fastmcp_streamable_http_app():
|
||||
"""Create a comprehensive FastMCP server with StreamableHTTP transport."""
|
||||
from starlette.applications import Starlette
|
||||
|
||||
# Create a new instance with different name for HTTP transport
|
||||
mcp = make_everything_fastmcp()
|
||||
# We can't change the name after creation, so we'll use the same name
|
||||
@@ -237,7 +235,6 @@ def make_everything_fastmcp_streamable_http_app():
|
||||
|
||||
def make_fastmcp_stateless_http_app():
|
||||
"""Create a FastMCP server with stateless StreamableHTTP transport."""
|
||||
from starlette.applications import Starlette
|
||||
|
||||
mcp = FastMCP(name="StatelessServer", stateless_http=True)
|
||||
|
||||
@@ -435,6 +432,174 @@ async def test_fastmcp_without_auth(server: None, server_url: str) -> None:
|
||||
assert tool_result.content[0].text == "Echo: hello"
|
||||
|
||||
|
||||
def make_fastmcp_with_context_app():
|
||||
"""Create a FastMCP server that can access request context."""
|
||||
|
||||
mcp = FastMCP(name="ContextServer")
|
||||
|
||||
# Tool that echoes request headers
|
||||
@mcp.tool(description="Echo request headers from context")
|
||||
def echo_headers(ctx: Context[Any, Any, Request]) -> str:
|
||||
"""Returns the request headers as JSON."""
|
||||
headers_info = {}
|
||||
if ctx.request_context.request:
|
||||
# Now the type system knows request is a Starlette Request object
|
||||
headers_info = dict(ctx.request_context.request.headers)
|
||||
return json.dumps(headers_info)
|
||||
|
||||
# Tool that returns full request context
|
||||
@mcp.tool(description="Echo request context with custom data")
|
||||
def echo_context(custom_request_id: str, ctx: Context[Any, Any, Request]) -> str:
|
||||
"""Returns request context including headers and custom data."""
|
||||
context_data = {
|
||||
"custom_request_id": custom_request_id,
|
||||
"headers": {},
|
||||
"method": None,
|
||||
"path": None,
|
||||
}
|
||||
if ctx.request_context.request:
|
||||
request = ctx.request_context.request
|
||||
context_data["headers"] = dict(request.headers)
|
||||
context_data["method"] = request.method
|
||||
context_data["path"] = request.url.path
|
||||
return json.dumps(context_data)
|
||||
|
||||
# Create the SSE app
|
||||
app = mcp.sse_app()
|
||||
return mcp, app
|
||||
|
||||
|
||||
def run_context_server(server_port: int) -> None:
|
||||
"""Run the context-aware FastMCP server."""
|
||||
_, app = make_fastmcp_with_context_app()
|
||||
server = uvicorn.Server(
|
||||
config=uvicorn.Config(
|
||||
app=app, host="127.0.0.1", port=server_port, log_level="error"
|
||||
)
|
||||
)
|
||||
print(f"Starting context server on port {server_port}")
|
||||
server.run()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def context_aware_server(server_port: int) -> Generator[None, None, None]:
|
||||
"""Start the context-aware server in a separate process."""
|
||||
proc = multiprocessing.Process(
|
||||
target=run_context_server, args=(server_port,), daemon=True
|
||||
)
|
||||
print("Starting context-aware server process")
|
||||
proc.start()
|
||||
|
||||
# Wait for server to be running
|
||||
max_attempts = 20
|
||||
attempt = 0
|
||||
print("Waiting for context-aware server to start")
|
||||
while attempt < max_attempts:
|
||||
try:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.connect(("127.0.0.1", server_port))
|
||||
break
|
||||
except ConnectionRefusedError:
|
||||
time.sleep(0.1)
|
||||
attempt += 1
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Context server failed to start after {max_attempts} attempts"
|
||||
)
|
||||
|
||||
yield
|
||||
|
||||
print("Killing context-aware server")
|
||||
proc.kill()
|
||||
proc.join(timeout=2)
|
||||
if proc.is_alive():
|
||||
print("Context server process failed to terminate")
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_fast_mcp_with_request_context(
|
||||
context_aware_server: None, server_url: str
|
||||
) -> None:
|
||||
"""Test that FastMCP properly propagates request context to tools."""
|
||||
# Test with custom headers
|
||||
custom_headers = {
|
||||
"Authorization": "Bearer fastmcp-test-token",
|
||||
"X-Custom-Header": "fastmcp-value",
|
||||
"X-Request-Id": "req-123",
|
||||
}
|
||||
|
||||
async with sse_client(server_url + "/sse", headers=custom_headers) as streams:
|
||||
async with ClientSession(*streams) as session:
|
||||
# Initialize the session
|
||||
result = await session.initialize()
|
||||
assert isinstance(result, InitializeResult)
|
||||
assert result.serverInfo.name == "ContextServer"
|
||||
|
||||
# Test 1: Call tool that echoes headers
|
||||
headers_result = await session.call_tool("echo_headers", {})
|
||||
assert len(headers_result.content) == 1
|
||||
assert isinstance(headers_result.content[0], TextContent)
|
||||
|
||||
headers_data = json.loads(headers_result.content[0].text)
|
||||
assert headers_data.get("authorization") == "Bearer fastmcp-test-token"
|
||||
assert headers_data.get("x-custom-header") == "fastmcp-value"
|
||||
assert headers_data.get("x-request-id") == "req-123"
|
||||
|
||||
# Test 2: Call tool that returns full context
|
||||
context_result = await session.call_tool(
|
||||
"echo_context", {"custom_request_id": "test-123"}
|
||||
)
|
||||
assert len(context_result.content) == 1
|
||||
assert isinstance(context_result.content[0], TextContent)
|
||||
|
||||
context_data = json.loads(context_result.content[0].text)
|
||||
assert context_data["custom_request_id"] == "test-123"
|
||||
assert (
|
||||
context_data["headers"].get("authorization")
|
||||
== "Bearer fastmcp-test-token"
|
||||
)
|
||||
assert context_data["method"] == "POST" #
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_fast_mcp_request_context_isolation(
|
||||
context_aware_server: None, server_url: str
|
||||
) -> None:
|
||||
"""Test that request contexts are isolated between different FastMCP clients."""
|
||||
contexts = []
|
||||
|
||||
# Create multiple clients with different headers
|
||||
for i in range(3):
|
||||
headers = {
|
||||
"Authorization": f"Bearer token-{i}",
|
||||
"X-Request-Id": f"fastmcp-req-{i}",
|
||||
"X-Custom-Value": f"value-{i}",
|
||||
}
|
||||
|
||||
async with sse_client(server_url + "/sse", headers=headers) as streams:
|
||||
async with ClientSession(*streams) as session:
|
||||
await session.initialize()
|
||||
|
||||
# Call the tool that returns context
|
||||
tool_result = await session.call_tool(
|
||||
"echo_context", {"custom_request_id": f"test-req-{i}"}
|
||||
)
|
||||
|
||||
# Parse and store the result
|
||||
assert len(tool_result.content) == 1
|
||||
assert isinstance(tool_result.content[0], TextContent)
|
||||
context_data = json.loads(tool_result.content[0].text)
|
||||
contexts.append(context_data)
|
||||
|
||||
# Verify each request had its own isolated context
|
||||
assert len(contexts) == 3
|
||||
for i, ctx in enumerate(contexts):
|
||||
assert ctx["custom_request_id"] == f"test-req-{i}"
|
||||
assert ctx["headers"].get("authorization") == f"Bearer token-{i}"
|
||||
assert ctx["headers"].get("x-request-id") == f"fastmcp-req-{i}"
|
||||
assert ctx["headers"].get("x-custom-value") == f"value-{i}"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_fastmcp_streamable_http(
|
||||
streamable_http_server: None, http_server_url: str
|
||||
|
||||
Reference in New Issue
Block a user