mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 14:54:24 +01:00
Support for http request injection propagation to tools (#816)
This commit is contained in:
@@ -49,7 +49,7 @@ from mcp.server.sse import SseServerTransport
|
|||||||
from mcp.server.stdio import stdio_server
|
from mcp.server.stdio import stdio_server
|
||||||
from mcp.server.streamable_http import EventStore
|
from mcp.server.streamable_http import EventStore
|
||||||
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
|
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
|
||||||
from mcp.shared.context import LifespanContextT, RequestContext
|
from mcp.shared.context import LifespanContextT, RequestContext, RequestT
|
||||||
from mcp.types import (
|
from mcp.types import (
|
||||||
AnyFunction,
|
AnyFunction,
|
||||||
EmbeddedResource,
|
EmbeddedResource,
|
||||||
@@ -124,9 +124,11 @@ class Settings(BaseSettings, Generic[LifespanResultT]):
|
|||||||
def lifespan_wrapper(
|
def lifespan_wrapper(
|
||||||
app: FastMCP,
|
app: FastMCP,
|
||||||
lifespan: Callable[[FastMCP], AbstractAsyncContextManager[LifespanResultT]],
|
lifespan: Callable[[FastMCP], AbstractAsyncContextManager[LifespanResultT]],
|
||||||
) -> Callable[[MCPServer[LifespanResultT]], AbstractAsyncContextManager[object]]:
|
) -> Callable[
|
||||||
|
[MCPServer[LifespanResultT, Request]], AbstractAsyncContextManager[object]
|
||||||
|
]:
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def wrap(s: MCPServer[LifespanResultT]) -> AsyncIterator[object]:
|
async def wrap(s: MCPServer[LifespanResultT, Request]) -> AsyncIterator[object]:
|
||||||
async with lifespan(app) as context:
|
async with lifespan(app) as context:
|
||||||
yield context
|
yield context
|
||||||
|
|
||||||
@@ -260,7 +262,7 @@ class FastMCP:
|
|||||||
for info in tools
|
for info in tools
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_context(self) -> Context[ServerSession, object]:
|
def get_context(self) -> Context[ServerSession, object, Request]:
|
||||||
"""
|
"""
|
||||||
Returns a Context object. Note that the context will only be valid
|
Returns a Context object. Note that the context will only be valid
|
||||||
during a request; outside a request, most methods will error.
|
during a request; outside a request, most methods will error.
|
||||||
@@ -893,7 +895,7 @@ def _convert_to_content(
|
|||||||
return [TextContent(type="text", text=result)]
|
return [TextContent(type="text", text=result)]
|
||||||
|
|
||||||
|
|
||||||
class Context(BaseModel, Generic[ServerSessionT, LifespanContextT]):
|
class Context(BaseModel, Generic[ServerSessionT, LifespanContextT, RequestT]):
|
||||||
"""Context object providing access to MCP capabilities.
|
"""Context object providing access to MCP capabilities.
|
||||||
|
|
||||||
This provides a cleaner interface to MCP's RequestContext functionality.
|
This provides a cleaner interface to MCP's RequestContext functionality.
|
||||||
@@ -927,13 +929,15 @@ class Context(BaseModel, Generic[ServerSessionT, LifespanContextT]):
|
|||||||
The context is optional - tools that don't need it can omit the parameter.
|
The context is optional - tools that don't need it can omit the parameter.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_request_context: RequestContext[ServerSessionT, LifespanContextT] | None
|
_request_context: RequestContext[ServerSessionT, LifespanContextT, RequestT] | None
|
||||||
_fastmcp: FastMCP | None
|
_fastmcp: FastMCP | None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
request_context: RequestContext[ServerSessionT, LifespanContextT] | None = None,
|
request_context: (
|
||||||
|
RequestContext[ServerSessionT, LifespanContextT, RequestT] | None
|
||||||
|
) = None,
|
||||||
fastmcp: FastMCP | None = None,
|
fastmcp: FastMCP | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
):
|
):
|
||||||
@@ -949,7 +953,9 @@ class Context(BaseModel, Generic[ServerSessionT, LifespanContextT]):
|
|||||||
return self._fastmcp
|
return self._fastmcp
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def request_context(self) -> RequestContext[ServerSessionT, LifespanContextT]:
|
def request_context(
|
||||||
|
self,
|
||||||
|
) -> RequestContext[ServerSessionT, LifespanContextT, RequestT]:
|
||||||
"""Access to the underlying request context."""
|
"""Access to the underlying request context."""
|
||||||
if self._request_context is None:
|
if self._request_context is None:
|
||||||
raise ValueError("Context is not available outside of a request")
|
raise ValueError("Context is not available outside of a request")
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ from mcp.types import ToolAnnotations
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from mcp.server.fastmcp.server import Context
|
from mcp.server.fastmcp.server import Context
|
||||||
from mcp.server.session import ServerSessionT
|
from mcp.server.session import ServerSessionT
|
||||||
from mcp.shared.context import LifespanContextT
|
from mcp.shared.context import LifespanContextT, RequestT
|
||||||
|
|
||||||
|
|
||||||
class Tool(BaseModel):
|
class Tool(BaseModel):
|
||||||
@@ -85,7 +85,7 @@ class Tool(BaseModel):
|
|||||||
async def run(
|
async def run(
|
||||||
self,
|
self,
|
||||||
arguments: dict[str, Any],
|
arguments: dict[str, Any],
|
||||||
context: Context[ServerSessionT, LifespanContextT] | None = None,
|
context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Run the tool with arguments."""
|
"""Run the tool with arguments."""
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Any
|
|||||||
from mcp.server.fastmcp.exceptions import ToolError
|
from mcp.server.fastmcp.exceptions import ToolError
|
||||||
from mcp.server.fastmcp.tools.base import Tool
|
from mcp.server.fastmcp.tools.base import Tool
|
||||||
from mcp.server.fastmcp.utilities.logging import get_logger
|
from mcp.server.fastmcp.utilities.logging import get_logger
|
||||||
from mcp.shared.context import LifespanContextT
|
from mcp.shared.context import LifespanContextT, RequestT
|
||||||
from mcp.types import ToolAnnotations
|
from mcp.types import ToolAnnotations
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -65,7 +65,7 @@ class ToolManager:
|
|||||||
self,
|
self,
|
||||||
name: str,
|
name: str,
|
||||||
arguments: dict[str, Any],
|
arguments: dict[str, Any],
|
||||||
context: Context[ServerSessionT, LifespanContextT] | None = None,
|
context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Call a tool by name with arguments."""
|
"""Call a tool by name with arguments."""
|
||||||
tool = self.get_tool(name)
|
tool = self.get_tool(name)
|
||||||
|
|||||||
@@ -72,11 +72,12 @@ import logging
|
|||||||
import warnings
|
import warnings
|
||||||
from collections.abc import AsyncIterator, Awaitable, Callable, Iterable
|
from collections.abc import AsyncIterator, Awaitable, Callable, Iterable
|
||||||
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
|
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
|
||||||
from typing import Any, Generic, TypeVar
|
from typing import Any, Generic
|
||||||
|
|
||||||
import anyio
|
import anyio
|
||||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||||
from pydantic import AnyUrl
|
from pydantic import AnyUrl
|
||||||
|
from typing_extensions import TypeVar
|
||||||
|
|
||||||
import mcp.types as types
|
import mcp.types as types
|
||||||
from mcp.server.lowlevel.helper_types import ReadResourceContents
|
from mcp.server.lowlevel.helper_types import ReadResourceContents
|
||||||
@@ -85,15 +86,16 @@ from mcp.server.session import ServerSession
|
|||||||
from mcp.server.stdio import stdio_server as stdio_server
|
from mcp.server.stdio import stdio_server as stdio_server
|
||||||
from mcp.shared.context import RequestContext
|
from mcp.shared.context import RequestContext
|
||||||
from mcp.shared.exceptions import McpError
|
from mcp.shared.exceptions import McpError
|
||||||
from mcp.shared.message import SessionMessage
|
from mcp.shared.message import ServerMessageMetadata, SessionMessage
|
||||||
from mcp.shared.session import RequestResponder
|
from mcp.shared.session import RequestResponder
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
LifespanResultT = TypeVar("LifespanResultT")
|
LifespanResultT = TypeVar("LifespanResultT")
|
||||||
|
RequestT = TypeVar("RequestT", default=Any)
|
||||||
|
|
||||||
# This will be properly typed in each Server instance's context
|
# This will be properly typed in each Server instance's context
|
||||||
request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any]] = (
|
request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any, Any]] = (
|
||||||
contextvars.ContextVar("request_ctx")
|
contextvars.ContextVar("request_ctx")
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -111,7 +113,7 @@ class NotificationOptions:
|
|||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(server: Server[LifespanResultT]) -> AsyncIterator[object]:
|
async def lifespan(server: Server[LifespanResultT, RequestT]) -> AsyncIterator[object]:
|
||||||
"""Default lifespan context manager that does nothing.
|
"""Default lifespan context manager that does nothing.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -123,14 +125,15 @@ async def lifespan(server: Server[LifespanResultT]) -> AsyncIterator[object]:
|
|||||||
yield {}
|
yield {}
|
||||||
|
|
||||||
|
|
||||||
class Server(Generic[LifespanResultT]):
|
class Server(Generic[LifespanResultT, RequestT]):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
name: str,
|
name: str,
|
||||||
version: str | None = None,
|
version: str | None = None,
|
||||||
instructions: str | None = None,
|
instructions: str | None = None,
|
||||||
lifespan: Callable[
|
lifespan: Callable[
|
||||||
[Server[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]
|
[Server[LifespanResultT, RequestT]],
|
||||||
|
AbstractAsyncContextManager[LifespanResultT],
|
||||||
] = lifespan,
|
] = lifespan,
|
||||||
):
|
):
|
||||||
self.name = name
|
self.name = name
|
||||||
@@ -215,7 +218,9 @@ class Server(Generic[LifespanResultT]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def request_context(self) -> RequestContext[ServerSession, LifespanResultT]:
|
def request_context(
|
||||||
|
self,
|
||||||
|
) -> RequestContext[ServerSession, LifespanResultT, RequestT]:
|
||||||
"""If called outside of a request context, this will raise a LookupError."""
|
"""If called outside of a request context, this will raise a LookupError."""
|
||||||
return request_ctx.get()
|
return request_ctx.get()
|
||||||
|
|
||||||
@@ -555,6 +560,13 @@ class Server(Generic[LifespanResultT]):
|
|||||||
|
|
||||||
token = None
|
token = None
|
||||||
try:
|
try:
|
||||||
|
# Extract request context from message metadata
|
||||||
|
request_data = None
|
||||||
|
if message.message_metadata is not None and isinstance(
|
||||||
|
message.message_metadata, ServerMessageMetadata
|
||||||
|
):
|
||||||
|
request_data = message.message_metadata.request_context
|
||||||
|
|
||||||
# Set our global state that can be retrieved via
|
# Set our global state that can be retrieved via
|
||||||
# app.get_request_context()
|
# app.get_request_context()
|
||||||
token = request_ctx.set(
|
token = request_ctx.set(
|
||||||
@@ -563,6 +575,7 @@ class Server(Generic[LifespanResultT]):
|
|||||||
message.request_meta,
|
message.request_meta,
|
||||||
session,
|
session,
|
||||||
lifespan_context,
|
lifespan_context,
|
||||||
|
request=request_data,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
response = await handler(req)
|
response = await handler(req)
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ from starlette.responses import Response
|
|||||||
from starlette.types import Receive, Scope, Send
|
from starlette.types import Receive, Scope, Send
|
||||||
|
|
||||||
import mcp.types as types
|
import mcp.types as types
|
||||||
from mcp.shared.message import SessionMessage
|
from mcp.shared.message import ServerMessageMetadata, SessionMessage
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -203,7 +203,9 @@ class SseServerTransport:
|
|||||||
await writer.send(err)
|
await writer.send(err)
|
||||||
return
|
return
|
||||||
|
|
||||||
session_message = SessionMessage(message)
|
# Pass the ASGI scope for framework-agnostic access to request data
|
||||||
|
metadata = ServerMessageMetadata(request_context=request)
|
||||||
|
session_message = SessionMessage(message, metadata=metadata)
|
||||||
logger.debug(f"Sending session message to writer: {session_message}")
|
logger.debug(f"Sending session message to writer: {session_message}")
|
||||||
response = Response("Accepted", status_code=202)
|
response = Response("Accepted", status_code=202)
|
||||||
await response(scope, receive, send)
|
await response(scope, receive, send)
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ class StreamableHTTPSessionManager:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
app: MCPServer[Any],
|
app: MCPServer[Any, Any],
|
||||||
event_store: EventStore | None = None,
|
event_store: EventStore | None = None,
|
||||||
json_response: bool = False,
|
json_response: bool = False,
|
||||||
stateless: bool = False,
|
stateless: bool = False,
|
||||||
|
|||||||
@@ -8,11 +8,13 @@ from mcp.types import RequestId, RequestParams
|
|||||||
|
|
||||||
SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any])
|
SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any])
|
||||||
LifespanContextT = TypeVar("LifespanContextT")
|
LifespanContextT = TypeVar("LifespanContextT")
|
||||||
|
RequestT = TypeVar("RequestT", default=Any)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RequestContext(Generic[SessionT, LifespanContextT]):
|
class RequestContext(Generic[SessionT, LifespanContextT, RequestT]):
|
||||||
request_id: RequestId
|
request_id: RequestId
|
||||||
meta: RequestParams.Meta | None
|
meta: RequestParams.Meta | None
|
||||||
session: SessionT
|
session: SessionT
|
||||||
lifespan_context: LifespanContextT
|
lifespan_context: LifespanContextT
|
||||||
|
request: RequestT | None = None
|
||||||
|
|||||||
@@ -30,6 +30,8 @@ class ServerMessageMetadata:
|
|||||||
"""Metadata specific to server messages."""
|
"""Metadata specific to server messages."""
|
||||||
|
|
||||||
related_request_id: RequestId | None = None
|
related_request_id: RequestId | None = None
|
||||||
|
# Request-specific context (e.g., headers, auth info)
|
||||||
|
request_context: object | None = None
|
||||||
|
|
||||||
|
|
||||||
MessageMetadata = ClientMessageMetadata | ServerMessageMetadata | None
|
MessageMetadata = ClientMessageMetadata | ServerMessageMetadata | None
|
||||||
|
|||||||
@@ -81,10 +81,12 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
|
|||||||
ReceiveNotificationT
|
ReceiveNotificationT
|
||||||
]""",
|
]""",
|
||||||
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any],
|
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any],
|
||||||
|
message_metadata: MessageMetadata = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.request_id = request_id
|
self.request_id = request_id
|
||||||
self.request_meta = request_meta
|
self.request_meta = request_meta
|
||||||
self.request = request
|
self.request = request
|
||||||
|
self.message_metadata = message_metadata
|
||||||
self._session = session
|
self._session = session
|
||||||
self._completed = False
|
self._completed = False
|
||||||
self._cancel_scope = anyio.CancelScope()
|
self._cancel_scope = anyio.CancelScope()
|
||||||
@@ -365,6 +367,7 @@ class BaseSession(
|
|||||||
request=validated_request,
|
request=validated_request,
|
||||||
session=self,
|
session=self,
|
||||||
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
|
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
|
||||||
|
message_metadata=message.metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._in_flight[responder.request_id] = responder
|
self._in_flight[responder.request_id] = responder
|
||||||
|
|||||||
@@ -5,14 +5,18 @@ These tests validate the proper functioning of FastMCP in various configurations
|
|||||||
including with and without authentication.
|
including with and without authentication.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import socket
|
import socket
|
||||||
import time
|
import time
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from pydantic import AnyUrl
|
from pydantic import AnyUrl
|
||||||
|
from starlette.applications import Starlette
|
||||||
|
from starlette.requests import Request
|
||||||
|
|
||||||
import mcp.types as types
|
import mcp.types as types
|
||||||
from mcp.client.session import ClientSession
|
from mcp.client.session import ClientSession
|
||||||
@@ -20,6 +24,7 @@ from mcp.client.sse import sse_client
|
|||||||
from mcp.client.streamable_http import streamablehttp_client
|
from mcp.client.streamable_http import streamablehttp_client
|
||||||
from mcp.server.fastmcp import FastMCP
|
from mcp.server.fastmcp import FastMCP
|
||||||
from mcp.server.fastmcp.resources import FunctionResource
|
from mcp.server.fastmcp.resources import FunctionResource
|
||||||
|
from mcp.server.fastmcp.server import Context
|
||||||
from mcp.shared.context import RequestContext
|
from mcp.shared.context import RequestContext
|
||||||
from mcp.types import (
|
from mcp.types import (
|
||||||
CreateMessageRequestParams,
|
CreateMessageRequestParams,
|
||||||
@@ -78,8 +83,6 @@ def stateless_http_server_url(stateless_http_server_port: int) -> str:
|
|||||||
# Create a function to make the FastMCP server app
|
# Create a function to make the FastMCP server app
|
||||||
def make_fastmcp_app():
|
def make_fastmcp_app():
|
||||||
"""Create a FastMCP server without auth settings."""
|
"""Create a FastMCP server without auth settings."""
|
||||||
from starlette.applications import Starlette
|
|
||||||
|
|
||||||
mcp = FastMCP(name="NoAuthServer")
|
mcp = FastMCP(name="NoAuthServer")
|
||||||
|
|
||||||
# Add a simple tool
|
# Add a simple tool
|
||||||
@@ -88,7 +91,7 @@ def make_fastmcp_app():
|
|||||||
return f"Echo: {message}"
|
return f"Echo: {message}"
|
||||||
|
|
||||||
# Create the SSE app
|
# Create the SSE app
|
||||||
app: Starlette = mcp.sse_app()
|
app = mcp.sse_app()
|
||||||
|
|
||||||
return mcp, app
|
return mcp, app
|
||||||
|
|
||||||
@@ -198,17 +201,14 @@ def make_everything_fastmcp() -> FastMCP:
|
|||||||
|
|
||||||
def make_everything_fastmcp_app():
|
def make_everything_fastmcp_app():
|
||||||
"""Create a comprehensive FastMCP server with SSE transport."""
|
"""Create a comprehensive FastMCP server with SSE transport."""
|
||||||
from starlette.applications import Starlette
|
|
||||||
|
|
||||||
mcp = make_everything_fastmcp()
|
mcp = make_everything_fastmcp()
|
||||||
# Create the SSE app
|
# Create the SSE app
|
||||||
app: Starlette = mcp.sse_app()
|
app = mcp.sse_app()
|
||||||
return mcp, app
|
return mcp, app
|
||||||
|
|
||||||
|
|
||||||
def make_fastmcp_streamable_http_app():
|
def make_fastmcp_streamable_http_app():
|
||||||
"""Create a FastMCP server with StreamableHTTP transport."""
|
"""Create a FastMCP server with StreamableHTTP transport."""
|
||||||
from starlette.applications import Starlette
|
|
||||||
|
|
||||||
mcp = FastMCP(name="NoAuthServer")
|
mcp = FastMCP(name="NoAuthServer")
|
||||||
|
|
||||||
@@ -225,8 +225,6 @@ def make_fastmcp_streamable_http_app():
|
|||||||
|
|
||||||
def make_everything_fastmcp_streamable_http_app():
|
def make_everything_fastmcp_streamable_http_app():
|
||||||
"""Create a comprehensive FastMCP server with StreamableHTTP transport."""
|
"""Create a comprehensive FastMCP server with StreamableHTTP transport."""
|
||||||
from starlette.applications import Starlette
|
|
||||||
|
|
||||||
# Create a new instance with different name for HTTP transport
|
# Create a new instance with different name for HTTP transport
|
||||||
mcp = make_everything_fastmcp()
|
mcp = make_everything_fastmcp()
|
||||||
# We can't change the name after creation, so we'll use the same name
|
# We can't change the name after creation, so we'll use the same name
|
||||||
@@ -237,7 +235,6 @@ def make_everything_fastmcp_streamable_http_app():
|
|||||||
|
|
||||||
def make_fastmcp_stateless_http_app():
|
def make_fastmcp_stateless_http_app():
|
||||||
"""Create a FastMCP server with stateless StreamableHTTP transport."""
|
"""Create a FastMCP server with stateless StreamableHTTP transport."""
|
||||||
from starlette.applications import Starlette
|
|
||||||
|
|
||||||
mcp = FastMCP(name="StatelessServer", stateless_http=True)
|
mcp = FastMCP(name="StatelessServer", stateless_http=True)
|
||||||
|
|
||||||
@@ -435,6 +432,174 @@ async def test_fastmcp_without_auth(server: None, server_url: str) -> None:
|
|||||||
assert tool_result.content[0].text == "Echo: hello"
|
assert tool_result.content[0].text == "Echo: hello"
|
||||||
|
|
||||||
|
|
||||||
|
def make_fastmcp_with_context_app():
|
||||||
|
"""Create a FastMCP server that can access request context."""
|
||||||
|
|
||||||
|
mcp = FastMCP(name="ContextServer")
|
||||||
|
|
||||||
|
# Tool that echoes request headers
|
||||||
|
@mcp.tool(description="Echo request headers from context")
|
||||||
|
def echo_headers(ctx: Context[Any, Any, Request]) -> str:
|
||||||
|
"""Returns the request headers as JSON."""
|
||||||
|
headers_info = {}
|
||||||
|
if ctx.request_context.request:
|
||||||
|
# Now the type system knows request is a Starlette Request object
|
||||||
|
headers_info = dict(ctx.request_context.request.headers)
|
||||||
|
return json.dumps(headers_info)
|
||||||
|
|
||||||
|
# Tool that returns full request context
|
||||||
|
@mcp.tool(description="Echo request context with custom data")
|
||||||
|
def echo_context(custom_request_id: str, ctx: Context[Any, Any, Request]) -> str:
|
||||||
|
"""Returns request context including headers and custom data."""
|
||||||
|
context_data = {
|
||||||
|
"custom_request_id": custom_request_id,
|
||||||
|
"headers": {},
|
||||||
|
"method": None,
|
||||||
|
"path": None,
|
||||||
|
}
|
||||||
|
if ctx.request_context.request:
|
||||||
|
request = ctx.request_context.request
|
||||||
|
context_data["headers"] = dict(request.headers)
|
||||||
|
context_data["method"] = request.method
|
||||||
|
context_data["path"] = request.url.path
|
||||||
|
return json.dumps(context_data)
|
||||||
|
|
||||||
|
# Create the SSE app
|
||||||
|
app = mcp.sse_app()
|
||||||
|
return mcp, app
|
||||||
|
|
||||||
|
|
||||||
|
def run_context_server(server_port: int) -> None:
|
||||||
|
"""Run the context-aware FastMCP server."""
|
||||||
|
_, app = make_fastmcp_with_context_app()
|
||||||
|
server = uvicorn.Server(
|
||||||
|
config=uvicorn.Config(
|
||||||
|
app=app, host="127.0.0.1", port=server_port, log_level="error"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
print(f"Starting context server on port {server_port}")
|
||||||
|
server.run()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def context_aware_server(server_port: int) -> Generator[None, None, None]:
|
||||||
|
"""Start the context-aware server in a separate process."""
|
||||||
|
proc = multiprocessing.Process(
|
||||||
|
target=run_context_server, args=(server_port,), daemon=True
|
||||||
|
)
|
||||||
|
print("Starting context-aware server process")
|
||||||
|
proc.start()
|
||||||
|
|
||||||
|
# Wait for server to be running
|
||||||
|
max_attempts = 20
|
||||||
|
attempt = 0
|
||||||
|
print("Waiting for context-aware server to start")
|
||||||
|
while attempt < max_attempts:
|
||||||
|
try:
|
||||||
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||||
|
s.connect(("127.0.0.1", server_port))
|
||||||
|
break
|
||||||
|
except ConnectionRefusedError:
|
||||||
|
time.sleep(0.1)
|
||||||
|
attempt += 1
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Context server failed to start after {max_attempts} attempts"
|
||||||
|
)
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
print("Killing context-aware server")
|
||||||
|
proc.kill()
|
||||||
|
proc.join(timeout=2)
|
||||||
|
if proc.is_alive():
|
||||||
|
print("Context server process failed to terminate")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_fast_mcp_with_request_context(
|
||||||
|
context_aware_server: None, server_url: str
|
||||||
|
) -> None:
|
||||||
|
"""Test that FastMCP properly propagates request context to tools."""
|
||||||
|
# Test with custom headers
|
||||||
|
custom_headers = {
|
||||||
|
"Authorization": "Bearer fastmcp-test-token",
|
||||||
|
"X-Custom-Header": "fastmcp-value",
|
||||||
|
"X-Request-Id": "req-123",
|
||||||
|
}
|
||||||
|
|
||||||
|
async with sse_client(server_url + "/sse", headers=custom_headers) as streams:
|
||||||
|
async with ClientSession(*streams) as session:
|
||||||
|
# Initialize the session
|
||||||
|
result = await session.initialize()
|
||||||
|
assert isinstance(result, InitializeResult)
|
||||||
|
assert result.serverInfo.name == "ContextServer"
|
||||||
|
|
||||||
|
# Test 1: Call tool that echoes headers
|
||||||
|
headers_result = await session.call_tool("echo_headers", {})
|
||||||
|
assert len(headers_result.content) == 1
|
||||||
|
assert isinstance(headers_result.content[0], TextContent)
|
||||||
|
|
||||||
|
headers_data = json.loads(headers_result.content[0].text)
|
||||||
|
assert headers_data.get("authorization") == "Bearer fastmcp-test-token"
|
||||||
|
assert headers_data.get("x-custom-header") == "fastmcp-value"
|
||||||
|
assert headers_data.get("x-request-id") == "req-123"
|
||||||
|
|
||||||
|
# Test 2: Call tool that returns full context
|
||||||
|
context_result = await session.call_tool(
|
||||||
|
"echo_context", {"custom_request_id": "test-123"}
|
||||||
|
)
|
||||||
|
assert len(context_result.content) == 1
|
||||||
|
assert isinstance(context_result.content[0], TextContent)
|
||||||
|
|
||||||
|
context_data = json.loads(context_result.content[0].text)
|
||||||
|
assert context_data["custom_request_id"] == "test-123"
|
||||||
|
assert (
|
||||||
|
context_data["headers"].get("authorization")
|
||||||
|
== "Bearer fastmcp-test-token"
|
||||||
|
)
|
||||||
|
assert context_data["method"] == "POST" #
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_fast_mcp_request_context_isolation(
|
||||||
|
context_aware_server: None, server_url: str
|
||||||
|
) -> None:
|
||||||
|
"""Test that request contexts are isolated between different FastMCP clients."""
|
||||||
|
contexts = []
|
||||||
|
|
||||||
|
# Create multiple clients with different headers
|
||||||
|
for i in range(3):
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer token-{i}",
|
||||||
|
"X-Request-Id": f"fastmcp-req-{i}",
|
||||||
|
"X-Custom-Value": f"value-{i}",
|
||||||
|
}
|
||||||
|
|
||||||
|
async with sse_client(server_url + "/sse", headers=headers) as streams:
|
||||||
|
async with ClientSession(*streams) as session:
|
||||||
|
await session.initialize()
|
||||||
|
|
||||||
|
# Call the tool that returns context
|
||||||
|
tool_result = await session.call_tool(
|
||||||
|
"echo_context", {"custom_request_id": f"test-req-{i}"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse and store the result
|
||||||
|
assert len(tool_result.content) == 1
|
||||||
|
assert isinstance(tool_result.content[0], TextContent)
|
||||||
|
context_data = json.loads(tool_result.content[0].text)
|
||||||
|
contexts.append(context_data)
|
||||||
|
|
||||||
|
# Verify each request had its own isolated context
|
||||||
|
assert len(contexts) == 3
|
||||||
|
for i, ctx in enumerate(contexts):
|
||||||
|
assert ctx["custom_request_id"] == f"test-req-{i}"
|
||||||
|
assert ctx["headers"].get("authorization") == f"Bearer token-{i}"
|
||||||
|
assert ctx["headers"].get("x-request-id") == f"fastmcp-req-{i}"
|
||||||
|
assert ctx["headers"].get("x-custom-value") == f"value-{i}"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_fastmcp_streamable_http(
|
async def test_fastmcp_streamable_http(
|
||||||
streamable_http_server: None, http_server_url: str
|
streamable_http_server: None, http_server_url: str
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from mcp.server.fastmcp.exceptions import ToolError
|
|||||||
from mcp.server.fastmcp.tools import Tool, ToolManager
|
from mcp.server.fastmcp.tools import Tool, ToolManager
|
||||||
from mcp.server.fastmcp.utilities.func_metadata import ArgModelBase, FuncMetadata
|
from mcp.server.fastmcp.utilities.func_metadata import ArgModelBase, FuncMetadata
|
||||||
from mcp.server.session import ServerSessionT
|
from mcp.server.session import ServerSessionT
|
||||||
from mcp.shared.context import LifespanContextT
|
from mcp.shared.context import LifespanContextT, RequestT
|
||||||
from mcp.types import ToolAnnotations
|
from mcp.types import ToolAnnotations
|
||||||
|
|
||||||
|
|
||||||
@@ -347,7 +347,7 @@ class TestContextHandling:
|
|||||||
assert tool.context_kwarg is None
|
assert tool.context_kwarg is None
|
||||||
|
|
||||||
def tool_with_parametrized_context(
|
def tool_with_parametrized_context(
|
||||||
x: int, ctx: Context[ServerSessionT, LifespanContextT]
|
x: int, ctx: Context[ServerSessionT, LifespanContextT, RequestT]
|
||||||
) -> str:
|
) -> str:
|
||||||
return str(x)
|
return str(x)
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import json
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import socket
|
import socket
|
||||||
import time
|
import time
|
||||||
@@ -318,3 +319,187 @@ async def test_sse_client_basic_connection_mounted_app(
|
|||||||
# Test ping
|
# Test ping
|
||||||
ping_result = await session.send_ping()
|
ping_result = await session.send_ping()
|
||||||
assert isinstance(ping_result, EmptyResult)
|
assert isinstance(ping_result, EmptyResult)
|
||||||
|
|
||||||
|
|
||||||
|
# Test server with request context that returns headers in the response
|
||||||
|
class RequestContextServer(Server[object, Request]):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__("request_context_server")
|
||||||
|
|
||||||
|
@self.call_tool()
|
||||||
|
async def handle_call_tool(name: str, args: dict) -> list[TextContent]:
|
||||||
|
headers_info = {}
|
||||||
|
context = self.request_context
|
||||||
|
if context.request:
|
||||||
|
headers_info = dict(context.request.headers)
|
||||||
|
|
||||||
|
if name == "echo_headers":
|
||||||
|
return [TextContent(type="text", text=json.dumps(headers_info))]
|
||||||
|
elif name == "echo_context":
|
||||||
|
context_data = {
|
||||||
|
"request_id": args.get("request_id"),
|
||||||
|
"headers": headers_info,
|
||||||
|
}
|
||||||
|
return [TextContent(type="text", text=json.dumps(context_data))]
|
||||||
|
|
||||||
|
return [TextContent(type="text", text=f"Called {name}")]
|
||||||
|
|
||||||
|
@self.list_tools()
|
||||||
|
async def handle_list_tools() -> list[Tool]:
|
||||||
|
return [
|
||||||
|
Tool(
|
||||||
|
name="echo_headers",
|
||||||
|
description="Echoes request headers",
|
||||||
|
inputSchema={"type": "object", "properties": {}},
|
||||||
|
),
|
||||||
|
Tool(
|
||||||
|
name="echo_context",
|
||||||
|
description="Echoes request context",
|
||||||
|
inputSchema={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"request_id": {"type": "string"}},
|
||||||
|
"required": ["request_id"],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def run_context_server(server_port: int) -> None:
|
||||||
|
"""Run a server that captures request context"""
|
||||||
|
sse = SseServerTransport("/messages/")
|
||||||
|
context_server = RequestContextServer()
|
||||||
|
|
||||||
|
async def handle_sse(request: Request) -> Response:
|
||||||
|
async with sse.connect_sse(
|
||||||
|
request.scope, request.receive, request._send
|
||||||
|
) as streams:
|
||||||
|
await context_server.run(
|
||||||
|
streams[0], streams[1], context_server.create_initialization_options()
|
||||||
|
)
|
||||||
|
return Response()
|
||||||
|
|
||||||
|
app = Starlette(
|
||||||
|
routes=[
|
||||||
|
Route("/sse", endpoint=handle_sse),
|
||||||
|
Mount("/messages/", app=sse.handle_post_message),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
server = uvicorn.Server(
|
||||||
|
config=uvicorn.Config(
|
||||||
|
app=app, host="127.0.0.1", port=server_port, log_level="error"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
print(f"starting context server on {server_port}")
|
||||||
|
server.run()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def context_server(server_port: int) -> Generator[None, None, None]:
|
||||||
|
"""Fixture that provides a server with request context capture"""
|
||||||
|
proc = multiprocessing.Process(
|
||||||
|
target=run_context_server, kwargs={"server_port": server_port}, daemon=True
|
||||||
|
)
|
||||||
|
print("starting context server process")
|
||||||
|
proc.start()
|
||||||
|
|
||||||
|
# Wait for server to be running
|
||||||
|
max_attempts = 20
|
||||||
|
attempt = 0
|
||||||
|
print("waiting for context server to start")
|
||||||
|
while attempt < max_attempts:
|
||||||
|
try:
|
||||||
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||||
|
s.connect(("127.0.0.1", server_port))
|
||||||
|
break
|
||||||
|
except ConnectionRefusedError:
|
||||||
|
time.sleep(0.1)
|
||||||
|
attempt += 1
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Context server failed to start after {max_attempts} attempts"
|
||||||
|
)
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
print("killing context server")
|
||||||
|
proc.kill()
|
||||||
|
proc.join(timeout=2)
|
||||||
|
if proc.is_alive():
|
||||||
|
print("context server process failed to terminate")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_request_context_propagation(
|
||||||
|
context_server: None, server_url: str
|
||||||
|
) -> None:
|
||||||
|
"""Test that request context is properly propagated through SSE transport."""
|
||||||
|
# Test with custom headers
|
||||||
|
custom_headers = {
|
||||||
|
"Authorization": "Bearer test-token",
|
||||||
|
"X-Custom-Header": "test-value",
|
||||||
|
"X-Trace-Id": "trace-123",
|
||||||
|
}
|
||||||
|
|
||||||
|
async with sse_client(server_url + "/sse", headers=custom_headers) as (
|
||||||
|
read_stream,
|
||||||
|
write_stream,
|
||||||
|
):
|
||||||
|
async with ClientSession(read_stream, write_stream) as session:
|
||||||
|
# Initialize the session
|
||||||
|
result = await session.initialize()
|
||||||
|
assert isinstance(result, InitializeResult)
|
||||||
|
|
||||||
|
# Call the tool that echoes headers back
|
||||||
|
tool_result = await session.call_tool("echo_headers", {})
|
||||||
|
|
||||||
|
# Parse the JSON response
|
||||||
|
|
||||||
|
assert len(tool_result.content) == 1
|
||||||
|
headers_data = json.loads(
|
||||||
|
tool_result.content[0].text
|
||||||
|
if tool_result.content[0].type == "text"
|
||||||
|
else "{}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify headers were propagated
|
||||||
|
assert headers_data.get("authorization") == "Bearer test-token"
|
||||||
|
assert headers_data.get("x-custom-header") == "test-value"
|
||||||
|
assert headers_data.get("x-trace-id") == "trace-123"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_request_context_isolation(context_server: None, server_url: str) -> None:
|
||||||
|
"""Test that request contexts are isolated between different SSE clients."""
|
||||||
|
contexts = []
|
||||||
|
|
||||||
|
# Create multiple clients with different headers
|
||||||
|
for i in range(3):
|
||||||
|
headers = {"X-Request-Id": f"request-{i}", "X-Custom-Value": f"value-{i}"}
|
||||||
|
|
||||||
|
async with sse_client(server_url + "/sse", headers=headers) as (
|
||||||
|
read_stream,
|
||||||
|
write_stream,
|
||||||
|
):
|
||||||
|
async with ClientSession(read_stream, write_stream) as session:
|
||||||
|
await session.initialize()
|
||||||
|
|
||||||
|
# Call the tool that echoes context
|
||||||
|
tool_result = await session.call_tool(
|
||||||
|
"echo_context", {"request_id": f"request-{i}"}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(tool_result.content) == 1
|
||||||
|
context_data = json.loads(
|
||||||
|
tool_result.content[0].text
|
||||||
|
if tool_result.content[0].type == "text"
|
||||||
|
else "{}"
|
||||||
|
)
|
||||||
|
contexts.append(context_data)
|
||||||
|
|
||||||
|
# Verify each request had its own context
|
||||||
|
assert len(contexts) == 3
|
||||||
|
for i, ctx in enumerate(contexts):
|
||||||
|
assert ctx["request_id"] == f"request-{i}"
|
||||||
|
assert ctx["headers"].get("x-request-id") == f"request-{i}"
|
||||||
|
assert ctx["headers"].get("x-custom-value") == f"value-{i}"
|
||||||
|
|||||||
Reference in New Issue
Block a user