mirror of
https://github.com/aljazceru/Auto-GPT.git
synced 2026-01-03 06:14:32 +01:00
Improve command system; add aliases for commands (#2635)
* Command name supports multiple names * Separate CommandRegistry.commands and .command_aliases * Update test_commands.py * Add __contains__ operator to CommandRegistry * Update error message for unknown commands --------- Co-authored-by: Reinier van der Leer <github@pwuts.nl>
This commit is contained in:
@@ -85,7 +85,7 @@ class Agent:
|
||||
|
||||
def start_interaction_loop(self):
|
||||
# Avoid circular imports
|
||||
from autogpt.app import execute_command, get_command
|
||||
from autogpt.app import execute_command, extract_command
|
||||
|
||||
# Interaction Loop
|
||||
self.cycle_count = 0
|
||||
@@ -161,7 +161,7 @@ class Agent:
|
||||
print_assistant_thoughts(
|
||||
self.ai_name, assistant_reply_json, self.config
|
||||
)
|
||||
command_name, arguments = get_command(
|
||||
command_name, arguments = extract_command(
|
||||
assistant_reply_json, assistant_reply, self.config
|
||||
)
|
||||
if self.config.speak_mode:
|
||||
|
||||
@@ -23,7 +23,7 @@ def is_valid_int(value: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def get_command(
|
||||
def extract_command(
|
||||
assistant_reply_json: Dict, assistant_reply: ChatModelResponse, config: Config
|
||||
):
|
||||
"""Parse the response and return the command name and arguments
|
||||
@@ -78,21 +78,6 @@ def get_command(
|
||||
return "Error:", str(e)
|
||||
|
||||
|
||||
def map_command_synonyms(command_name: str):
|
||||
"""Takes the original command name given by the AI, and checks if the
|
||||
string matches a list of common/known hallucinations
|
||||
"""
|
||||
synonyms = [
|
||||
("write_file", "write_to_file"),
|
||||
("create_file", "write_to_file"),
|
||||
("search", "google"),
|
||||
]
|
||||
for seen_command, actual_command_name in synonyms:
|
||||
if command_name == seen_command:
|
||||
return actual_command_name
|
||||
return command_name
|
||||
|
||||
|
||||
def execute_command(
|
||||
command_name: str,
|
||||
arguments: dict[str, str],
|
||||
@@ -109,28 +94,21 @@ def execute_command(
|
||||
str: The result of the command
|
||||
"""
|
||||
try:
|
||||
cmd = agent.command_registry.commands.get(command_name)
|
||||
# Execute a native command with the same name or alias, if it exists
|
||||
if command := agent.command_registry.get_command(command_name):
|
||||
return command(**arguments, agent=agent)
|
||||
|
||||
# If the command is found, call it with the provided arguments
|
||||
if cmd:
|
||||
return cmd(**arguments, agent=agent)
|
||||
|
||||
# TODO: Remove commands below after they are moved to the command registry.
|
||||
command_name = map_command_synonyms(command_name.lower())
|
||||
|
||||
# TODO: Change these to take in a file rather than pasted code, if
|
||||
# non-file is given, return instructions "Input should be a python
|
||||
# filepath, write your code to file and try again
|
||||
# Handle non-native commands (e.g. from plugins)
|
||||
for command in agent.ai_config.prompt_generator.commands:
|
||||
if (
|
||||
command_name == command["label"].lower()
|
||||
or command_name == command["name"].lower()
|
||||
):
|
||||
return command["function"](**arguments)
|
||||
return (
|
||||
f"Unknown command '{command_name}'. Please refer to the 'COMMANDS'"
|
||||
" list for available commands and only respond in the specified JSON"
|
||||
" format."
|
||||
|
||||
raise RuntimeError(
|
||||
f"Cannot execute '{command_name}': unknown command."
|
||||
" Do not try to use this command again."
|
||||
)
|
||||
except Exception as e:
|
||||
return f"Error: {str(e)}"
|
||||
|
||||
@@ -20,6 +20,7 @@ def command(
|
||||
parameters: dict[str, CommandParameterSpec],
|
||||
enabled: bool | Callable[[Config], bool] = True,
|
||||
disabled_reason: Optional[str] = None,
|
||||
aliases: list[str] = [],
|
||||
) -> Callable[..., Any]:
|
||||
"""The command decorator is used to create Command objects from ordinary functions."""
|
||||
|
||||
@@ -40,6 +41,7 @@ def command(
|
||||
parameters=typed_parameters,
|
||||
enabled=enabled,
|
||||
disabled_reason=disabled_reason,
|
||||
aliases=aliases,
|
||||
)
|
||||
|
||||
@functools.wraps(func)
|
||||
|
||||
@@ -189,6 +189,7 @@ def ingest_file(
|
||||
"required": True,
|
||||
},
|
||||
},
|
||||
aliases=["write_file", "create_file"],
|
||||
)
|
||||
def write_to_file(filename: str, text: str, agent: Agent) -> str:
|
||||
"""Write text to a file
|
||||
|
||||
@@ -23,6 +23,7 @@ DUCKDUCKGO_MAX_ATTEMPTS = 3
|
||||
"required": True,
|
||||
}
|
||||
},
|
||||
aliases=["search"],
|
||||
)
|
||||
def web_search(query: str, agent: Agent, num_results: int = 8) -> str:
|
||||
"""Return the results of a Google search
|
||||
@@ -67,6 +68,7 @@ def web_search(query: str, agent: Agent, num_results: int = 8) -> str:
|
||||
lambda config: bool(config.google_api_key)
|
||||
and bool(config.google_custom_search_engine_id),
|
||||
"Configure google_api_key and custom_search_engine_id.",
|
||||
aliases=["search"],
|
||||
)
|
||||
def google(query: str, agent: Agent, num_results: int = 8) -> str | list[str]:
|
||||
"""Return the results of a Google search using the official Google API
|
||||
@@ -124,7 +126,7 @@ def google(query: str, agent: Agent, num_results: int = 8) -> str | list[str]:
|
||||
|
||||
def safe_google_results(results: str | list) -> str:
|
||||
"""
|
||||
Return the results of a google search in a safe format.
|
||||
Return the results of a Google search in a safe format.
|
||||
|
||||
Args:
|
||||
results (str | list): The search results.
|
||||
|
||||
@@ -154,7 +154,7 @@ def run_auto_gpt(
|
||||
incompatible_commands.append(command)
|
||||
|
||||
for command in incompatible_commands:
|
||||
command_registry.unregister(command.name)
|
||||
command_registry.unregister(command)
|
||||
logger.debug(
|
||||
f"Unregistering incompatible command: {command.name}, "
|
||||
f"reason - {command.disabled_reason or 'Disabled by current config.'}"
|
||||
|
||||
@@ -22,6 +22,7 @@ class Command:
|
||||
parameters: list[CommandParameter],
|
||||
enabled: bool | Callable[[Config], bool] = True,
|
||||
disabled_reason: Optional[str] = None,
|
||||
aliases: list[str] = [],
|
||||
):
|
||||
self.name = name
|
||||
self.description = description
|
||||
@@ -29,6 +30,7 @@ class Command:
|
||||
self.parameters = parameters
|
||||
self.enabled = enabled
|
||||
self.disabled_reason = disabled_reason
|
||||
self.aliases = aliases
|
||||
|
||||
def __call__(self, *args, **kwargs) -> Any:
|
||||
if hasattr(kwargs, "config") and callable(self.enabled):
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import importlib
|
||||
import inspect
|
||||
from typing import Any, Callable
|
||||
from typing import Any
|
||||
|
||||
from autogpt.command_decorator import AUTO_GPT_COMMAND_IDENTIFIER
|
||||
from autogpt.logs import logger
|
||||
@@ -15,10 +15,11 @@ class CommandRegistry:
|
||||
directory.
|
||||
"""
|
||||
|
||||
commands: dict[str, Command]
|
||||
commands: dict[str, Command] = {}
|
||||
commands_aliases: dict[str, Command] = {}
|
||||
|
||||
def __init__(self):
|
||||
self.commands = {}
|
||||
def __contains__(self, command_name: str):
|
||||
return command_name in self.commands or command_name in self.commands_aliases
|
||||
|
||||
def _import_module(self, module_name: str) -> Any:
|
||||
return importlib.import_module(module_name)
|
||||
@@ -33,11 +34,21 @@ class CommandRegistry:
|
||||
)
|
||||
self.commands[cmd.name] = cmd
|
||||
|
||||
def unregister(self, command_name: str):
|
||||
if command_name in self.commands:
|
||||
del self.commands[command_name]
|
||||
if cmd.name in self.commands_aliases:
|
||||
logger.warn(
|
||||
f"Command '{cmd.name}' will overwrite alias with the same name of "
|
||||
f"'{self.commands_aliases[cmd.name]}'!"
|
||||
)
|
||||
for alias in cmd.aliases:
|
||||
self.commands_aliases[alias] = cmd
|
||||
|
||||
def unregister(self, command: Command) -> None:
|
||||
if command.name in self.commands:
|
||||
del self.commands[command.name]
|
||||
for alias in command.aliases:
|
||||
del self.commands_aliases[alias]
|
||||
else:
|
||||
raise KeyError(f"Command '{command_name}' not found in registry.")
|
||||
raise KeyError(f"Command '{command.name}' not found in registry.")
|
||||
|
||||
def reload_commands(self) -> None:
|
||||
"""Reloads all loaded command plugins."""
|
||||
@@ -48,14 +59,17 @@ class CommandRegistry:
|
||||
if hasattr(reloaded_module, "register"):
|
||||
reloaded_module.register(self)
|
||||
|
||||
def get_command(self, name: str) -> Callable[..., Any]:
|
||||
return self.commands[name]
|
||||
def get_command(self, name: str) -> Command | None:
|
||||
if name in self.commands:
|
||||
return self.commands[name]
|
||||
|
||||
if name in self.commands_aliases:
|
||||
return self.commands_aliases[name]
|
||||
|
||||
def call(self, command_name: str, **kwargs) -> Any:
|
||||
if command_name not in self.commands:
|
||||
raise KeyError(f"Command '{command_name}' not found in registry.")
|
||||
command = self.commands[command_name]
|
||||
return command(**kwargs)
|
||||
if command := self.get_command(command_name):
|
||||
return command(**kwargs)
|
||||
raise KeyError(f"Command '{command_name}' not found in registry")
|
||||
|
||||
def command_prompt(self) -> str:
|
||||
"""
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
""" A module for generating custom prompt strings."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, TypedDict
|
||||
|
||||
from autogpt.config import Config
|
||||
from autogpt.json_utils.utilities import llm_response_schema
|
||||
@@ -15,19 +17,33 @@ class PromptGenerator:
|
||||
resources, and performance evaluations.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""
|
||||
Initialize the PromptGenerator object with empty lists of constraints,
|
||||
commands, resources, and performance evaluations.
|
||||
"""
|
||||
class Command(TypedDict):
|
||||
label: str
|
||||
name: str
|
||||
params: dict[str, str]
|
||||
function: Optional[Callable]
|
||||
|
||||
constraints: list[str]
|
||||
commands: list[Command]
|
||||
resources: list[str]
|
||||
performance_evaluation: list[str]
|
||||
command_registry: CommandRegistry | None
|
||||
|
||||
# TODO: replace with AIConfig
|
||||
name: str
|
||||
role: str
|
||||
goals: list[str]
|
||||
|
||||
def __init__(self):
|
||||
self.constraints = []
|
||||
self.commands = []
|
||||
self.resources = []
|
||||
self.performance_evaluation = []
|
||||
self.goals = []
|
||||
self.command_registry: CommandRegistry | None = None
|
||||
self.command_registry = None
|
||||
|
||||
self.name = "Bob"
|
||||
self.role = "AI"
|
||||
self.goals = []
|
||||
|
||||
def add_constraint(self, constraint: str) -> None:
|
||||
"""
|
||||
@@ -42,29 +58,29 @@ class PromptGenerator:
|
||||
self,
|
||||
command_label: str,
|
||||
command_name: str,
|
||||
args=None,
|
||||
params: dict[str, str] = {},
|
||||
function: Optional[Callable] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Add a command to the commands list with a label, name, and optional arguments.
|
||||
|
||||
*Should only be used by plugins.* Native commands should be added
|
||||
directly to the CommandRegistry.
|
||||
|
||||
Args:
|
||||
command_label (str): The label of the command.
|
||||
command_name (str): The name of the command.
|
||||
args (dict, optional): A dictionary containing argument names and their
|
||||
params (dict, optional): A dictionary containing argument names and their
|
||||
values. Defaults to None.
|
||||
function (callable, optional): A callable function to be called when
|
||||
the command is executed. Defaults to None.
|
||||
"""
|
||||
if args is None:
|
||||
args = {}
|
||||
command_params = {name: type for name, type in params.items()}
|
||||
|
||||
command_args = {arg_key: arg_value for arg_key, arg_value in args.items()}
|
||||
|
||||
command = {
|
||||
command: PromptGenerator.Command = {
|
||||
"label": command_label,
|
||||
"name": command_name,
|
||||
"args": command_args,
|
||||
"params": command_params,
|
||||
"function": function,
|
||||
}
|
||||
|
||||
@@ -80,10 +96,10 @@ class PromptGenerator:
|
||||
Returns:
|
||||
str: The formatted command string.
|
||||
"""
|
||||
args_string = ", ".join(
|
||||
f'"{key}": "{value}"' for key, value in command["args"].items()
|
||||
params_string = ", ".join(
|
||||
f'"{key}": "{value}"' for key, value in command["params"].items()
|
||||
)
|
||||
return f'{command["label"]}: "{command["name"]}", args: {args_string}'
|
||||
return f'{command["label"]}: "{command["name"]}", params: {params_string}'
|
||||
|
||||
def add_resource(self, resource: str) -> None:
|
||||
"""
|
||||
|
||||
@@ -14,196 +14,218 @@ PARAMETERS = [
|
||||
]
|
||||
|
||||
|
||||
class TestCommand:
|
||||
"""Test cases for the Command class."""
|
||||
|
||||
@staticmethod
|
||||
def example_command_method(arg1: int, arg2: str) -> str:
|
||||
"""Example function for testing the Command class."""
|
||||
# This function is static because it is not used by any other test cases.
|
||||
return f"{arg1} - {arg2}"
|
||||
|
||||
def test_command_creation(self):
|
||||
"""Test that a Command object can be created with the correct attributes."""
|
||||
cmd = Command(
|
||||
name="example",
|
||||
description="Example command",
|
||||
method=self.example_command_method,
|
||||
parameters=PARAMETERS,
|
||||
)
|
||||
|
||||
assert cmd.name == "example"
|
||||
assert cmd.description == "Example command"
|
||||
assert cmd.method == self.example_command_method
|
||||
assert (
|
||||
str(cmd)
|
||||
== "example: Example command, params: (arg1: int, arg2: Optional[str])"
|
||||
)
|
||||
|
||||
def test_command_call(self):
|
||||
"""Test that Command(*args) calls and returns the result of method(*args)."""
|
||||
# Create a Command object with the example_command_method.
|
||||
cmd = Command(
|
||||
name="example",
|
||||
description="Example command",
|
||||
method=self.example_command_method,
|
||||
parameters=[
|
||||
CommandParameter(
|
||||
name="prompt",
|
||||
type="string",
|
||||
description="The prompt used to generate the image",
|
||||
required=True,
|
||||
),
|
||||
],
|
||||
)
|
||||
result = cmd(arg1=1, arg2="test")
|
||||
assert result == "1 - test"
|
||||
|
||||
def test_command_call_with_invalid_arguments(self):
|
||||
"""Test that calling a Command object with invalid arguments raises a TypeError."""
|
||||
cmd = Command(
|
||||
name="example",
|
||||
description="Example command",
|
||||
method=self.example_command_method,
|
||||
parameters=PARAMETERS,
|
||||
)
|
||||
with pytest.raises(TypeError):
|
||||
cmd(arg1="invalid", does_not_exist="test")
|
||||
def example_command_method(arg1: int, arg2: str) -> str:
|
||||
"""Example function for testing the Command class."""
|
||||
# This function is static because it is not used by any other test cases.
|
||||
return f"{arg1} - {arg2}"
|
||||
|
||||
|
||||
class TestCommandRegistry:
|
||||
@staticmethod
|
||||
def example_command_method(arg1: int, arg2: str) -> str:
|
||||
return f"{arg1} - {arg2}"
|
||||
def test_command_creation():
|
||||
"""Test that a Command object can be created with the correct attributes."""
|
||||
cmd = Command(
|
||||
name="example",
|
||||
description="Example command",
|
||||
method=example_command_method,
|
||||
parameters=PARAMETERS,
|
||||
)
|
||||
|
||||
def test_register_command(self):
|
||||
"""Test that a command can be registered to the registry."""
|
||||
registry = CommandRegistry()
|
||||
cmd = Command(
|
||||
name="example",
|
||||
description="Example command",
|
||||
method=self.example_command_method,
|
||||
parameters=PARAMETERS,
|
||||
)
|
||||
assert cmd.name == "example"
|
||||
assert cmd.description == "Example command"
|
||||
assert cmd.method == example_command_method
|
||||
assert (
|
||||
str(cmd) == "example: Example command, params: (arg1: int, arg2: Optional[str])"
|
||||
)
|
||||
|
||||
registry.register(cmd)
|
||||
|
||||
assert cmd.name in registry.commands
|
||||
assert registry.commands[cmd.name] == cmd
|
||||
@pytest.fixture
|
||||
def example_command():
|
||||
yield Command(
|
||||
name="example",
|
||||
description="Example command",
|
||||
method=example_command_method,
|
||||
parameters=PARAMETERS,
|
||||
)
|
||||
|
||||
def test_unregister_command(self):
|
||||
"""Test that a command can be unregistered from the registry."""
|
||||
registry = CommandRegistry()
|
||||
cmd = Command(
|
||||
name="example",
|
||||
description="Example command",
|
||||
method=self.example_command_method,
|
||||
parameters=PARAMETERS,
|
||||
)
|
||||
|
||||
registry.register(cmd)
|
||||
registry.unregister(cmd.name)
|
||||
def test_command_call(example_command: Command):
|
||||
"""Test that Command(*args) calls and returns the result of method(*args)."""
|
||||
result = example_command(arg1=1, arg2="test")
|
||||
assert result == "1 - test"
|
||||
|
||||
assert cmd.name not in registry.commands
|
||||
|
||||
def test_get_command(self):
|
||||
"""Test that a command can be retrieved from the registry."""
|
||||
registry = CommandRegistry()
|
||||
cmd = Command(
|
||||
name="example",
|
||||
description="Example command",
|
||||
method=self.example_command_method,
|
||||
parameters=PARAMETERS,
|
||||
)
|
||||
def test_command_call_with_invalid_arguments(example_command: Command):
|
||||
"""Test that calling a Command object with invalid arguments raises a TypeError."""
|
||||
with pytest.raises(TypeError):
|
||||
example_command(arg1="invalid", does_not_exist="test")
|
||||
|
||||
registry.register(cmd)
|
||||
retrieved_cmd = registry.get_command(cmd.name)
|
||||
|
||||
assert retrieved_cmd == cmd
|
||||
def test_register_command(example_command: Command):
|
||||
"""Test that a command can be registered to the registry."""
|
||||
registry = CommandRegistry()
|
||||
|
||||
def test_get_nonexistent_command(self):
|
||||
"""Test that attempting to get a nonexistent command raises a KeyError."""
|
||||
registry = CommandRegistry()
|
||||
registry.register(example_command)
|
||||
|
||||
with pytest.raises(KeyError):
|
||||
registry.get_command("nonexistent_command")
|
||||
assert registry.get_command(example_command.name) == example_command
|
||||
assert len(registry.commands) == 1
|
||||
|
||||
def test_call_command(self):
|
||||
"""Test that a command can be called through the registry."""
|
||||
registry = CommandRegistry()
|
||||
cmd = Command(
|
||||
name="example",
|
||||
description="Example command",
|
||||
method=self.example_command_method,
|
||||
parameters=PARAMETERS,
|
||||
)
|
||||
|
||||
registry.register(cmd)
|
||||
result = registry.call("example", arg1=1, arg2="test")
|
||||
def test_unregister_command(example_command: Command):
|
||||
"""Test that a command can be unregistered from the registry."""
|
||||
registry = CommandRegistry()
|
||||
|
||||
assert result == "1 - test"
|
||||
registry.register(example_command)
|
||||
registry.unregister(example_command)
|
||||
|
||||
def test_call_nonexistent_command(self):
|
||||
"""Test that attempting to call a nonexistent command raises a KeyError."""
|
||||
registry = CommandRegistry()
|
||||
assert len(registry.commands) == 0
|
||||
assert example_command.name not in registry
|
||||
|
||||
with pytest.raises(KeyError):
|
||||
registry.call("nonexistent_command", arg1=1, arg2="test")
|
||||
|
||||
def test_get_command_prompt(self):
|
||||
"""Test that the command prompt is correctly formatted."""
|
||||
registry = CommandRegistry()
|
||||
cmd = Command(
|
||||
name="example",
|
||||
description="Example command",
|
||||
method=self.example_command_method,
|
||||
parameters=PARAMETERS,
|
||||
)
|
||||
@pytest.fixture
|
||||
def example_command_with_aliases(example_command: Command):
|
||||
example_command.aliases = ["example_alias", "example_alias_2"]
|
||||
return example_command
|
||||
|
||||
registry.register(cmd)
|
||||
command_prompt = registry.command_prompt()
|
||||
|
||||
assert f"(arg1: int, arg2: Optional[str])" in command_prompt
|
||||
def test_register_command_aliases(example_command_with_aliases: Command):
|
||||
"""Test that a command can be registered to the registry."""
|
||||
registry = CommandRegistry()
|
||||
command = example_command_with_aliases
|
||||
|
||||
def test_import_mock_commands_module(self):
|
||||
"""Test that the registry can import a module with mock command plugins."""
|
||||
registry = CommandRegistry()
|
||||
mock_commands_module = "tests.mocks.mock_commands"
|
||||
registry.register(command)
|
||||
|
||||
registry.import_commands(mock_commands_module)
|
||||
assert command.name in registry
|
||||
assert registry.get_command(command.name) == command
|
||||
for alias in command.aliases:
|
||||
assert registry.get_command(alias) == command
|
||||
assert len(registry.commands) == 1
|
||||
|
||||
assert "function_based" in registry.commands
|
||||
assert registry.commands["function_based"].name == "function_based"
|
||||
assert (
|
||||
registry.commands["function_based"].description
|
||||
== "Function-based test command"
|
||||
)
|
||||
|
||||
def test_import_temp_command_file_module(self, tmp_path):
|
||||
"""
|
||||
Test that the registry can import a command plugins module from a temp file.
|
||||
Args:
|
||||
tmp_path (pathlib.Path): Path to a temporary directory.
|
||||
"""
|
||||
registry = CommandRegistry()
|
||||
def test_unregister_command_aliases(example_command_with_aliases: Command):
|
||||
"""Test that a command can be unregistered from the registry."""
|
||||
registry = CommandRegistry()
|
||||
command = example_command_with_aliases
|
||||
|
||||
# Create a temp command file
|
||||
src = Path(os.getcwd()) / "tests/mocks/mock_commands.py"
|
||||
temp_commands_file = tmp_path / "mock_commands.py"
|
||||
shutil.copyfile(src, temp_commands_file)
|
||||
registry.register(command)
|
||||
registry.unregister(command)
|
||||
|
||||
# Add the temp directory to sys.path to make the module importable
|
||||
sys.path.append(str(tmp_path))
|
||||
assert len(registry.commands) == 0
|
||||
assert command.name not in registry
|
||||
for alias in command.aliases:
|
||||
assert alias not in registry
|
||||
|
||||
temp_commands_module = "mock_commands"
|
||||
registry.import_commands(temp_commands_module)
|
||||
|
||||
# Remove the temp directory from sys.path
|
||||
sys.path.remove(str(tmp_path))
|
||||
def test_command_in_registry(example_command_with_aliases: Command):
|
||||
"""Test that `command_name in registry` works."""
|
||||
registry = CommandRegistry()
|
||||
command = example_command_with_aliases
|
||||
|
||||
assert "function_based" in registry.commands
|
||||
assert registry.commands["function_based"].name == "function_based"
|
||||
assert (
|
||||
registry.commands["function_based"].description
|
||||
== "Function-based test command"
|
||||
)
|
||||
assert command.name not in registry
|
||||
assert "nonexistent_command" not in registry
|
||||
|
||||
registry.register(command)
|
||||
|
||||
assert command.name in registry
|
||||
assert "nonexistent_command" not in registry
|
||||
for alias in command.aliases:
|
||||
assert alias in registry
|
||||
|
||||
|
||||
def test_get_command(example_command: Command):
|
||||
"""Test that a command can be retrieved from the registry."""
|
||||
registry = CommandRegistry()
|
||||
|
||||
registry.register(example_command)
|
||||
retrieved_cmd = registry.get_command(example_command.name)
|
||||
|
||||
assert retrieved_cmd == example_command
|
||||
|
||||
|
||||
def test_get_nonexistent_command():
|
||||
"""Test that attempting to get a nonexistent command raises a KeyError."""
|
||||
registry = CommandRegistry()
|
||||
|
||||
assert registry.get_command("nonexistent_command") is None
|
||||
assert "nonexistent_command" not in registry
|
||||
|
||||
|
||||
def test_call_command():
|
||||
"""Test that a command can be called through the registry."""
|
||||
registry = CommandRegistry()
|
||||
cmd = Command(
|
||||
name="example",
|
||||
description="Example command",
|
||||
method=example_command_method,
|
||||
parameters=PARAMETERS,
|
||||
)
|
||||
|
||||
registry.register(cmd)
|
||||
result = registry.call("example", arg1=1, arg2="test")
|
||||
|
||||
assert result == "1 - test"
|
||||
|
||||
|
||||
def test_call_nonexistent_command():
|
||||
"""Test that attempting to call a nonexistent command raises a KeyError."""
|
||||
registry = CommandRegistry()
|
||||
|
||||
with pytest.raises(KeyError):
|
||||
registry.call("nonexistent_command", arg1=1, arg2="test")
|
||||
|
||||
|
||||
def test_get_command_prompt():
|
||||
"""Test that the command prompt is correctly formatted."""
|
||||
registry = CommandRegistry()
|
||||
cmd = Command(
|
||||
name="example",
|
||||
description="Example command",
|
||||
method=example_command_method,
|
||||
parameters=PARAMETERS,
|
||||
)
|
||||
|
||||
registry.register(cmd)
|
||||
command_prompt = registry.command_prompt()
|
||||
|
||||
assert f"(arg1: int, arg2: Optional[str])" in command_prompt
|
||||
|
||||
|
||||
def test_import_mock_commands_module():
|
||||
"""Test that the registry can import a module with mock command plugins."""
|
||||
registry = CommandRegistry()
|
||||
mock_commands_module = "tests.mocks.mock_commands"
|
||||
|
||||
registry.import_commands(mock_commands_module)
|
||||
|
||||
assert "function_based" in registry
|
||||
assert registry.commands["function_based"].name == "function_based"
|
||||
assert (
|
||||
registry.commands["function_based"].description == "Function-based test command"
|
||||
)
|
||||
|
||||
|
||||
def test_import_temp_command_file_module(tmp_path: Path):
|
||||
"""
|
||||
Test that the registry can import a command plugins module from a temp file.
|
||||
Args:
|
||||
tmp_path (pathlib.Path): Path to a temporary directory.
|
||||
"""
|
||||
registry = CommandRegistry()
|
||||
|
||||
# Create a temp command file
|
||||
src = Path(os.getcwd()) / "tests/mocks/mock_commands.py"
|
||||
temp_commands_file = tmp_path / "mock_commands.py"
|
||||
shutil.copyfile(src, temp_commands_file)
|
||||
|
||||
# Add the temp directory to sys.path to make the module importable
|
||||
sys.path.append(str(tmp_path))
|
||||
|
||||
temp_commands_module = "mock_commands"
|
||||
registry.import_commands(temp_commands_module)
|
||||
|
||||
# Remove the temp directory from sys.path
|
||||
sys.path.remove(str(tmp_path))
|
||||
|
||||
assert "function_based" in registry
|
||||
assert registry.commands["function_based"].name == "function_based"
|
||||
assert (
|
||||
registry.commands["function_based"].description == "Function-based test command"
|
||||
)
|
||||
|
||||
@@ -8,17 +8,16 @@ def check_plan():
|
||||
|
||||
def test_execute_command_plugin(agent: Agent):
|
||||
"""Test that executing a command that came from a plugin works as expected"""
|
||||
command_name = "check_plan"
|
||||
agent.ai_config.prompt_generator.add_command(
|
||||
"check_plan",
|
||||
command_name,
|
||||
"Read the plan.md with the next goals to achieve",
|
||||
{},
|
||||
check_plan,
|
||||
)
|
||||
command_name = "check_plan"
|
||||
arguments = {}
|
||||
command_result = execute_command(
|
||||
command_name=command_name,
|
||||
arguments=arguments,
|
||||
arguments={},
|
||||
agent=agent,
|
||||
)
|
||||
assert command_result == "hi"
|
||||
|
||||
@@ -17,13 +17,13 @@ def test_add_command():
|
||||
"""
|
||||
command_label = "Command Label"
|
||||
command_name = "command_name"
|
||||
args = {"arg1": "value1", "arg2": "value2"}
|
||||
params = {"arg1": "value1", "arg2": "value2"}
|
||||
generator = PromptGenerator()
|
||||
generator.add_command(command_label, command_name, args)
|
||||
generator.add_command(command_label, command_name, params)
|
||||
command = {
|
||||
"label": command_label,
|
||||
"name": command_name,
|
||||
"args": args,
|
||||
"params": params,
|
||||
"function": None,
|
||||
}
|
||||
assert command in generator.commands
|
||||
@@ -62,12 +62,12 @@ def test_generate_prompt_string(config):
|
||||
{
|
||||
"label": "Command1",
|
||||
"name": "command_name1",
|
||||
"args": {"arg1": "value1"},
|
||||
"params": {"arg1": "value1"},
|
||||
},
|
||||
{
|
||||
"label": "Command2",
|
||||
"name": "command_name2",
|
||||
"args": {},
|
||||
"params": {},
|
||||
},
|
||||
]
|
||||
resources = ["Resource1", "Resource2"]
|
||||
@@ -78,7 +78,7 @@ def test_generate_prompt_string(config):
|
||||
for constraint in constraints:
|
||||
generator.add_constraint(constraint)
|
||||
for command in commands:
|
||||
generator.add_command(command["label"], command["name"], command["args"])
|
||||
generator.add_command(command["label"], command["name"], command["params"])
|
||||
for resource in resources:
|
||||
generator.add_resource(resource)
|
||||
for evaluation in evaluations:
|
||||
@@ -93,58 +93,7 @@ def test_generate_prompt_string(config):
|
||||
assert constraint in prompt_string
|
||||
for command in commands:
|
||||
assert command["name"] in prompt_string
|
||||
for key, value in command["args"].items():
|
||||
assert f'"{key}": "{value}"' in prompt_string
|
||||
for resource in resources:
|
||||
assert resource in prompt_string
|
||||
for evaluation in evaluations:
|
||||
assert evaluation in prompt_string
|
||||
|
||||
|
||||
def test_generate_prompt_string(config):
|
||||
"""
|
||||
Test if the generate_prompt_string() method generates a prompt string with all the added
|
||||
constraints, commands, resources, and evaluations.
|
||||
"""
|
||||
|
||||
# Define the test data
|
||||
constraints = ["Constraint1", "Constraint2"]
|
||||
commands = [
|
||||
{
|
||||
"label": "Command1",
|
||||
"name": "command_name1",
|
||||
"args": {"arg1": "value1"},
|
||||
},
|
||||
{
|
||||
"label": "Command2",
|
||||
"name": "command_name2",
|
||||
"args": {},
|
||||
},
|
||||
]
|
||||
resources = ["Resource1", "Resource2"]
|
||||
evaluations = ["Evaluation1", "Evaluation2"]
|
||||
|
||||
# Add test data to the generator
|
||||
generator = PromptGenerator()
|
||||
for constraint in constraints:
|
||||
generator.add_constraint(constraint)
|
||||
for command in commands:
|
||||
generator.add_command(command["label"], command["name"], command["args"])
|
||||
for resource in resources:
|
||||
generator.add_resource(resource)
|
||||
for evaluation in evaluations:
|
||||
generator.add_performance_evaluation(evaluation)
|
||||
|
||||
# Generate the prompt string and verify its correctness
|
||||
prompt_string = generator.generate_prompt_string(config)
|
||||
assert prompt_string is not None
|
||||
|
||||
# Check if all constraints, commands, resources, and evaluations are present in the prompt string
|
||||
for constraint in constraints:
|
||||
assert constraint in prompt_string
|
||||
for command in commands:
|
||||
assert command["name"] in prompt_string
|
||||
for key, value in command["args"].items():
|
||||
for key, value in command["params"].items():
|
||||
assert f'"{key}": "{value}"' in prompt_string
|
||||
for resource in resources:
|
||||
assert resource in prompt_string
|
||||
|
||||
Reference in New Issue
Block a user