Support for http request injection propagation to tools (#816)

This commit is contained in:
ihrpr
2025-05-28 15:59:14 +01:00
committed by GitHub
parent 532b1176f9
commit 70014a2bbb
12 changed files with 413 additions and 35 deletions

View File

@@ -49,7 +49,7 @@ from mcp.server.sse import SseServerTransport
from mcp.server.stdio import stdio_server
from mcp.server.streamable_http import EventStore
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
from mcp.shared.context import LifespanContextT, RequestContext
from mcp.shared.context import LifespanContextT, RequestContext, RequestT
from mcp.types import (
AnyFunction,
EmbeddedResource,
@@ -124,9 +124,11 @@ class Settings(BaseSettings, Generic[LifespanResultT]):
def lifespan_wrapper(
app: FastMCP,
lifespan: Callable[[FastMCP], AbstractAsyncContextManager[LifespanResultT]],
) -> Callable[[MCPServer[LifespanResultT]], AbstractAsyncContextManager[object]]:
) -> Callable[
[MCPServer[LifespanResultT, Request]], AbstractAsyncContextManager[object]
]:
@asynccontextmanager
async def wrap(s: MCPServer[LifespanResultT]) -> AsyncIterator[object]:
async def wrap(s: MCPServer[LifespanResultT, Request]) -> AsyncIterator[object]:
async with lifespan(app) as context:
yield context
@@ -260,7 +262,7 @@ class FastMCP:
for info in tools
]
def get_context(self) -> Context[ServerSession, object]:
def get_context(self) -> Context[ServerSession, object, Request]:
"""
Returns a Context object. Note that the context will only be valid
during a request; outside a request, most methods will error.
@@ -893,7 +895,7 @@ def _convert_to_content(
return [TextContent(type="text", text=result)]
class Context(BaseModel, Generic[ServerSessionT, LifespanContextT]):
class Context(BaseModel, Generic[ServerSessionT, LifespanContextT, RequestT]):
"""Context object providing access to MCP capabilities.
This provides a cleaner interface to MCP's RequestContext functionality.
@@ -927,13 +929,15 @@ class Context(BaseModel, Generic[ServerSessionT, LifespanContextT]):
The context is optional - tools that don't need it can omit the parameter.
"""
_request_context: RequestContext[ServerSessionT, LifespanContextT] | None
_request_context: RequestContext[ServerSessionT, LifespanContextT, RequestT] | None
_fastmcp: FastMCP | None
def __init__(
self,
*,
request_context: RequestContext[ServerSessionT, LifespanContextT] | None = None,
request_context: (
RequestContext[ServerSessionT, LifespanContextT, RequestT] | None
) = None,
fastmcp: FastMCP | None = None,
**kwargs: Any,
):
@@ -949,7 +953,9 @@ class Context(BaseModel, Generic[ServerSessionT, LifespanContextT]):
return self._fastmcp
@property
def request_context(self) -> RequestContext[ServerSessionT, LifespanContextT]:
def request_context(
self,
) -> RequestContext[ServerSessionT, LifespanContextT, RequestT]:
"""Access to the underlying request context."""
if self._request_context is None:
raise ValueError("Context is not available outside of a request")

View File

@@ -14,7 +14,7 @@ from mcp.types import ToolAnnotations
if TYPE_CHECKING:
from mcp.server.fastmcp.server import Context
from mcp.server.session import ServerSessionT
from mcp.shared.context import LifespanContextT
from mcp.shared.context import LifespanContextT, RequestT
class Tool(BaseModel):
@@ -85,7 +85,7 @@ class Tool(BaseModel):
async def run(
self,
arguments: dict[str, Any],
context: Context[ServerSessionT, LifespanContextT] | None = None,
context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None,
) -> Any:
"""Run the tool with arguments."""
try:

View File

@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Any
from mcp.server.fastmcp.exceptions import ToolError
from mcp.server.fastmcp.tools.base import Tool
from mcp.server.fastmcp.utilities.logging import get_logger
from mcp.shared.context import LifespanContextT
from mcp.shared.context import LifespanContextT, RequestT
from mcp.types import ToolAnnotations
if TYPE_CHECKING:
@@ -65,7 +65,7 @@ class ToolManager:
self,
name: str,
arguments: dict[str, Any],
context: Context[ServerSessionT, LifespanContextT] | None = None,
context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None,
) -> Any:
"""Call a tool by name with arguments."""
tool = self.get_tool(name)

View File

@@ -72,11 +72,12 @@ import logging
import warnings
from collections.abc import AsyncIterator, Awaitable, Callable, Iterable
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
from typing import Any, Generic, TypeVar
from typing import Any, Generic
import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import AnyUrl
from typing_extensions import TypeVar
import mcp.types as types
from mcp.server.lowlevel.helper_types import ReadResourceContents
@@ -85,15 +86,16 @@ from mcp.server.session import ServerSession
from mcp.server.stdio import stdio_server as stdio_server
from mcp.shared.context import RequestContext
from mcp.shared.exceptions import McpError
from mcp.shared.message import SessionMessage
from mcp.shared.message import ServerMessageMetadata, SessionMessage
from mcp.shared.session import RequestResponder
logger = logging.getLogger(__name__)
LifespanResultT = TypeVar("LifespanResultT")
RequestT = TypeVar("RequestT", default=Any)
# This will be properly typed in each Server instance's context
request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any]] = (
request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any, Any]] = (
contextvars.ContextVar("request_ctx")
)
@@ -111,7 +113,7 @@ class NotificationOptions:
@asynccontextmanager
async def lifespan(server: Server[LifespanResultT]) -> AsyncIterator[object]:
async def lifespan(server: Server[LifespanResultT, RequestT]) -> AsyncIterator[object]:
"""Default lifespan context manager that does nothing.
Args:
@@ -123,14 +125,15 @@ async def lifespan(server: Server[LifespanResultT]) -> AsyncIterator[object]:
yield {}
class Server(Generic[LifespanResultT]):
class Server(Generic[LifespanResultT, RequestT]):
def __init__(
self,
name: str,
version: str | None = None,
instructions: str | None = None,
lifespan: Callable[
[Server[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]
[Server[LifespanResultT, RequestT]],
AbstractAsyncContextManager[LifespanResultT],
] = lifespan,
):
self.name = name
@@ -215,7 +218,9 @@ class Server(Generic[LifespanResultT]):
)
@property
def request_context(self) -> RequestContext[ServerSession, LifespanResultT]:
def request_context(
self,
) -> RequestContext[ServerSession, LifespanResultT, RequestT]:
"""If called outside of a request context, this will raise a LookupError."""
return request_ctx.get()
@@ -555,6 +560,13 @@ class Server(Generic[LifespanResultT]):
token = None
try:
# Extract request context from message metadata
request_data = None
if message.message_metadata is not None and isinstance(
message.message_metadata, ServerMessageMetadata
):
request_data = message.message_metadata.request_context
# Set our global state that can be retrieved via
# app.get_request_context()
token = request_ctx.set(
@@ -563,6 +575,7 @@ class Server(Generic[LifespanResultT]):
message.request_meta,
session,
lifespan_context,
request=request_data,
)
)
response = await handler(req)

