Fix async callable object tools (#568)

This commit is contained in:
Stephan Lensky
2025-05-23 12:21:53 -04:00
committed by GitHub
parent d1876433af
commit f2f4dbdcbd
2 changed files with 72 additions and 1 deletions

View File

@@ -1,5 +1,6 @@
from __future__ import annotations as _annotations from __future__ import annotations as _annotations
import functools
import inspect import inspect
from collections.abc import Callable from collections.abc import Callable
from typing import TYPE_CHECKING, Any, get_origin from typing import TYPE_CHECKING, Any, get_origin
@@ -53,7 +54,7 @@ class Tool(BaseModel):
raise ValueError("You must provide a name for lambda functions") raise ValueError("You must provide a name for lambda functions")
func_doc = description or fn.__doc__ or "" func_doc = description or fn.__doc__ or ""
is_async = inspect.iscoroutinefunction(fn) is_async = _is_async_callable(fn)
if context_kwarg is None: if context_kwarg is None:
sig = inspect.signature(fn) sig = inspect.signature(fn)
@@ -98,3 +99,12 @@ class Tool(BaseModel):
) )
except Exception as e: except Exception as e:
raise ToolError(f"Error executing tool {self.name}: {e}") from e raise ToolError(f"Error executing tool {self.name}: {e}") from e
def _is_async_callable(obj: Any) -> bool:
while isinstance(obj, functools.partial):
obj = obj.func
return inspect.iscoroutinefunction(obj) or (
callable(obj) and inspect.iscoroutinefunction(getattr(obj, "__call__", None))
)

View File

@@ -102,6 +102,39 @@ class TestAddTools:
assert "age" in tool.parameters["$defs"]["UserInput"]["properties"] assert "age" in tool.parameters["$defs"]["UserInput"]["properties"]
assert "flag" in tool.parameters["properties"] assert "flag" in tool.parameters["properties"]
def test_add_callable_object(self):
"""Test registering a callable object."""
class MyTool:
def __init__(self):
self.__name__ = "MyTool"
def __call__(self, x: int) -> int:
return x * 2
manager = ToolManager()
tool = manager.add_tool(MyTool())
assert tool.name == "MyTool"
assert tool.is_async is False
assert tool.parameters["properties"]["x"]["type"] == "integer"
@pytest.mark.anyio
async def test_add_async_callable_object(self):
"""Test registering an async callable object."""
class MyAsyncTool:
def __init__(self):
self.__name__ = "MyAsyncTool"
async def __call__(self, x: int) -> int:
return x * 2
manager = ToolManager()
tool = manager.add_tool(MyAsyncTool())
assert tool.name == "MyAsyncTool"
assert tool.is_async is True
assert tool.parameters["properties"]["x"]["type"] == "integer"
def test_add_invalid_tool(self): def test_add_invalid_tool(self):
manager = ToolManager() manager = ToolManager()
with pytest.raises(AttributeError): with pytest.raises(AttributeError):
@@ -168,6 +201,34 @@ class TestCallTools:
result = await manager.call_tool("double", {"n": 5}) result = await manager.call_tool("double", {"n": 5})
assert result == 10 assert result == 10
@pytest.mark.anyio
async def test_call_object_tool(self):
class MyTool:
def __init__(self):
self.__name__ = "MyTool"
def __call__(self, x: int) -> int:
return x * 2
manager = ToolManager()
tool = manager.add_tool(MyTool())
result = await tool.run({"x": 5})
assert result == 10
@pytest.mark.anyio
async def test_call_async_object_tool(self):
class MyAsyncTool:
def __init__(self):
self.__name__ = "MyAsyncTool"
async def __call__(self, x: int) -> int:
return x * 2
manager = ToolManager()
tool = manager.add_tool(MyAsyncTool())
result = await tool.run({"x": 5})
assert result == 10
@pytest.mark.anyio @pytest.mark.anyio
async def test_call_tool_with_default_args(self): async def test_call_tool_with_default_args(self):
def add(a: int, b: int = 1) -> int: def add(a: int, b: int = 1) -> int: