mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 14:54:24 +01:00
Allow generic parameters to be passed onto Context on FastMCP tools
Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
This commit is contained in:
@@ -2,7 +2,7 @@ from __future__ import annotations as _annotations
|
||||
|
||||
import inspect
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, get_origin
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -53,7 +53,9 @@ class Tool(BaseModel):
|
||||
if context_kwarg is None:
|
||||
sig = inspect.signature(fn)
|
||||
for param_name, param in sig.parameters.items():
|
||||
if param.annotation is Context:
|
||||
if get_origin(param.annotation) is not None:
|
||||
continue
|
||||
if issubclass(param.annotation, Context):
|
||||
context_kwarg = param_name
|
||||
break
|
||||
|
||||
|
||||
@@ -4,8 +4,11 @@ import logging
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from mcp.server.fastmcp import Context, FastMCP
|
||||
from mcp.server.fastmcp.exceptions import ToolError
|
||||
from mcp.server.fastmcp.tools import ToolManager
|
||||
from mcp.server.session import ServerSessionT
|
||||
from mcp.shared.context import LifespanContextT
|
||||
|
||||
|
||||
class TestAddTools:
|
||||
@@ -194,8 +197,6 @@ class TestCallTools:
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_call_tool_with_complex_model(self):
|
||||
from mcp.server.fastmcp import Context
|
||||
|
||||
class MyShrimpTank(BaseModel):
|
||||
class Shrimp(BaseModel):
|
||||
name: str
|
||||
@@ -223,8 +224,6 @@ class TestCallTools:
|
||||
class TestToolSchema:
|
||||
@pytest.mark.anyio
|
||||
async def test_context_arg_excluded_from_schema(self):
|
||||
from mcp.server.fastmcp import Context
|
||||
|
||||
def something(a: int, ctx: Context) -> int:
|
||||
return a
|
||||
|
||||
@@ -241,7 +240,6 @@ class TestContextHandling:
|
||||
def test_context_parameter_detection(self):
|
||||
"""Test that context parameters are properly detected in
|
||||
Tool.from_function()."""
|
||||
from mcp.server.fastmcp import Context
|
||||
|
||||
def tool_with_context(x: int, ctx: Context) -> str:
|
||||
return str(x)
|
||||
@@ -256,10 +254,17 @@ class TestContextHandling:
|
||||
tool = manager.add_tool(tool_without_context)
|
||||
assert tool.context_kwarg is None
|
||||
|
||||
def tool_with_parametrized_context(
|
||||
x: int, ctx: Context[ServerSessionT, LifespanContextT]
|
||||
) -> str:
|
||||
return str(x)
|
||||
|
||||
tool = manager.add_tool(tool_with_parametrized_context)
|
||||
assert tool.context_kwarg == "ctx"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_context_injection(self):
|
||||
"""Test that context is properly injected during tool execution."""
|
||||
from mcp.server.fastmcp import Context, FastMCP
|
||||
|
||||
def tool_with_context(x: int, ctx: Context) -> str:
|
||||
assert isinstance(ctx, Context)
|
||||
@@ -276,7 +281,6 @@ class TestContextHandling:
|
||||
@pytest.mark.anyio
|
||||
async def test_context_injection_async(self):
|
||||
"""Test that context is properly injected in async tools."""
|
||||
from mcp.server.fastmcp import Context, FastMCP
|
||||
|
||||
async def async_tool(x: int, ctx: Context) -> str:
|
||||
assert isinstance(ctx, Context)
|
||||
@@ -293,7 +297,6 @@ class TestContextHandling:
|
||||
@pytest.mark.anyio
|
||||
async def test_context_optional(self):
|
||||
"""Test that context is optional when calling tools."""
|
||||
from mcp.server.fastmcp import Context
|
||||
|
||||
def tool_with_context(x: int, ctx: Context | None = None) -> str:
|
||||
return str(x)
|
||||
@@ -307,7 +310,6 @@ class TestContextHandling:
|
||||
@pytest.mark.anyio
|
||||
async def test_context_error_handling(self):
|
||||
"""Test error handling when context injection fails."""
|
||||
from mcp.server.fastmcp import Context, FastMCP
|
||||
|
||||
def tool_with_context(x: int, ctx: Context) -> str:
|
||||
raise ValueError("Test error")
|
||||
|
||||
Reference in New Issue
Block a user