Support for http request injection propagation to tools (#816)

This commit is contained in:
ihrpr
2025-05-28 15:59:14 +01:00
committed by GitHub
parent 532b1176f9
commit 70014a2bbb
12 changed files with 413 additions and 35 deletions

View File

@@ -49,7 +49,7 @@ from mcp.server.sse import SseServerTransport
from mcp.server.stdio import stdio_server
from mcp.server.streamable_http import EventStore
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
from mcp.shared.context import LifespanContextT, RequestContext
from mcp.shared.context import LifespanContextT, RequestContext, RequestT
from mcp.types import (
AnyFunction,
EmbeddedResource,
@@ -124,9 +124,11 @@ class Settings(BaseSettings, Generic[LifespanResultT]):
def lifespan_wrapper(
app: FastMCP,
lifespan: Callable[[FastMCP], AbstractAsyncContextManager[LifespanResultT]],
) -> Callable[[MCPServer[LifespanResultT]], AbstractAsyncContextManager[object]]:
) -> Callable[
[MCPServer[LifespanResultT, Request]], AbstractAsyncContextManager[object]
]:
@asynccontextmanager
async def wrap(s: MCPServer[LifespanResultT]) -> AsyncIterator[object]:
async def wrap(s: MCPServer[LifespanResultT, Request]) -> AsyncIterator[object]:
async with lifespan(app) as context:
yield context
@@ -260,7 +262,7 @@ class FastMCP:
for info in tools
]
def get_context(self) -> Context[ServerSession, object]:
def get_context(self) -> Context[ServerSession, object, Request]:
"""
Returns a Context object. Note that the context will only be valid
during a request; outside a request, most methods will error.
@@ -893,7 +895,7 @@ def _convert_to_content(
return [TextContent(type="text", text=result)]
class Context(BaseModel, Generic[ServerSessionT, LifespanContextT]):
class Context(BaseModel, Generic[ServerSessionT, LifespanContextT, RequestT]):
"""Context object providing access to MCP capabilities.
This provides a cleaner interface to MCP's RequestContext functionality.
@@ -927,13 +929,15 @@ class Context(BaseModel, Generic[ServerSessionT, LifespanContextT]):
The context is optional - tools that don't need it can omit the parameter.
"""
_request_context: RequestContext[ServerSessionT, LifespanContextT] | None
_request_context: RequestContext[ServerSessionT, LifespanContextT, RequestT] | None
_fastmcp: FastMCP | None
def __init__(
self,
*,
request_context: RequestContext[ServerSessionT, LifespanContextT] | None = None,
request_context: (
RequestContext[ServerSessionT, LifespanContextT, RequestT] | None
) = None,
fastmcp: FastMCP | None = None,
**kwargs: Any,
):
@@ -949,7 +953,9 @@ class Context(BaseModel, Generic[ServerSessionT, LifespanContextT]):
return self._fastmcp
@property
def request_context(self) -> RequestContext[ServerSessionT, LifespanContextT]:
def request_context(
self,
) -> RequestContext[ServerSessionT, LifespanContextT, RequestT]:
"""Access to the underlying request context."""
if self._request_context is None:
raise ValueError("Context is not available outside of a request")

View File

@@ -14,7 +14,7 @@ from mcp.types import ToolAnnotations
if TYPE_CHECKING:
from mcp.server.fastmcp.server import Context
from mcp.server.session import ServerSessionT
from mcp.shared.context import LifespanContextT
from mcp.shared.context import LifespanContextT, RequestT
class Tool(BaseModel):
@@ -85,7 +85,7 @@ class Tool(BaseModel):
async def run(
self,
arguments: dict[str, Any],
context: Context[ServerSessionT, LifespanContextT] | None = None,
context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None,
) -> Any:
"""Run the tool with arguments."""
try:

View File

@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Any
from mcp.server.fastmcp.exceptions import ToolError
from mcp.server.fastmcp.tools.base import Tool
from mcp.server.fastmcp.utilities.logging import get_logger
from mcp.shared.context import LifespanContextT
from mcp.shared.context import LifespanContextT, RequestT
from mcp.types import ToolAnnotations
if TYPE_CHECKING:
@@ -65,7 +65,7 @@ class ToolManager:
self,
name: str,
arguments: dict[str, Any],
context: Context[ServerSessionT, LifespanContextT] | None = None,
context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None,
) -> Any:
"""Call a tool by name with arguments."""
tool = self.get_tool(name)

View File

@@ -72,11 +72,12 @@ import logging
import warnings
from collections.abc import AsyncIterator, Awaitable, Callable, Iterable
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
from typing import Any, Generic, TypeVar
from typing import Any, Generic
import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import AnyUrl
from typing_extensions import TypeVar
import mcp.types as types
from mcp.server.lowlevel.helper_types import ReadResourceContents
@@ -85,15 +86,16 @@ from mcp.server.session import ServerSession
from mcp.server.stdio import stdio_server as stdio_server
from mcp.shared.context import RequestContext
from mcp.shared.exceptions import McpError
from mcp.shared.message import SessionMessage
from mcp.shared.message import ServerMessageMetadata, SessionMessage
from mcp.shared.session import RequestResponder
logger = logging.getLogger(__name__)
LifespanResultT = TypeVar("LifespanResultT")
RequestT = TypeVar("RequestT", default=Any)
# This will be properly typed in each Server instance's context
request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any]] = (
request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any, Any]] = (
contextvars.ContextVar("request_ctx")
)
@@ -111,7 +113,7 @@ class NotificationOptions:
@asynccontextmanager
async def lifespan(server: Server[LifespanResultT]) -> AsyncIterator[object]:
async def lifespan(server: Server[LifespanResultT, RequestT]) -> AsyncIterator[object]:
"""Default lifespan context manager that does nothing.
Args:
@@ -123,14 +125,15 @@ async def lifespan(server: Server[LifespanResultT]) -> AsyncIterator[object]:
yield {}
class Server(Generic[LifespanResultT]):
class Server(Generic[LifespanResultT, RequestT]):
def __init__(
self,
name: str,
version: str | None = None,
instructions: str | None = None,
lifespan: Callable[
[Server[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]
[Server[LifespanResultT, RequestT]],
AbstractAsyncContextManager[LifespanResultT],
] = lifespan,
):
self.name = name
@@ -215,7 +218,9 @@ class Server(Generic[LifespanResultT]):
)
@property
def request_context(self) -> RequestContext[ServerSession, LifespanResultT]:
def request_context(
self,
) -> RequestContext[ServerSession, LifespanResultT, RequestT]:
"""If called outside of a request context, this will raise a LookupError."""
return request_ctx.get()
@@ -555,6 +560,13 @@ class Server(Generic[LifespanResultT]):
token = None
try:
# Extract request context from message metadata
request_data = None
if message.message_metadata is not None and isinstance(
message.message_metadata, ServerMessageMetadata
):
request_data = message.message_metadata.request_context
# Set our global state that can be retrieved via
# app.get_request_context()
token = request_ctx.set(
@@ -563,6 +575,7 @@ class Server(Generic[LifespanResultT]):
message.request_meta,
session,
lifespan_context,
request=request_data,
)
)
response = await handler(req)