View File

@@ -52,7 +52,7 @@ from starlette.responses import Response
from starlette.types import Receive, Scope, Send
import mcp.types as types
from mcp.shared.message import SessionMessage
from mcp.shared.message import ServerMessageMetadata, SessionMessage
logger = logging.getLogger(__name__)
@@ -203,7 +203,9 @@ class SseServerTransport:
await writer.send(err)
return
session_message = SessionMessage(message)
# Pass the ASGI scope for framework-agnostic access to request data
metadata = ServerMessageMetadata(request_context=request)
session_message = SessionMessage(message, metadata=metadata)
logger.debug(f"Sending session message to writer: {session_message}")
response = Response("Accepted", status_code=202)
await response(scope, receive, send)

View File

@@ -56,7 +56,7 @@ class StreamableHTTPSessionManager:
def __init__(
self,
app: MCPServer[Any],
app: MCPServer[Any, Any],
event_store: EventStore | None = None,
json_response: bool = False,
stateless: bool = False,

View File

@@ -8,11 +8,13 @@ from mcp.types import RequestId, RequestParams
SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any])
LifespanContextT = TypeVar("LifespanContextT")
RequestT = TypeVar("RequestT", default=Any)
@dataclass
class RequestContext(Generic[SessionT, LifespanContextT]):
class RequestContext(Generic[SessionT, LifespanContextT, RequestT]):
request_id: RequestId
meta: RequestParams.Meta | None
session: SessionT
lifespan_context: LifespanContextT
request: RequestT | None = None

