Fix async callable object tools (#568)

This commit is contained in:
Stephan Lensky
2025-05-23 12:21:53 -04:00
committed by GitHub
parent d1876433af
commit f2f4dbdcbd
2 changed files with 72 additions and 1 deletions

View File

@@ -102,6 +102,39 @@ class TestAddTools:
assert "age" in tool.parameters["$defs"]["UserInput"]["properties"]
assert "flag" in tool.parameters["properties"]
def test_add_callable_object(self):
"""Test registering a callable object."""
class MyTool:
def __init__(self):
self.__name__ = "MyTool"
def __call__(self, x: int) -> int:
return x * 2
manager = ToolManager()
tool = manager.add_tool(MyTool())
assert tool.name == "MyTool"
assert tool.is_async is False
assert tool.parameters["properties"]["x"]["type"] == "integer"
@pytest.mark.anyio
async def test_add_async_callable_object(self):
"""Test registering an async callable object."""
class MyAsyncTool:
def __init__(self):
self.__name__ = "MyAsyncTool"
async def __call__(self, x: int) -> int:
return x * 2
manager = ToolManager()
tool = manager.add_tool(MyAsyncTool())
assert tool.name == "MyAsyncTool"
assert tool.is_async is True
assert tool.parameters["properties"]["x"]["type"] == "integer"
def test_add_invalid_tool(self):
manager = ToolManager()
with pytest.raises(AttributeError):
@@ -168,6 +201,34 @@ class TestCallTools:
result = await manager.call_tool("double", {"n": 5})
assert result == 10
@pytest.mark.anyio
async def test_call_object_tool(self):
class MyTool:
def __init__(self):
self.__name__ = "MyTool"
def __call__(self, x: int) -> int:
return x * 2
manager = ToolManager()
tool = manager.add_tool(MyTool())
result = await tool.run({"x": 5})
assert result == 10
@pytest.mark.anyio
async def test_call_async_object_tool(self):
class MyAsyncTool:
def __init__(self):
self.__name__ = "MyAsyncTool"
async def __call__(self, x: int) -> int:
return x * 2
manager = ToolManager()
tool = manager.add_tool(MyAsyncTool())
result = await tool.run({"x": 5})
assert result == 10
@pytest.mark.anyio
async def test_call_tool_with_default_args(self):
def add(a: int, b: int = 1) -> int: