Support for http request injection propagation to tools (#816)

This commit is contained in:
ihrpr
2025-05-28 15:59:14 +01:00
committed by GitHub
parent 532b1176f9
commit 70014a2bbb
12 changed files with 413 additions and 35 deletions

View File

@@ -49,7 +49,7 @@ from mcp.server.sse import SseServerTransport
from mcp.server.stdio import stdio_server from mcp.server.stdio import stdio_server
from mcp.server.streamable_http import EventStore from mcp.server.streamable_http import EventStore
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
from mcp.shared.context import LifespanContextT, RequestContext from mcp.shared.context import LifespanContextT, RequestContext, RequestT
from mcp.types import ( from mcp.types import (
AnyFunction, AnyFunction,
EmbeddedResource, EmbeddedResource,
@@ -124,9 +124,11 @@ class Settings(BaseSettings, Generic[LifespanResultT]):
def lifespan_wrapper( def lifespan_wrapper(
app: FastMCP, app: FastMCP,
lifespan: Callable[[FastMCP], AbstractAsyncContextManager[LifespanResultT]], lifespan: Callable[[FastMCP], AbstractAsyncContextManager[LifespanResultT]],
) -> Callable[[MCPServer[LifespanResultT]], AbstractAsyncContextManager[object]]: ) -> Callable[
[MCPServer[LifespanResultT, Request]], AbstractAsyncContextManager[object]
]:
@asynccontextmanager @asynccontextmanager
async def wrap(s: MCPServer[LifespanResultT]) -> AsyncIterator[object]: async def wrap(s: MCPServer[LifespanResultT, Request]) -> AsyncIterator[object]:
async with lifespan(app) as context: async with lifespan(app) as context:
yield context yield context
@@ -260,7 +262,7 @@ class FastMCP:
for info in tools for info in tools
] ]
def get_context(self) -> Context[ServerSession, object]: def get_context(self) -> Context[ServerSession, object, Request]:
""" """
Returns a Context object. Note that the context will only be valid Returns a Context object. Note that the context will only be valid
during a request; outside a request, most methods will error. during a request; outside a request, most methods will error.
@@ -893,7 +895,7 @@ def _convert_to_content(
return [TextContent(type="text", text=result)] return [TextContent(type="text", text=result)]
class Context(BaseModel, Generic[ServerSessionT, LifespanContextT]): class Context(BaseModel, Generic[ServerSessionT, LifespanContextT, RequestT]):
"""Context object providing access to MCP capabilities. """Context object providing access to MCP capabilities.
This provides a cleaner interface to MCP's RequestContext functionality. This provides a cleaner interface to MCP's RequestContext functionality.
@@ -927,13 +929,15 @@ class Context(BaseModel, Generic[ServerSessionT, LifespanContextT]):
The context is optional - tools that don't need it can omit the parameter. The context is optional - tools that don't need it can omit the parameter.
""" """
_request_context: RequestContext[ServerSessionT, LifespanContextT] | None _request_context: RequestContext[ServerSessionT, LifespanContextT, RequestT] | None
_fastmcp: FastMCP | None _fastmcp: FastMCP | None
def __init__( def __init__(
self, self,
*, *,
request_context: RequestContext[ServerSessionT, LifespanContextT] | None = None, request_context: (
RequestContext[ServerSessionT, LifespanContextT, RequestT] | None
) = None,
fastmcp: FastMCP | None = None, fastmcp: FastMCP | None = None,
**kwargs: Any, **kwargs: Any,
): ):
@@ -949,7 +953,9 @@ class Context(BaseModel, Generic[ServerSessionT, LifespanContextT]):
return self._fastmcp return self._fastmcp
@property @property
def request_context(self) -> RequestContext[ServerSessionT, LifespanContextT]: def request_context(
self,
) -> RequestContext[ServerSessionT, LifespanContextT, RequestT]:
"""Access to the underlying request context.""" """Access to the underlying request context."""
if self._request_context is None: if self._request_context is None:
raise ValueError("Context is not available outside of a request") raise ValueError("Context is not available outside of a request")

View File

@@ -14,7 +14,7 @@ from mcp.types import ToolAnnotations
if TYPE_CHECKING: if TYPE_CHECKING:
from mcp.server.fastmcp.server import Context from mcp.server.fastmcp.server import Context
from mcp.server.session import ServerSessionT from mcp.server.session import ServerSessionT
from mcp.shared.context import LifespanContextT from mcp.shared.context import LifespanContextT, RequestT
class Tool(BaseModel): class Tool(BaseModel):
@@ -85,7 +85,7 @@ class Tool(BaseModel):
async def run( async def run(
self, self,
arguments: dict[str, Any], arguments: dict[str, Any],
context: Context[ServerSessionT, LifespanContextT] | None = None, context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None,
) -> Any: ) -> Any:
"""Run the tool with arguments.""" """Run the tool with arguments."""
try: try:

View File

@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Any
from mcp.server.fastmcp.exceptions import ToolError from mcp.server.fastmcp.exceptions import ToolError
from mcp.server.fastmcp.tools.base import Tool from mcp.server.fastmcp.tools.base import Tool
from mcp.server.fastmcp.utilities.logging import get_logger from mcp.server.fastmcp.utilities.logging import get_logger
from mcp.shared.context import LifespanContextT from mcp.shared.context import LifespanContextT, RequestT
from mcp.types import ToolAnnotations from mcp.types import ToolAnnotations
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -65,7 +65,7 @@ class ToolManager:
self, self,
name: str, name: str,
arguments: dict[str, Any], arguments: dict[str, Any],
context: Context[ServerSessionT, LifespanContextT] | None = None, context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None,
) -> Any: ) -> Any:
"""Call a tool by name with arguments.""" """Call a tool by name with arguments."""
tool = self.get_tool(name) tool = self.get_tool(name)

