mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 06:54:18 +01:00
Fix async callable object tools (#568)
This commit is contained in:
@@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations as _annotations
|
from __future__ import annotations as _annotations
|
||||||
|
|
||||||
|
import functools
|
||||||
import inspect
|
import inspect
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from typing import TYPE_CHECKING, Any, get_origin
|
from typing import TYPE_CHECKING, Any, get_origin
|
||||||
@@ -53,7 +54,7 @@ class Tool(BaseModel):
|
|||||||
raise ValueError("You must provide a name for lambda functions")
|
raise ValueError("You must provide a name for lambda functions")
|
||||||
|
|
||||||
func_doc = description or fn.__doc__ or ""
|
func_doc = description or fn.__doc__ or ""
|
||||||
is_async = inspect.iscoroutinefunction(fn)
|
is_async = _is_async_callable(fn)
|
||||||
|
|
||||||
if context_kwarg is None:
|
if context_kwarg is None:
|
||||||
sig = inspect.signature(fn)
|
sig = inspect.signature(fn)
|
||||||
@@ -98,3 +99,12 @@ class Tool(BaseModel):
|
|||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ToolError(f"Error executing tool {self.name}: {e}") from e
|
raise ToolError(f"Error executing tool {self.name}: {e}") from e
|
||||||
|
|
||||||
|
|
||||||
|
def _is_async_callable(obj: Any) -> bool:
|
||||||
|
while isinstance(obj, functools.partial):
|
||||||
|
obj = obj.func
|
||||||
|
|
||||||
|
return inspect.iscoroutinefunction(obj) or (
|
||||||
|
callable(obj) and inspect.iscoroutinefunction(getattr(obj, "__call__", None))
|
||||||
|
)
|
||||||
|
|||||||
@@ -102,6 +102,39 @@ class TestAddTools:
|
|||||||
assert "age" in tool.parameters["$defs"]["UserInput"]["properties"]
|
assert "age" in tool.parameters["$defs"]["UserInput"]["properties"]
|
||||||
assert "flag" in tool.parameters["properties"]
|
assert "flag" in tool.parameters["properties"]
|
||||||
|
|
||||||
|
def test_add_callable_object(self):
|
||||||
|
"""Test registering a callable object."""
|
||||||
|
|
||||||
|
class MyTool:
|
||||||
|
def __init__(self):
|
||||||
|
self.__name__ = "MyTool"
|
||||||
|
|
||||||
|
def __call__(self, x: int) -> int:
|
||||||
|
return x * 2
|
||||||
|
|
||||||
|
manager = ToolManager()
|
||||||
|
tool = manager.add_tool(MyTool())
|
||||||
|
assert tool.name == "MyTool"
|
||||||
|
assert tool.is_async is False
|
||||||
|
assert tool.parameters["properties"]["x"]["type"] == "integer"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_add_async_callable_object(self):
|
||||||
|
"""Test registering an async callable object."""
|
||||||
|
|
||||||
|
class MyAsyncTool:
|
||||||
|
def __init__(self):
|
||||||
|
self.__name__ = "MyAsyncTool"
|
||||||
|
|
||||||
|
async def __call__(self, x: int) -> int:
|
||||||
|
return x * 2
|
||||||
|
|
||||||
|
manager = ToolManager()
|
||||||
|
tool = manager.add_tool(MyAsyncTool())
|
||||||
|
assert tool.name == "MyAsyncTool"
|
||||||
|
assert tool.is_async is True
|
||||||
|
assert tool.parameters["properties"]["x"]["type"] == "integer"
|
||||||
|
|
||||||
def test_add_invalid_tool(self):
|
def test_add_invalid_tool(self):
|
||||||
manager = ToolManager()
|
manager = ToolManager()
|
||||||
with pytest.raises(AttributeError):
|
with pytest.raises(AttributeError):
|
||||||
@@ -168,6 +201,34 @@ class TestCallTools:
|
|||||||
result = await manager.call_tool("double", {"n": 5})
|
result = await manager.call_tool("double", {"n": 5})
|
||||||
assert result == 10
|
assert result == 10
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_call_object_tool(self):
|
||||||
|
class MyTool:
|
||||||
|
def __init__(self):
|
||||||
|
self.__name__ = "MyTool"
|
||||||
|
|
||||||
|
def __call__(self, x: int) -> int:
|
||||||
|
return x * 2
|
||||||
|
|
||||||
|
manager = ToolManager()
|
||||||
|
tool = manager.add_tool(MyTool())
|
||||||
|
result = await tool.run({"x": 5})
|
||||||
|
assert result == 10
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_call_async_object_tool(self):
|
||||||
|
class MyAsyncTool:
|
||||||
|
def __init__(self):
|
||||||
|
self.__name__ = "MyAsyncTool"
|
||||||
|
|
||||||
|
async def __call__(self, x: int) -> int:
|
||||||
|
return x * 2
|
||||||
|
|
||||||
|
manager = ToolManager()
|
||||||
|
tool = manager.add_tool(MyAsyncTool())
|
||||||
|
result = await tool.run({"x": 5})
|
||||||
|
assert result == 10
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_call_tool_with_default_args(self):
|
async def test_call_tool_with_default_args(self):
|
||||||
def add(a: int, b: int = 1) -> int:
|
def add(a: int, b: int = 1) -> int:
|
||||||
|
|||||||
Reference in New Issue
Block a user