mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-20 07:14:24 +01:00
Support for http request injection propagation to tools (#816)
This commit is contained in:
@@ -49,7 +49,7 @@ from mcp.server.sse import SseServerTransport
|
||||
from mcp.server.stdio import stdio_server
|
||||
from mcp.server.streamable_http import EventStore
|
||||
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
|
||||
from mcp.shared.context import LifespanContextT, RequestContext
|
||||
from mcp.shared.context import LifespanContextT, RequestContext, RequestT
|
||||
from mcp.types import (
|
||||
AnyFunction,
|
||||
EmbeddedResource,
|
||||
@@ -124,9 +124,11 @@ class Settings(BaseSettings, Generic[LifespanResultT]):
|
||||
def lifespan_wrapper(
|
||||
app: FastMCP,
|
||||
lifespan: Callable[[FastMCP], AbstractAsyncContextManager[LifespanResultT]],
|
||||
) -> Callable[[MCPServer[LifespanResultT]], AbstractAsyncContextManager[object]]:
|
||||
) -> Callable[
|
||||
[MCPServer[LifespanResultT, Request]], AbstractAsyncContextManager[object]
|
||||
]:
|
||||
@asynccontextmanager
|
||||
async def wrap(s: MCPServer[LifespanResultT]) -> AsyncIterator[object]:
|
||||
async def wrap(s: MCPServer[LifespanResultT, Request]) -> AsyncIterator[object]:
|
||||
async with lifespan(app) as context:
|
||||
yield context
|
||||
|
||||
@@ -260,7 +262,7 @@ class FastMCP:
|
||||
for info in tools
|
||||
]
|
||||
|
||||
def get_context(self) -> Context[ServerSession, object]:
|
||||
def get_context(self) -> Context[ServerSession, object, Request]:
|
||||
"""
|
||||
Returns a Context object. Note that the context will only be valid
|
||||
during a request; outside a request, most methods will error.
|
||||
@@ -893,7 +895,7 @@ def _convert_to_content(
|
||||
return [TextContent(type="text", text=result)]
|
||||
|
||||
|
||||
class Context(BaseModel, Generic[ServerSessionT, LifespanContextT]):
|
||||
class Context(BaseModel, Generic[ServerSessionT, LifespanContextT, RequestT]):
|
||||
"""Context object providing access to MCP capabilities.
|
||||
|
||||
This provides a cleaner interface to MCP's RequestContext functionality.
|
||||
@@ -927,13 +929,15 @@ class Context(BaseModel, Generic[ServerSessionT, LifespanContextT]):
|
||||
The context is optional - tools that don't need it can omit the parameter.
|
||||
"""
|
||||
|
||||
_request_context: RequestContext[ServerSessionT, LifespanContextT] | None
|
||||
_request_context: RequestContext[ServerSessionT, LifespanContextT, RequestT] | None
|
||||
_fastmcp: FastMCP | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
request_context: RequestContext[ServerSessionT, LifespanContextT] | None = None,
|
||||
request_context: (
|
||||
RequestContext[ServerSessionT, LifespanContextT, RequestT] | None
|
||||
) = None,
|
||||
fastmcp: FastMCP | None = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
@@ -949,7 +953,9 @@ class Context(BaseModel, Generic[ServerSessionT, LifespanContextT]):
|
||||
return self._fastmcp
|
||||
|
||||
@property
|
||||
def request_context(self) -> RequestContext[ServerSessionT, LifespanContextT]:
|
||||
def request_context(
|
||||
self,
|
||||
) -> RequestContext[ServerSessionT, LifespanContextT, RequestT]:
|
||||
"""Access to the underlying request context."""
|
||||
if self._request_context is None:
|
||||
raise ValueError("Context is not available outside of a request")
|
||||
|
||||
@@ -14,7 +14,7 @@ from mcp.types import ToolAnnotations
|
||||
if TYPE_CHECKING:
|
||||
from mcp.server.fastmcp.server import Context
|
||||
from mcp.server.session import ServerSessionT
|
||||
from mcp.shared.context import LifespanContextT
|
||||
from mcp.shared.context import LifespanContextT, RequestT
|
||||
|
||||
|
||||
class Tool(BaseModel):
|
||||
@@ -85,7 +85,7 @@ class Tool(BaseModel):
|
||||
async def run(
|
||||
self,
|
||||
arguments: dict[str, Any],
|
||||
context: Context[ServerSessionT, LifespanContextT] | None = None,
|
||||
context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None,
|
||||
) -> Any:
|
||||
"""Run the tool with arguments."""
|
||||
try:
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Any
|
||||
from mcp.server.fastmcp.exceptions import ToolError
|
||||
from mcp.server.fastmcp.tools.base import Tool
|
||||
from mcp.server.fastmcp.utilities.logging import get_logger
|
||||
from mcp.shared.context import LifespanContextT
|
||||
from mcp.shared.context import LifespanContextT, RequestT
|
||||
from mcp.types import ToolAnnotations
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -65,7 +65,7 @@ class ToolManager:
|
||||
self,
|
||||
name: str,
|
||||
arguments: dict[str, Any],
|
||||
context: Context[ServerSessionT, LifespanContextT] | None = None,
|
||||
context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None,
|
||||
) -> Any:
|
||||
"""Call a tool by name with arguments."""
|
||||
tool = self.get_tool(name)
|
||||
|
||||
@@ -72,11 +72,12 @@ import logging
|
||||
import warnings
|
||||
from collections.abc import AsyncIterator, Awaitable, Callable, Iterable
|
||||
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
|
||||
from typing import Any, Generic, TypeVar
|
||||
from typing import Any, Generic
|
||||
|
||||
import anyio
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from pydantic import AnyUrl
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
import mcp.types as types
|
||||
from mcp.server.lowlevel.helper_types import ReadResourceContents
|
||||
@@ -85,15 +86,16 @@ from mcp.server.session import ServerSession
|
||||
from mcp.server.stdio import stdio_server as stdio_server
|
||||
from mcp.shared.context import RequestContext
|
||||
from mcp.shared.exceptions import McpError
|
||||
from mcp.shared.message import SessionMessage
|
||||
from mcp.shared.message import ServerMessageMetadata, SessionMessage
|
||||
from mcp.shared.session import RequestResponder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
LifespanResultT = TypeVar("LifespanResultT")
|
||||
RequestT = TypeVar("RequestT", default=Any)
|
||||
|
||||
# This will be properly typed in each Server instance's context
|
||||
request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any]] = (
|
||||
request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any, Any]] = (
|
||||
contextvars.ContextVar("request_ctx")
|
||||
)
|
||||
|
||||
@@ -111,7 +113,7 @@ class NotificationOptions:
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(server: Server[LifespanResultT]) -> AsyncIterator[object]:
|
||||
async def lifespan(server: Server[LifespanResultT, RequestT]) -> AsyncIterator[object]:
|
||||
"""Default lifespan context manager that does nothing.
|
||||
|
||||
Args:
|
||||
@@ -123,14 +125,15 @@ async def lifespan(server: Server[LifespanResultT]) -> AsyncIterator[object]:
|
||||
yield {}
|
||||
|
||||
|
||||
class Server(Generic[LifespanResultT]):
|
||||
class Server(Generic[LifespanResultT, RequestT]):
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
version: str | None = None,
|
||||
instructions: str | None = None,
|
||||
lifespan: Callable[
|
||||
[Server[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]
|
||||
[Server[LifespanResultT, RequestT]],
|
||||
AbstractAsyncContextManager[LifespanResultT],
|
||||
] = lifespan,
|
||||
):
|
||||
self.name = name
|
||||
@@ -215,7 +218,9 @@ class Server(Generic[LifespanResultT]):
|
||||
)
|
||||
|
||||
@property
|
||||
def request_context(self) -> RequestContext[ServerSession, LifespanResultT]:
|
||||
def request_context(
|
||||
self,
|
||||
) -> RequestContext[ServerSession, LifespanResultT, RequestT]:
|
||||
"""If called outside of a request context, this will raise a LookupError."""
|
||||
return request_ctx.get()
|
||||
|
||||
@@ -555,6 +560,13 @@ class Server(Generic[LifespanResultT]):
|
||||
|
||||
token = None
|
||||
try:
|
||||
# Extract request context from message metadata
|
||||
request_data = None
|
||||
if message.message_metadata is not None and isinstance(
|
||||
message.message_metadata, ServerMessageMetadata
|
||||
):
|
||||
request_data = message.message_metadata.request_context
|
||||
|
||||
# Set our global state that can be retrieved via
|
||||
# app.get_request_context()
|
||||
token = request_ctx.set(
|
||||
@@ -563,6 +575,7 @@ class Server(Generic[LifespanResultT]):
|
||||
message.request_meta,
|
||||
session,
|
||||
lifespan_context,
|
||||
request=request_data,
|
||||
)
|
||||
)
|
||||
response = await handler(req)
|
||||
|
||||
@@ -52,7 +52,7 @@ from starlette.responses import Response
|
||||
from starlette.types import Receive, Scope, Send
|
||||
|
||||
import mcp.types as types
|
||||
from mcp.shared.message import SessionMessage
|
||||
from mcp.shared.message import ServerMessageMetadata, SessionMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -203,7 +203,9 @@ class SseServerTransport:
|
||||
await writer.send(err)
|
||||
return
|
||||
|
||||
session_message = SessionMessage(message)
|
||||
# Pass the ASGI scope for framework-agnostic access to request data
|
||||
metadata = ServerMessageMetadata(request_context=request)
|
||||
session_message = SessionMessage(message, metadata=metadata)
|
||||
logger.debug(f"Sending session message to writer: {session_message}")
|
||||
response = Response("Accepted", status_code=202)
|
||||
await response(scope, receive, send)
|
||||
|
||||
@@ -56,7 +56,7 @@ class StreamableHTTPSessionManager:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app: MCPServer[Any],
|
||||
app: MCPServer[Any, Any],
|
||||
event_store: EventStore | None = None,
|
||||
json_response: bool = False,
|
||||
stateless: bool = False,
|
||||
|
||||
@@ -8,11 +8,13 @@ from mcp.types import RequestId, RequestParams
|
||||
|
||||
SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any])
|
||||
LifespanContextT = TypeVar("LifespanContextT")
|
||||
RequestT = TypeVar("RequestT", default=Any)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestContext(Generic[SessionT, LifespanContextT]):
|
||||
class RequestContext(Generic[SessionT, LifespanContextT, RequestT]):
|
||||
request_id: RequestId
|
||||
meta: RequestParams.Meta | None
|
||||
session: SessionT
|
||||
lifespan_context: LifespanContextT
|
||||
request: RequestT | None = None
|
||||
|
||||
@@ -30,6 +30,8 @@ class ServerMessageMetadata:
|
||||
"""Metadata specific to server messages."""
|
||||
|
||||
related_request_id: RequestId | None = None
|
||||
# Request-specific context (e.g., headers, auth info)
|
||||
request_context: object | None = None
|
||||
|
||||
|
||||
MessageMetadata = ClientMessageMetadata | ServerMessageMetadata | None
|
||||
|
||||
@@ -81,10 +81,12 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
|
||||
ReceiveNotificationT
|
||||
]""",
|
||||
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any],
|
||||
message_metadata: MessageMetadata = None,
|
||||
) -> None:
|
||||
self.request_id = request_id
|
||||
self.request_meta = request_meta
|
||||
self.request = request
|
||||
self.message_metadata = message_metadata
|
||||
self._session = session
|
||||
self._completed = False
|
||||
self._cancel_scope = anyio.CancelScope()
|
||||
@@ -365,6 +367,7 @@ class BaseSession(
|
||||
request=validated_request,
|
||||
session=self,
|
||||
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
|
||||
message_metadata=message.metadata,
|
||||
)
|
||||
|
||||
self._in_flight[responder.request_id] = responder
|
||||
|
||||
Reference in New Issue
Block a user