View File

@@ -72,11 +72,12 @@ import logging
import warnings import warnings
from collections.abc import AsyncIterator, Awaitable, Callable, Iterable from collections.abc import AsyncIterator, Awaitable, Callable, Iterable
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
from typing import Any, Generic, TypeVar from typing import Any, Generic
import anyio import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import AnyUrl from pydantic import AnyUrl
from typing_extensions import TypeVar
import mcp.types as types import mcp.types as types
from mcp.server.lowlevel.helper_types import ReadResourceContents from mcp.server.lowlevel.helper_types import ReadResourceContents
@@ -85,15 +86,16 @@ from mcp.server.session import ServerSession
from mcp.server.stdio import stdio_server as stdio_server from mcp.server.stdio import stdio_server as stdio_server
from mcp.shared.context import RequestContext from mcp.shared.context import RequestContext
from mcp.shared.exceptions import McpError from mcp.shared.exceptions import McpError
from mcp.shared.message import SessionMessage from mcp.shared.message import ServerMessageMetadata, SessionMessage
from mcp.shared.session import RequestResponder from mcp.shared.session import RequestResponder
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
LifespanResultT = TypeVar("LifespanResultT") LifespanResultT = TypeVar("LifespanResultT")
RequestT = TypeVar("RequestT", default=Any)
# This will be properly typed in each Server instance's context # This will be properly typed in each Server instance's context
request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any]] = ( request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any, Any]] = (
contextvars.ContextVar("request_ctx") contextvars.ContextVar("request_ctx")
) )
@@ -111,7 +113,7 @@ class NotificationOptions:
@asynccontextmanager @asynccontextmanager
async def lifespan(server: Server[LifespanResultT]) -> AsyncIterator[object]: async def lifespan(server: Server[LifespanResultT, RequestT]) -> AsyncIterator[object]:
"""Default lifespan context manager that does nothing. """Default lifespan context manager that does nothing.
Args: Args:
@@ -123,14 +125,15 @@ async def lifespan(server: Server[LifespanResultT]) -> AsyncIterator[object]:
yield {} yield {}
class Server(Generic[LifespanResultT]): class Server(Generic[LifespanResultT, RequestT]):
def __init__( def __init__(
self, self,
name: str, name: str,
version: str | None = None, version: str | None = None,
instructions: str | None = None, instructions: str | None = None,
lifespan: Callable[ lifespan: Callable[
[Server[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT] [Server[LifespanResultT, RequestT]],
AbstractAsyncContextManager[LifespanResultT],
] = lifespan, ] = lifespan,
): ):
self.name = name self.name = name
@@ -215,7 +218,9 @@ class Server(Generic[LifespanResultT]):
) )
@property @property
def request_context(self) -> RequestContext[ServerSession, LifespanResultT]: def request_context(
self,
) -> RequestContext[ServerSession, LifespanResultT, RequestT]:
"""If called outside of a request context, this will raise a LookupError.""" """If called outside of a request context, this will raise a LookupError."""
return request_ctx.get() return request_ctx.get()
@@ -555,6 +560,13 @@ class Server(Generic[LifespanResultT]):
token = None token = None
try: try:
# Extract request context from message metadata
request_data = None
if message.message_metadata is not None and isinstance(
message.message_metadata, ServerMessageMetadata
):
request_data = message.message_metadata.request_context
# Set our global state that can be retrieved via # Set our global state that can be retrieved via
# app.get_request_context() # app.get_request_context()
token = request_ctx.set( token = request_ctx.set(
@@ -563,6 +575,7 @@ class Server(Generic[LifespanResultT]):
message.request_meta, message.request_meta,
session, session,
lifespan_context, lifespan_context,
request=request_data,
) )
) )
response = await handler(req) response = await handler(req)

