mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-18 22:44:20 +01:00
Fix async callable object tools (#568)
This commit is contained in:
@@ -102,6 +102,39 @@ class TestAddTools:
|
||||
assert "age" in tool.parameters["$defs"]["UserInput"]["properties"]
|
||||
assert "flag" in tool.parameters["properties"]
|
||||
|
||||
def test_add_callable_object(self):
|
||||
"""Test registering a callable object."""
|
||||
|
||||
class MyTool:
|
||||
def __init__(self):
|
||||
self.__name__ = "MyTool"
|
||||
|
||||
def __call__(self, x: int) -> int:
|
||||
return x * 2
|
||||
|
||||
manager = ToolManager()
|
||||
tool = manager.add_tool(MyTool())
|
||||
assert tool.name == "MyTool"
|
||||
assert tool.is_async is False
|
||||
assert tool.parameters["properties"]["x"]["type"] == "integer"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_add_async_callable_object(self):
|
||||
"""Test registering an async callable object."""
|
||||
|
||||
class MyAsyncTool:
|
||||
def __init__(self):
|
||||
self.__name__ = "MyAsyncTool"
|
||||
|
||||
async def __call__(self, x: int) -> int:
|
||||
return x * 2
|
||||
|
||||
manager = ToolManager()
|
||||
tool = manager.add_tool(MyAsyncTool())
|
||||
assert tool.name == "MyAsyncTool"
|
||||
assert tool.is_async is True
|
||||
assert tool.parameters["properties"]["x"]["type"] == "integer"
|
||||
|
||||
def test_add_invalid_tool(self):
|
||||
manager = ToolManager()
|
||||
with pytest.raises(AttributeError):
|
||||
@@ -168,6 +201,34 @@ class TestCallTools:
|
||||
result = await manager.call_tool("double", {"n": 5})
|
||||
assert result == 10
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_call_object_tool(self):
|
||||
class MyTool:
|
||||
def __init__(self):
|
||||
self.__name__ = "MyTool"
|
||||
|
||||
def __call__(self, x: int) -> int:
|
||||
return x * 2
|
||||
|
||||
manager = ToolManager()
|
||||
tool = manager.add_tool(MyTool())
|
||||
result = await tool.run({"x": 5})
|
||||
assert result == 10
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_call_async_object_tool(self):
|
||||
class MyAsyncTool:
|
||||
def __init__(self):
|
||||
self.__name__ = "MyAsyncTool"
|
||||
|
||||
async def __call__(self, x: int) -> int:
|
||||
return x * 2
|
||||
|
||||
manager = ToolManager()
|
||||
tool = manager.add_tool(MyAsyncTool())
|
||||
result = await tool.run({"x": 5})
|
||||
assert result == 10
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_call_tool_with_default_args(self):
|
||||
def add(a: int, b: int = 1) -> int:
|
||||
|
||||
Reference in New Issue
Block a user