View File

@@ -52,7 +52,7 @@ from starlette.responses import Response
from starlette.types import Receive, Scope, Send
import mcp.types as types
from mcp.shared.message import SessionMessage
from mcp.shared.message import ServerMessageMetadata, SessionMessage
logger = logging.getLogger(__name__)
@@ -203,7 +203,9 @@ class SseServerTransport:
await writer.send(err)
return
session_message = SessionMessage(message)
# Pass the ASGI scope for framework-agnostic access to request data
metadata = ServerMessageMetadata(request_context=request)
session_message = SessionMessage(message, metadata=metadata)
logger.debug(f"Sending session message to writer: {session_message}")
response = Response("Accepted", status_code=202)
await response(scope, receive, send)

View File

@@ -56,7 +56,7 @@ class StreamableHTTPSessionManager:
def __init__(
self,
app: MCPServer[Any],
app: MCPServer[Any, Any],
event_store: EventStore | None = None,
json_response: bool = False,
stateless: bool = False,

View File

@@ -8,11 +8,13 @@ from mcp.types import RequestId, RequestParams
SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any])
LifespanContextT = TypeVar("LifespanContextT")
RequestT = TypeVar("RequestT", default=Any)
@dataclass
class RequestContext(Generic[SessionT, LifespanContextT]):
class RequestContext(Generic[SessionT, LifespanContextT, RequestT]):
request_id: RequestId
meta: RequestParams.Meta | None
session: SessionT
lifespan_context: LifespanContextT
request: RequestT | None = None

View File

@@ -30,6 +30,8 @@ class ServerMessageMetadata:
"""Metadata specific to server messages."""
related_request_id: RequestId | None = None
# Request-specific context (e.g., headers, auth info)
request_context: object | None = None
MessageMetadata = ClientMessageMetadata | ServerMessageMetadata | None

View File

@@ -81,10 +81,12 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
ReceiveNotificationT
]""",
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any],
message_metadata: MessageMetadata = None,
) -> None:
self.request_id = request_id
self.request_meta = request_meta
self.request = request
self.message_metadata = message_metadata
self._session = session
self._completed = False
self._cancel_scope = anyio.CancelScope()
@@ -365,6 +367,7 @@ class BaseSession(
request=validated_request,
session=self,
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
message_metadata=message.metadata,
)
self._in_flight[responder.request_id] = responder