View File

@@ -52,7 +52,7 @@ from starlette.responses import Response
from starlette.types import Receive, Scope, Send from starlette.types import Receive, Scope, Send
import mcp.types as types import mcp.types as types
from mcp.shared.message import SessionMessage from mcp.shared.message import ServerMessageMetadata, SessionMessage
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -203,7 +203,9 @@ class SseServerTransport:
await writer.send(err) await writer.send(err)
return return
session_message = SessionMessage(message) # Pass the ASGI scope for framework-agnostic access to request data
metadata = ServerMessageMetadata(request_context=request)
session_message = SessionMessage(message, metadata=metadata)
logger.debug(f"Sending session message to writer: {session_message}") logger.debug(f"Sending session message to writer: {session_message}")
response = Response("Accepted", status_code=202) response = Response("Accepted", status_code=202)
await response(scope, receive, send) await response(scope, receive, send)

View File

@@ -56,7 +56,7 @@ class StreamableHTTPSessionManager:
def __init__( def __init__(
self, self,
app: MCPServer[Any], app: MCPServer[Any, Any],
event_store: EventStore | None = None, event_store: EventStore | None = None,
json_response: bool = False, json_response: bool = False,
stateless: bool = False, stateless: bool = False,

View File

@@ -8,11 +8,13 @@ from mcp.types import RequestId, RequestParams
SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any]) SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any])
LifespanContextT = TypeVar("LifespanContextT") LifespanContextT = TypeVar("LifespanContextT")
RequestT = TypeVar("RequestT", default=Any)
@dataclass @dataclass
class RequestContext(Generic[SessionT, LifespanContextT]): class RequestContext(Generic[SessionT, LifespanContextT, RequestT]):
request_id: RequestId request_id: RequestId
meta: RequestParams.Meta | None meta: RequestParams.Meta | None
session: SessionT session: SessionT
lifespan_context: LifespanContextT lifespan_context: LifespanContextT
request: RequestT | None = None

View File

@@ -30,6 +30,8 @@ class ServerMessageMetadata:
"""Metadata specific to server messages.""" """Metadata specific to server messages."""
related_request_id: RequestId | None = None related_request_id: RequestId | None = None
# Request-specific context (e.g., headers, auth info)
request_context: object | None = None
MessageMetadata = ClientMessageMetadata | ServerMessageMetadata | None MessageMetadata = ClientMessageMetadata | ServerMessageMetadata | None

View File

