mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 06:54:18 +01:00
Fix async callable object tools (#568)
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING, Any, get_origin
|
||||
@@ -53,7 +54,7 @@ class Tool(BaseModel):
|
||||
raise ValueError("You must provide a name for lambda functions")
|
||||
|
||||
func_doc = description or fn.__doc__ or ""
|
||||
is_async = inspect.iscoroutinefunction(fn)
|
||||
is_async = _is_async_callable(fn)
|
||||
|
||||
if context_kwarg is None:
|
||||
sig = inspect.signature(fn)
|
||||
@@ -98,3 +99,12 @@ class Tool(BaseModel):
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolError(f"Error executing tool {self.name}: {e}") from e
|
||||
|
||||
|
||||
def _is_async_callable(obj: Any) -> bool:
|
||||
while isinstance(obj, functools.partial):
|
||||
obj = obj.func
|
||||
|
||||
return inspect.iscoroutinefunction(obj) or (
|
||||
callable(obj) and inspect.iscoroutinefunction(getattr(obj, "__call__", None))
|
||||
)
|
||||
|
||||
@@ -102,6 +102,39 @@ class TestAddTools:
|
||||
assert "age" in tool.parameters["$defs"]["UserInput"]["properties"]
|
||||
assert "flag" in tool.parameters["properties"]
|
||||
|
||||
def test_add_callable_object(self):
|
||||
"""Test registering a callable object."""
|
||||
|
||||
class MyTool:
|
||||
def __init__(self):
|
||||
self.__name__ = "MyTool"
|
||||
|
||||
def __call__(self, x: int) -> int:
|
||||
return x * 2
|
||||
|
||||
manager = ToolManager()
|
||||
tool = manager.add_tool(MyTool())
|
||||
assert tool.name == "MyTool"
|
||||
assert tool.is_async is False
|
||||
assert tool.parameters["properties"]["x"]["type"] == "integer"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_add_async_callable_object(self):
|
||||
"""Test registering an async callable object."""
|
||||
|
||||
class MyAsyncTool:
|
||||
def __init__(self):
|
||||
self.__name__ = "MyAsyncTool"
|
||||
|
||||
async def __call__(self, x: int) -> int:
|
||||
return x * 2
|
||||
|
||||
manager = ToolManager()
|
||||
tool = manager.add_tool(MyAsyncTool())
|
||||
assert tool.name == "MyAsyncTool"
|
||||
assert tool.is_async is True
|
||||
assert tool.parameters["properties"]["x"]["type"] == "integer"
|
||||
|
||||
def test_add_invalid_tool(self):
|
||||
manager = ToolManager()
|
||||
with pytest.raises(AttributeError):
|
||||
@@ -168,6 +201,34 @@ class TestCallTools:
|
||||
result = await manager.call_tool("double", {"n": 5})
|
||||
assert result == 10
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_call_object_tool(self):
|
||||
class MyTool:
|
||||
def __init__(self):
|
||||
self.__name__ = "MyTool"
|
||||
|
||||
def __call__(self, x: int) -> int:
|
||||
return x * 2
|
||||
|
||||
manager = ToolManager()
|
||||
tool = manager.add_tool(MyTool())
|
||||
result = await tool.run({"x": 5})
|
||||
assert result == 10
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_call_async_object_tool(self):
|
||||
class MyAsyncTool:
|
||||
def __init__(self):
|
||||
self.__name__ = "MyAsyncTool"
|
||||
|
||||
async def __call__(self, x: int) -> int:
|
||||
return x * 2
|
||||
|
||||
manager = ToolManager()
|
||||
tool = manager.add_tool(MyAsyncTool())
|
||||
result = await tool.run({"x": 5})
|
||||
assert result == 10
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_call_tool_with_default_args(self):
|
||||
def add(a: int, b: int = 1) -> int:
|
||||
|
||||
Reference in New Issue
Block a user