View File

@@ -30,6 +30,8 @@ class ServerMessageMetadata:
"""Metadata specific to server messages."""
related_request_id: RequestId | None = None
# Request-specific context (e.g., headers, auth info)
request_context: object | None = None
MessageMetadata = ClientMessageMetadata | ServerMessageMetadata | None

View File

@@ -81,10 +81,12 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
ReceiveNotificationT
]""",
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any],
message_metadata: MessageMetadata = None,
) -> None:
self.request_id = request_id
self.request_meta = request_meta
self.request = request
self.message_metadata = message_metadata
self._session = session
self._completed = False
self._cancel_scope = anyio.CancelScope()
@@ -365,6 +367,7 @@ class BaseSession(
request=validated_request,
session=self,
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
message_metadata=message.metadata,
)
self._in_flight[responder.request_id] = responder

View File

@@ -5,14 +5,18 @@ These tests validate the proper functioning of FastMCP in various configurations
including with and without authentication.
"""
import json
import multiprocessing
import socket
import time
from collections.abc import Generator
from typing import Any
import pytest
import uvicorn
from pydantic import AnyUrl
from starlette.applications import Starlette
from starlette.requests import Request
import mcp.types as types
from mcp.client.session import ClientSession
@@ -20,6 +24,7 @@ from mcp.client.sse import sse_client
from mcp.client.streamable_http import streamablehttp_client
from mcp.server.fastmcp import FastMCP
from mcp.server.fastmcp.resources import FunctionResource
from mcp.server.fastmcp.server import Context
from mcp.shared.context import RequestContext
from mcp.types import (
CreateMessageRequestParams,
@@ -78,8 +83,6 @@ def stateless_http_server_url(stateless_http_server_port: int) -> str:
# Create a function to make the FastMCP server app
def make_fastmcp_app():
"""Create a FastMCP server without auth settings."""
from starlette.applications import Starlette
mcp = FastMCP(name="NoAuthServer")
# Add a simple tool
@@ -88,7 +91,7 @@ def make_fastmcp_app():
return f"Echo: {message}"
# Create the SSE app
app: Starlette = mcp.sse_app()
app = mcp.sse_app()
return mcp, app
@@ -198,17 +201,14 @@ def make_everything_fastmcp() -> FastMCP:
def make_everything_fastmcp_app():
"""Create a comprehensive FastMCP server with SSE transport."""
from starlette.applications import Starlette
mcp = make_everything_fastmcp()
# Create the SSE app
app: Starlette = mcp.sse_app()
app = mcp.sse_app()
return mcp, app
def make_fastmcp_streamable_http_app():
"""Create a FastMCP server with StreamableHTTP transport."""
from starlette.applications import Starlette
mcp = FastMCP(name="NoAuthServer")
@@ -225,8 +225,6 @@ def make_fastmcp_streamable_http_app():
def make_everything_fastmcp_streamable_http_app():
"""Create a comprehensive FastMCP server with StreamableHTTP transport."""
from starlette.applications import Starlette
# Create a new instance with different name for HTTP transport
mcp = make_everything_fastmcp()
# We can't change the name after creation, so we'll use the same name
@@ -237,7 +235,6 @@ def make_everything_fastmcp_streamable_http_app():
def make_fastmcp_stateless_http_app():
"""Create a FastMCP server with stateless StreamableHTTP transport."""
from starlette.applications import Starlette
mcp = FastMCP(name="StatelessServer", stateless_http=True)
@@ -435,6 +432,174 @@ async def test_fastmcp_without_auth(server: None, server_url: str) -> None:
assert tool_result.content[0].text == "Echo: hello"
def make_fastmcp_with_context_app():
"""Create a FastMCP server that can access request context."""
mcp = FastMCP(name="ContextServer")
# Tool that echoes request headers
@mcp.tool(description="Echo request headers from context")
def echo_headers(ctx: Context[Any, Any, Request]) -> str:
"""Returns the request headers as JSON."""
headers_info = {}
if ctx.request_context.request:
# Now the type system knows request is a Starlette Request object
headers_info = dict(ctx.request_context.request.headers)
return json.dumps(headers_info)
# Tool that returns full request context
@mcp.tool(description="Echo request context with custom data")
def echo_context(custom_request_id: str, ctx: Context[Any, Any, Request]) -> str:
"""Returns request context including headers and custom data."""
context_data = {
"custom_request_id": custom_request_id,
"headers": {},
"method": None,
"path": None,
}
if ctx.request_context.request:
request = ctx.request_context.request
context_data["headers"] = dict(request.headers)
context_data["method"] = request.method
context_data["path"] = request.url.path
return json.dumps(context_data)
# Create the SSE app
app = mcp.sse_app()
return mcp, app
def run_context_server(server_port: int) -> None:
"""Run the context-aware FastMCP server."""
_, app = make_fastmcp_with_context_app()
server = uvicorn.Server(
config=uvicorn.Config(
app=app, host="127.0.0.1", port=server_port, log_level="error"
)
)
print(f"Starting context server on port {server_port}")
server.run()
@pytest.fixture()
def context_aware_server(server_port: int) -> Generator[None, None, None]:
"""Start the context-aware server in a separate process."""
proc = multiprocessing.Process(
target=run_context_server, args=(server_port,), daemon=True
)
print("Starting context-aware server process")
proc.start()
# Wait for server to be running
max_attempts = 20
attempt = 0
print("Waiting for context-aware server to start")
while attempt < max_attempts:
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.connect(("127.0.0.1", server_port))
break
except ConnectionRefusedError:
time.sleep(0.1)
attempt += 1
else:
raise RuntimeError(
f"Context server failed to start after {max_attempts} attempts"
)
yield
print("Killing context-aware server")
proc.kill()
proc.join(timeout=2)
if proc.is_alive():
print("Context server process failed to terminate")
@pytest.mark.anyio
async def test_fast_mcp_with_request_context(
context_aware_server: None, server_url: str
) -> None:
"""Test that FastMCP properly propagates request context to tools."""
# Test with custom headers
custom_headers = {
"Authorization": "Bearer fastmcp-test-token",
"X-Custom-Header": "fastmcp-value",
"X-Request-Id": "req-123",
}
async with sse_client(server_url + "/sse", headers=custom_headers) as streams:
async with ClientSession(*streams) as session:
# Initialize the session
result = await session.initialize()
assert isinstance(result, InitializeResult)
assert result.serverInfo.name == "ContextServer"
# Test 1: Call tool that echoes headers
headers_result = await session.call_tool("echo_headers", {})
assert len(headers_result.content) == 1
assert isinstance(headers_result.content[0], TextContent)
headers_data = json.loads(headers_result.content[0].text)
assert headers_data.get("authorization") == "Bearer fastmcp-test-token"
assert headers_data.get("x-custom-header") == "fastmcp-value"
assert headers_data.get("x-request-id") == "req-123"
# Test 2: Call tool that returns full context
context_result = await session.call_tool(
"echo_context", {"custom_request_id": "test-123"}
)
assert len(context_result.content) == 1
assert isinstance(context_result.content[0], TextContent)
context_data = json.loads(context_result.content[0].text)
assert context_data["custom_request_id"] == "test-123"
assert (
context_data["headers"].get("authorization")
== "Bearer fastmcp-test-token"
)
assert context_data["method"] == "POST" #
@pytest.mark.anyio
async def test_fast_mcp_request_context_isolation(
context_aware_server: None, server_url: str
) -> None:
"""Test that request contexts are isolated between different FastMCP clients."""
contexts = []
# Create multiple clients with different headers
for i in range(3):
headers = {
"Authorization": f"Bearer token-{i}",
"X-Request-Id": f"fastmcp-req-{i}",
"X-Custom-Value": f"value-{i}",
}
async with sse_client(server_url + "/sse", headers=headers) as streams:
async with ClientSession(*streams) as session:
await session.initialize()
# Call the tool that returns context
tool_result = await session.call_tool(
"echo_context", {"custom_request_id": f"test-req-{i}"}
)
# Parse and store the result
assert len(tool_result.content) == 1
assert isinstance(tool_result.content[0], TextContent)
context_data = json.loads(tool_result.content[0].text)
contexts.append(context_data)
# Verify each request had its own isolated context
assert len(contexts) == 3
for i, ctx in enumerate(contexts):
assert ctx["custom_request_id"] == f"test-req-{i}"
assert ctx["headers"].get("authorization") == f"Bearer token-{i}"
assert ctx["headers"].get("x-request-id") == f"fastmcp-req-{i}"
assert ctx["headers"].get("x-custom-value") == f"value-{i}"
@pytest.mark.anyio
async def test_fastmcp_streamable_http(
streamable_http_server: None, http_server_url: str

View File

@@ -9,7 +9,7 @@ from mcp.server.fastmcp.exceptions import ToolError
from mcp.server.fastmcp.tools import Tool, ToolManager
from mcp.server.fastmcp.utilities.func_metadata import ArgModelBase, FuncMetadata
from mcp.server.session import ServerSessionT
from mcp.shared.context import LifespanContextT
from mcp.shared.context import LifespanContextT, RequestT
from mcp.types import ToolAnnotations
@@ -347,7 +347,7 @@ class TestContextHandling:
assert tool.context_kwarg is None
def tool_with_parametrized_context(
x: int, ctx: Context[ServerSessionT, LifespanContextT]
x: int, ctx: Context[ServerSessionT, LifespanContextT, RequestT]
) -> str:
return str(x)

View File

@@ -1,3 +1,4 @@
import json
import multiprocessing
import socket
import time
@@ -318,3 +319,187 @@ async def test_sse_client_basic_connection_mounted_app(
# Test ping
ping_result = await session.send_ping()
assert isinstance(ping_result, EmptyResult)
# Test server with request context that returns headers in the response
class RequestContextServer(Server[object, Request]):
def __init__(self):
super().__init__("request_context_server")
@self.call_tool()
async def handle_call_tool(name: str, args: dict) -> list[TextContent]:
headers_info = {}
context = self.request_context
if context.request:
headers_info = dict(context.request.headers)
if name == "echo_headers":
return [TextContent(type="text", text=json.dumps(headers_info))]
elif name == "echo_context":
context_data = {
"request_id": args.get("request_id"),
"headers": headers_info,
}
return [TextContent(type="text", text=json.dumps(context_data))]
return [TextContent(type="text", text=f"Called {name}")]
@self.list_tools()
async def handle_list_tools() -> list[Tool]:
return [
Tool(
name="echo_headers",
description="Echoes request headers",
inputSchema={"type": "object", "properties": {}},
),
Tool(
name="echo_context",
description="Echoes request context",
inputSchema={
"type": "object",
"properties": {"request_id": {"type": "string"}},
"required": ["request_id"],
},
),
]
def run_context_server(server_port: int) -> None:
"""Run a server that captures request context"""
sse = SseServerTransport("/messages/")
context_server = RequestContextServer()
async def handle_sse(request: Request) -> Response:
async with sse.connect_sse(
request.scope, request.receive, request._send
) as streams:
await context_server.run(
streams[0], streams[1], context_server.create_initialization_options()
)
return Response()
app = Starlette(
routes=[
Route("/sse", endpoint=handle_sse),
Mount("/messages/", app=sse.handle_post_message),
]
)
server = uvicorn.Server(
config=uvicorn.Config(
app=app, host="127.0.0.1", port=server_port, log_level="error"
)
)
print(f"starting context server on {server_port}")
server.run()
@pytest.fixture()
def context_server(server_port: int) -> Generator[None, None, None]:
"""Fixture that provides a server with request context capture"""
proc = multiprocessing.Process(
target=run_context_server, kwargs={"server_port": server_port}, daemon=True
)
print("starting context server process")
proc.start()
# Wait for server to be running
max_attempts = 20
attempt = 0
print("waiting for context server to start")
while attempt < max_attempts:
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.connect(("127.0.0.1", server_port))
break
except ConnectionRefusedError:
time.sleep(0.1)
attempt += 1
else:
raise RuntimeError(
f"Context server failed to start after {max_attempts} attempts"
)
yield
print("killing context server")
proc.kill()
proc.join(timeout=2)
if proc.is_alive():
print("context server process failed to terminate")
@pytest.mark.anyio
async def test_request_context_propagation(
context_server: None, server_url: str
) -> None:
"""Test that request context is properly propagated through SSE transport."""
# Test with custom headers
custom_headers = {
"Authorization": "Bearer test-token",
"X-Custom-Header": "test-value",
"X-Trace-Id": "trace-123",
}
async with sse_client(server_url + "/sse", headers=custom_headers) as (
read_stream,
write_stream,
):
async with ClientSession(read_stream, write_stream) as session:
# Initialize the session
result = await session.initialize()
assert isinstance(result, InitializeResult)
# Call the tool that echoes headers back
tool_result = await session.call_tool("echo_headers", {})
# Parse the JSON response
assert len(tool_result.content) == 1
headers_data = json.loads(
tool_result.content[0].text
if tool_result.content[0].type == "text"
else "{}"
)
# Verify headers were propagated
assert headers_data.get("authorization") == "Bearer test-token"
assert headers_data.get("x-custom-header") == "test-value"
assert headers_data.get("x-trace-id") == "trace-123"
@pytest.mark.anyio
async def test_request_context_isolation(context_server: None, server_url: str) -> None:
"""Test that request contexts are isolated between different SSE clients."""
contexts = []
# Create multiple clients with different headers
for i in range(3):
headers = {"X-Request-Id": f"request-{i}", "X-Custom-Value": f"value-{i}"}
async with sse_client(server_url + "/sse", headers=headers) as (
read_stream,
write_stream,
):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()
# Call the tool that echoes context
tool_result = await session.call_tool(
"echo_context", {"request_id": f"request-{i}"}
)
assert len(tool_result.content) == 1
context_data = json.loads(
tool_result.content[0].text
if tool_result.content[0].type == "text"
else "{}"
)
contexts.append(context_data)
# Verify each request had its own context
assert len(contexts) == 3
for i, ctx in enumerate(contexts):
assert ctx["request_id"] == f"request-{i}"
assert ctx["headers"].get("x-request-id") == f"request-{i}"
assert ctx["headers"].get("x-custom-value") == f"value-{i}"