@@ -81,10 +81,12 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
ReceiveNotificationT ReceiveNotificationT
]""", ]""",
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any], on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any],
message_metadata: MessageMetadata = None,
) -> None: ) -> None:
self.request_id = request_id self.request_id = request_id
self.request_meta = request_meta self.request_meta = request_meta
self.request = request self.request = request
self.message_metadata = message_metadata
self._session = session self._session = session
self._completed = False self._completed = False
self._cancel_scope = anyio.CancelScope() self._cancel_scope = anyio.CancelScope()
@@ -365,6 +367,7 @@ class BaseSession(
request=validated_request, request=validated_request,
session=self, session=self,
on_complete=lambda r: self._in_flight.pop(r.request_id, None), on_complete=lambda r: self._in_flight.pop(r.request_id, None),
message_metadata=message.metadata,
) )
self._in_flight[responder.request_id] = responder self._in_flight[responder.request_id] = responder

View File

@@ -5,14 +5,18 @@ These tests validate the proper functioning of FastMCP in various configurations
including with and without authentication. including with and without authentication.
""" """
import json
import multiprocessing import multiprocessing
import socket import socket
import time import time
from collections.abc import Generator from collections.abc import Generator
from typing import Any
import pytest import pytest
import uvicorn import uvicorn
from pydantic import AnyUrl from pydantic import AnyUrl
from starlette.applications import Starlette
from starlette.requests import Request
import mcp.types as types import mcp.types as types
from mcp.client.session import ClientSession from mcp.client.session import ClientSession
@@ -20,6 +24,7 @@ from mcp.client.sse import sse_client
from mcp.client.streamable_http import streamablehttp_client from mcp.client.streamable_http import streamablehttp_client
from mcp.server.fastmcp import FastMCP from mcp.server.fastmcp import FastMCP
from mcp.server.fastmcp.resources import FunctionResource from mcp.server.fastmcp.resources import FunctionResource
from mcp.server.fastmcp.server import Context
from mcp.shared.context import RequestContext from mcp.shared.context import RequestContext
from mcp.types import ( from mcp.types import (
CreateMessageRequestParams, CreateMessageRequestParams,
@@ -78,8 +83,6 @@ def stateless_http_server_url(stateless_http_server_port: int) -> str:
# Create a function to make the FastMCP server app # Create a function to make the FastMCP server app
def make_fastmcp_app(): def make_fastmcp_app():
"""Create a FastMCP server without auth settings.""" """Create a FastMCP server without auth settings."""
from starlette.applications import Starlette
mcp = FastMCP(name="NoAuthServer") mcp = FastMCP(name="NoAuthServer")
# Add a simple tool # Add a simple tool
@@ -88,7 +91,7 @@ def make_fastmcp_app():
return f"Echo: {message}" return f"Echo: {message}"
# Create the SSE app # Create the SSE app
app: Starlette = mcp.sse_app() app = mcp.sse_app()
return mcp, app return mcp, app
@@ -198,17 +201,14 @@ def make_everything_fastmcp() -> FastMCP:
def make_everything_fastmcp_app(): def make_everything_fastmcp_app():
"""Create a comprehensive FastMCP server with SSE transport.""" """Create a comprehensive FastMCP server with SSE transport."""
from starlette.applications import Starlette
mcp = make_everything_fastmcp() mcp = make_everything_fastmcp()
# Create the SSE app # Create the SSE app
app: Starlette = mcp.sse_app() app = mcp.sse_app()
return mcp, app return mcp, app
def make_fastmcp_streamable_http_app(): def make_fastmcp_streamable_http_app():
"""Create a FastMCP server with StreamableHTTP transport.""" """Create a FastMCP server with StreamableHTTP transport."""
from starlette.applications import Starlette
mcp = FastMCP(name="NoAuthServer") mcp = FastMCP(name="NoAuthServer")
@@ -225,8 +225,6 @@ def make_fastmcp_streamable_http_app():
def make_everything_fastmcp_streamable_http_app(): def make_everything_fastmcp_streamable_http_app():
"""Create a comprehensive FastMCP server with StreamableHTTP transport.""" """Create a comprehensive FastMCP server with StreamableHTTP transport."""
from starlette.applications import Starlette
# Create a new instance with different name for HTTP transport # Create a new instance with different name for HTTP transport
mcp = make_everything_fastmcp() mcp = make_everything_fastmcp()
# We can't change the name after creation, so we'll use the same name # We can't change the name after creation, so we'll use the same name
@@ -237,7 +235,6 @@ def make_everything_fastmcp_streamable_http_app():
def make_fastmcp_stateless_http_app(): def make_fastmcp_stateless_http_app():
"""Create a FastMCP server with stateless StreamableHTTP transport.""" """Create a FastMCP server with stateless StreamableHTTP transport."""
from starlette.applications import Starlette
mcp = FastMCP(name="StatelessServer", stateless_http=True) mcp = FastMCP(name="StatelessServer", stateless_http=True)
@@ -435,6 +432,174 @@ async def test_fastmcp_without_auth(server: None, server_url: str) -> None:
assert tool_result.content[0].text == "Echo: hello" assert tool_result.content[0].text == "Echo: hello"
def make_fastmcp_with_context_app():
"""Create a FastMCP server that can access request context."""
mcp = FastMCP(name="ContextServer")
# Tool that echoes request headers
@mcp.tool(description="Echo request headers from context")
def echo_headers(ctx: Context[Any, Any, Request]) -> str:
"""Returns the request headers as JSON."""
headers_info = {}
if ctx.request_context.request:
# Now the type system knows request is a Starlette Request object
headers_info = dict(ctx.request_context.request.headers)
return json.dumps(headers_info)
# Tool that returns full request context
@mcp.tool(description="Echo request context with custom data")
def echo_context(custom_request_id: str, ctx: Context[Any, Any, Request]) -> str:
"""Returns request context including headers and custom data."""
context_data = {
"custom_request_id": custom_request_id,
"headers": {},
"method": None,
"path": None,
}
if ctx.request_context.request:
request = ctx.request_context.request
context_data["headers"] = dict(request.headers)
context_data["method"] = request.method
context_data["path"] = request.url.path
return json.dumps(context_data)
# Create the SSE app
app = mcp.sse_app()
return mcp, app
def run_context_server(server_port: int) -> None:
"""Run the context-aware FastMCP server."""
_, app = make_fastmcp_with_context_app()
server = uvicorn.Server(
config=uvicorn.Config(
app=app, host="127.0.0.1", port=server_port, log_level="error"
)
)
print(f"Starting context server on port {server_port}")
server.run()
@pytest.fixture()
def context_aware_server(server_port: int) -> Generator[None, None, None]:
"""Start the context-aware server in a separate process."""
proc = multiprocessing.Process(
target=run_context_server, args=(server_port,), daemon=True
)
print("Starting context-aware server process")
proc.start()
# Wait for server to be running
max_attempts = 20
attempt = 0
print("Waiting for context-aware server to start")
while attempt < max_attempts:
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.connect(("127.0.0.1", server_port))
break
except ConnectionRefusedError:
time.sleep(0.1)
attempt += 1
else:
raise RuntimeError(
f"Context server failed to start after {max_attempts} attempts"
)
yield
print("Killing context-aware server")
proc.kill()
proc.join(timeout=2)
if proc.is_alive():
print("Context server process failed to terminate")
@pytest.mark.anyio
async def test_fast_mcp_with_request_context(
context_aware_server: None, server_url: str
) -> None:
"""Test that FastMCP properly propagates request context to tools."""
# Test with custom headers
custom_headers = {
"Authorization": "Bearer fastmcp-test-token",
"X-Custom-Header": "fastmcp-value",
"X-Request-Id": "req-123",
}
async with sse_client(server_url + "/sse", headers=custom_headers) as streams:
async with ClientSession(*streams) as session:
# Initialize the session
result = await session.initialize()
assert isinstance(result, InitializeResult)
assert result.serverInfo.name == "ContextServer"
# Test 1: Call tool that echoes headers
headers_result = await session.call_tool("echo_headers", {})
assert len(headers_result.content) == 1
assert isinstance(headers_result.content[0], TextContent)
headers_data = json.loads(headers_result.content[0].text)
assert headers_data.get("authorization") == "Bearer fastmcp-test-token"
assert headers_data.get("x-custom-header") == "fastmcp-value"
assert headers_data.get("x-request-id") == "req-123"
# Test 2: Call tool that returns full context
context_result = await session.call_tool(
"echo_context", {"custom_request_id": "test-123"}
)
assert len(context_result.content) == 1
assert isinstance(context_result.content[0], TextContent)
context_data = json.loads(context_result.content[0].text)
assert context_data["custom_request_id"] == "test-123"
assert (
context_data["headers"].get("authorization")
== "Bearer fastmcp-test-token"
)
assert context_data["method"] == "POST" #
@pytest.mark.anyio
async def test_fast_mcp_request_context_isolation(
context_aware_server: None, server_url: str
) -> None:
"""Test that request contexts are isolated between different FastMCP clients."""
contexts = []
# Create multiple clients with different headers
for i in range(3):
headers = {
"Authorization": f"Bearer token-{i}",
"X-Request-Id": f"fastmcp-req-{i}",
"X-Custom-Value": f"value-{i}",
}
async with sse_client(server_url + "/sse", headers=headers) as streams:
async with ClientSession(*streams) as session:
await session.initialize()
# Call the tool that returns context
tool_result = await session.call_tool(
"echo_context", {"custom_request_id": f"test-req-{i}"}
)
# Parse and store the result
assert len(tool_result.content) == 1
assert isinstance(tool_result.content[0], TextContent)
context_data = json.loads(tool_result.content[0].text)
contexts.append(context_data)
# Verify each request had its own isolated context
assert len(contexts) == 3
for i, ctx in enumerate(contexts):
assert ctx["custom_request_id"] == f"test-req-{i}"
assert ctx["headers"].get("authorization") == f"Bearer token-{i}"
assert ctx["headers"].get("x-request-id") == f"fastmcp-req-{i}"
assert ctx["headers"].get("x-custom-value") == f"value-{i}"
@pytest.mark.anyio @pytest.mark.anyio
async def test_fastmcp_streamable_http( async def test_fastmcp_streamable_http(
streamable_http_server: None, http_server_url: str streamable_http_server: None, http_server_url: str

View File

@@ -9,7 +9,7 @@ from mcp.server.fastmcp.exceptions import ToolError
from mcp.server.fastmcp.tools import Tool, ToolManager from mcp.server.fastmcp.tools import Tool, ToolManager
from mcp.server.fastmcp.utilities.func_metadata import ArgModelBase, FuncMetadata from mcp.server.fastmcp.utilities.func_metadata import ArgModelBase, FuncMetadata
from mcp.server.session import ServerSessionT from mcp.server.session import ServerSessionT
from mcp.shared.context import LifespanContextT from mcp.shared.context import LifespanContextT, RequestT
from mcp.types import ToolAnnotations from mcp.types import ToolAnnotations
@@ -347,7 +347,7 @@ class TestContextHandling:
assert tool.context_kwarg is None assert tool.context_kwarg is None
def tool_with_parametrized_context( def tool_with_parametrized_context(
x: int, ctx: Context[ServerSessionT, LifespanContextT] x: int, ctx: Context[ServerSessionT, LifespanContextT, RequestT]
) -> str: ) -> str:
return str(x) return str(x)

View File

@@ -1,3 +1,4 @@
import json
import multiprocessing import multiprocessing
import socket import socket
import time import time
@@ -318,3 +319,187 @@ async def test_sse_client_basic_connection_mounted_app(
# Test ping # Test ping
ping_result = await session.send_ping() ping_result = await session.send_ping()
assert isinstance(ping_result, EmptyResult) assert isinstance(ping_result, EmptyResult)
# Test server with request context that returns headers in the response
class RequestContextServer(Server[object, Request]):
def __init__(self):
super().__init__("request_context_server")
@self.call_tool()
async def handle_call_tool(name: str, args: dict) -> list[TextContent]:
headers_info = {}
context = self.request_context
if context.request:
headers_info = dict(context.request.headers)
if name == "echo_headers":
return [TextContent(type="text", text=json.dumps(headers_info))]
elif name == "echo_context":
context_data = {
"request_id": args.get("request_id"),
"headers": headers_info,
}
return [TextContent(type="text", text=json.dumps(context_data))]
return [TextContent(type="text", text=f"Called {name}")]
@self.list_tools()
async def handle_list_tools() -> list[Tool]:
return [
Tool(
name="echo_headers",
description="Echoes request headers",
inputSchema={"type": "object", "properties": {}},
),
Tool(
name="echo_context",
description="Echoes request context",
inputSchema={
"type": "object",
"properties": {"request_id": {"type": "string"}},
"required": ["request_id"],
},
),
]
def run_context_server(server_port: int) -> None:
"""Run a server that captures request context"""
sse = SseServerTransport("/messages/")
context_server = RequestContextServer()
async def handle_sse(request: Request) -> Response:
async with sse.connect_sse(
request.scope, request.receive, request._send
) as streams:
await context_server.run(
streams[0], streams[1], context_server.create_initialization_options()
)
return Response()
app = Starlette(
routes=[
Route("/sse", endpoint=handle_sse),
Mount("/messages/", app=sse.handle_post_message),
]
)
server = uvicorn.Server(
config=uvicorn.Config(
app=app, host="127.0.0.1", port=server_port, log_level="error"
)
)
print(f"starting context server on {server_port}")
server.run()
@pytest.fixture()
def context_server(server_port: int) -> Generator[None, None, None]:
"""Fixture that provides a server with request context capture"""
proc = multiprocessing.Process(
target=run_context_server, kwargs={"server_port": server_port}, daemon=True
)
print("starting context server process")
proc.start()
# Wait for server to be running
max_attempts = 20
attempt = 0
print("waiting for context server to start")
while attempt < max_attempts:
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.connect(("127.0.0.1", server_port))
break
except ConnectionRefusedError:
time.sleep(0.1)
attempt += 1
else:
raise RuntimeError(
f"Context server failed to start after {max_attempts} attempts"
)
yield
print("killing context server")
proc.kill()
proc.join(timeout=2)
if proc.is_alive():
print("context server process failed to terminate")
@pytest.mark.anyio
async def test_request_context_propagation(
context_server: None, server_url: str
) -> None:
"""Test that request context is properly propagated through SSE transport."""
# Test with custom headers
custom_headers = {
"Authorization": "Bearer test-token",
"X-Custom-Header": "test-value",
"X-Trace-Id": "trace-123",
}
async with sse_client(server_url + "/sse", headers=custom_headers) as (
read_stream,
write_stream,
):
async with ClientSession(read_stream, write_stream) as session:
# Initialize the session
result = await session.initialize()
assert isinstance(result, InitializeResult)
# Call the tool that echoes headers back
tool_result = await session.call_tool("echo_headers", {})
# Parse the JSON response
assert len(tool_result.content) == 1
headers_data = json.loads(
tool_result.content[0].text
if tool_result.content[0].type == "text"
else "{}"
)
# Verify headers were propagated
assert headers_data.get("authorization") == "Bearer test-token"
assert headers_data.get("x-custom-header") == "test-value"
assert headers_data.get("x-trace-id") == "trace-123"
@pytest.mark.anyio
async def test_request_context_isolation(context_server: None, server_url: str) -> None:
"""Test that request contexts are isolated between different SSE clients."""
contexts = []
# Create multiple clients with different headers
for i in range(3):
headers = {"X-Request-Id": f"request-{i}", "X-Custom-Value": f"value-{i}"}
async with sse_client(server_url + "/sse", headers=headers) as (
read_stream,
write_stream,
):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()
# Call the tool that echoes context
tool_result = await session.call_tool(
"echo_context", {"request_id": f"request-{i}"}
)
assert len(tool_result.content) == 1
context_data = json.loads(
tool_result.content[0].text
if tool_result.content[0].type == "text"
else "{}"
)
contexts.append(context_data)
# Verify each request had its own context
assert len(contexts) == 3
for i, ctx in enumerate(contexts):
assert ctx["request_id"] == f"request-{i}"
assert ctx["headers"].get("x-request-id") == f"request-{i}"
assert ctx["headers"].get("x-custom-value") == f"value-{i}"