mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 14:54:24 +01:00
Allow generic parameters to be passed onto Context on FastMCP tools
Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
This commit is contained in:
@@ -2,7 +2,7 @@ from __future__ import annotations as _annotations
|
|||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any, get_origin
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
@@ -53,7 +53,9 @@ class Tool(BaseModel):
|
|||||||
if context_kwarg is None:
|
if context_kwarg is None:
|
||||||
sig = inspect.signature(fn)
|
sig = inspect.signature(fn)
|
||||||
for param_name, param in sig.parameters.items():
|
for param_name, param in sig.parameters.items():
|
||||||
if param.annotation is Context:
|
if get_origin(param.annotation) is not None:
|
||||||
|
continue
|
||||||
|
if issubclass(param.annotation, Context):
|
||||||
context_kwarg = param_name
|
context_kwarg = param_name
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|||||||
@@ -4,8 +4,11 @@ import logging
|
|||||||
import pytest
|
import pytest
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from mcp.server.fastmcp import Context, FastMCP
|
||||||
from mcp.server.fastmcp.exceptions import ToolError
|
from mcp.server.fastmcp.exceptions import ToolError
|
||||||
from mcp.server.fastmcp.tools import ToolManager
|
from mcp.server.fastmcp.tools import ToolManager
|
||||||
|
from mcp.server.session import ServerSessionT
|
||||||
|
from mcp.shared.context import LifespanContextT
|
||||||
|
|
||||||
|
|
||||||
class TestAddTools:
|
class TestAddTools:
|
||||||
@@ -194,8 +197,6 @@ class TestCallTools:
|
|||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_call_tool_with_complex_model(self):
|
async def test_call_tool_with_complex_model(self):
|
||||||
from mcp.server.fastmcp import Context
|
|
||||||
|
|
||||||
class MyShrimpTank(BaseModel):
|
class MyShrimpTank(BaseModel):
|
||||||
class Shrimp(BaseModel):
|
class Shrimp(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
@@ -223,8 +224,6 @@ class TestCallTools:
|
|||||||
class TestToolSchema:
|
class TestToolSchema:
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_context_arg_excluded_from_schema(self):
|
async def test_context_arg_excluded_from_schema(self):
|
||||||
from mcp.server.fastmcp import Context
|
|
||||||
|
|
||||||
def something(a: int, ctx: Context) -> int:
|
def something(a: int, ctx: Context) -> int:
|
||||||
return a
|
return a
|
||||||
|
|
||||||
@@ -241,7 +240,6 @@ class TestContextHandling:
|
|||||||
def test_context_parameter_detection(self):
|
def test_context_parameter_detection(self):
|
||||||
"""Test that context parameters are properly detected in
|
"""Test that context parameters are properly detected in
|
||||||
Tool.from_function()."""
|
Tool.from_function()."""
|
||||||
from mcp.server.fastmcp import Context
|
|
||||||
|
|
||||||
def tool_with_context(x: int, ctx: Context) -> str:
|
def tool_with_context(x: int, ctx: Context) -> str:
|
||||||
return str(x)
|
return str(x)
|
||||||
@@ -256,10 +254,17 @@ class TestContextHandling:
|
|||||||
tool = manager.add_tool(tool_without_context)
|
tool = manager.add_tool(tool_without_context)
|
||||||
assert tool.context_kwarg is None
|
assert tool.context_kwarg is None
|
||||||
|
|
||||||
|
def tool_with_parametrized_context(
|
||||||
|
x: int, ctx: Context[ServerSessionT, LifespanContextT]
|
||||||
|
) -> str:
|
||||||
|
return str(x)
|
||||||
|
|
||||||
|
tool = manager.add_tool(tool_with_parametrized_context)
|
||||||
|
assert tool.context_kwarg == "ctx"
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_context_injection(self):
|
async def test_context_injection(self):
|
||||||
"""Test that context is properly injected during tool execution."""
|
"""Test that context is properly injected during tool execution."""
|
||||||
from mcp.server.fastmcp import Context, FastMCP
|
|
||||||
|
|
||||||
def tool_with_context(x: int, ctx: Context) -> str:
|
def tool_with_context(x: int, ctx: Context) -> str:
|
||||||
assert isinstance(ctx, Context)
|
assert isinstance(ctx, Context)
|
||||||
@@ -276,7 +281,6 @@ class TestContextHandling:
|
|||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_context_injection_async(self):
|
async def test_context_injection_async(self):
|
||||||
"""Test that context is properly injected in async tools."""
|
"""Test that context is properly injected in async tools."""
|
||||||
from mcp.server.fastmcp import Context, FastMCP
|
|
||||||
|
|
||||||
async def async_tool(x: int, ctx: Context) -> str:
|
async def async_tool(x: int, ctx: Context) -> str:
|
||||||
assert isinstance(ctx, Context)
|
assert isinstance(ctx, Context)
|
||||||
@@ -293,7 +297,6 @@ class TestContextHandling:
|
|||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_context_optional(self):
|
async def test_context_optional(self):
|
||||||
"""Test that context is optional when calling tools."""
|
"""Test that context is optional when calling tools."""
|
||||||
from mcp.server.fastmcp import Context
|
|
||||||
|
|
||||||
def tool_with_context(x: int, ctx: Context | None = None) -> str:
|
def tool_with_context(x: int, ctx: Context | None = None) -> str:
|
||||||
return str(x)
|
return str(x)
|
||||||
@@ -307,7 +310,6 @@ class TestContextHandling:
|
|||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_context_error_handling(self):
|
async def test_context_error_handling(self):
|
||||||
"""Test error handling when context injection fails."""
|
"""Test error handling when context injection fails."""
|
||||||
from mcp.server.fastmcp import Context, FastMCP
|
|
||||||
|
|
||||||
def tool_with_context(x: int, ctx: Context) -> str:
|
def tool_with_context(x: int, ctx: Context) -> str:
|
||||||
raise ValueError("Test error")
|
raise ValueError("Test error")
|
||||||
|
|||||||
Reference in New Issue
Block a user