From da54ea003eef5926ecfb619ae47f38d7bd794cad Mon Sep 17 00:00:00 2001 From: Junpei Kawamoto Date: Thu, 10 Apr 2025 03:36:46 -0600 Subject: [PATCH] Allow generic parameters to be passed onto `Context` on FastMCP tools Co-authored-by: Marcelo Trylesinski --- src/mcp/server/fastmcp/tools/base.py | 6 ++++-- tests/server/fastmcp/test_tool_manager.py | 20 +++++++++++--------- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/src/mcp/server/fastmcp/tools/base.py b/src/mcp/server/fastmcp/tools/base.py index e137e84..92a216f 100644 --- a/src/mcp/server/fastmcp/tools/base.py +++ b/src/mcp/server/fastmcp/tools/base.py @@ -2,7 +2,7 @@ from __future__ import annotations as _annotations import inspect from collections.abc import Callable -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, get_origin from pydantic import BaseModel, Field @@ -53,7 +53,9 @@ class Tool(BaseModel): if context_kwarg is None: sig = inspect.signature(fn) for param_name, param in sig.parameters.items(): - if param.annotation is Context: + if get_origin(param.annotation) is not None: + continue + if issubclass(param.annotation, Context): context_kwarg = param_name break diff --git a/tests/server/fastmcp/test_tool_manager.py b/tests/server/fastmcp/test_tool_manager.py index d206758..8f52e3d 100644 --- a/tests/server/fastmcp/test_tool_manager.py +++ b/tests/server/fastmcp/test_tool_manager.py @@ -4,8 +4,11 @@ import logging import pytest from pydantic import BaseModel +from mcp.server.fastmcp import Context, FastMCP from mcp.server.fastmcp.exceptions import ToolError from mcp.server.fastmcp.tools import ToolManager +from mcp.server.session import ServerSessionT +from mcp.shared.context import LifespanContextT class TestAddTools: @@ -194,8 +197,6 @@ class TestCallTools: @pytest.mark.anyio async def test_call_tool_with_complex_model(self): - from mcp.server.fastmcp import Context - class MyShrimpTank(BaseModel): class Shrimp(BaseModel): name: str @@ -223,8 +224,6 @@ class TestCallTools: class TestToolSchema: @pytest.mark.anyio async def test_context_arg_excluded_from_schema(self): - from mcp.server.fastmcp import Context - def something(a: int, ctx: Context) -> int: return a @@ -241,7 +240,6 @@ class TestContextHandling: def test_context_parameter_detection(self): """Test that context parameters are properly detected in Tool.from_function().""" - from mcp.server.fastmcp import Context def tool_with_context(x: int, ctx: Context) -> str: return str(x) @@ -256,10 +254,17 @@ class TestContextHandling: tool = manager.add_tool(tool_without_context) assert tool.context_kwarg is None + def tool_with_parametrized_context( + x: int, ctx: Context[ServerSessionT, LifespanContextT] + ) -> str: + return str(x) + + tool = manager.add_tool(tool_with_parametrized_context) + assert tool.context_kwarg == "ctx" + @pytest.mark.anyio async def test_context_injection(self): """Test that context is properly injected during tool execution.""" - from mcp.server.fastmcp import Context, FastMCP def tool_with_context(x: int, ctx: Context) -> str: assert isinstance(ctx, Context) @@ -276,7 +281,6 @@ class TestContextHandling: @pytest.mark.anyio async def test_context_injection_async(self): """Test that context is properly injected in async tools.""" - from mcp.server.fastmcp import Context, FastMCP async def async_tool(x: int, ctx: Context) -> str: assert isinstance(ctx, Context) @@ -293,7 +297,6 @@ class TestContextHandling: @pytest.mark.anyio async def test_context_optional(self): """Test that context is optional when calling tools.""" - from mcp.server.fastmcp import Context def tool_with_context(x: int, ctx: Context | None = None) -> str: return str(x) @@ -307,7 +310,6 @@ class TestContextHandling: @pytest.mark.anyio async def test_context_error_handling(self): """Test error handling when context injection fails.""" - from mcp.server.fastmcp import Context, FastMCP def tool_with_context(x: int, ctx: Context) -> str: raise ValueError("Test error")