diff --git a/autogpt/agents/agent.py b/autogpt/agents/agent.py index 29686c09..16e5f163 100644 --- a/autogpt/agents/agent.py +++ b/autogpt/agents/agent.py @@ -347,11 +347,8 @@ def execute_command( raise CommandExecutionError(str(e)) # Handle non-native commands (e.g. from plugins) - for command in agent.prompt_generator.commands: - if ( - command_name == command.label.lower() - or command_name == command.name.lower() - ): + for name, command in agent.prompt_generator.commands.items(): + if command_name == name or command_name.lower() == command.description.lower(): try: return command.function(**arguments) except AgentException: diff --git a/autogpt/agents/base.py b/autogpt/agents/base.py index bef16cb0..4b682ff7 100644 --- a/autogpt/agents/base.py +++ b/autogpt/agents/base.py @@ -322,9 +322,7 @@ class BaseAgent(metaclass=ABCMeta): for i, plugin in enumerate(self.config.plugins): if not plugin.can_handle_on_planning(): continue - plugin_response = plugin.on_planning( - self.ai_config.prompt_generator, prompt.raw() - ) + plugin_response = plugin.on_planning(self.prompt_generator, prompt.raw()) if not plugin_response or plugin_response == "": continue message_to_add = Message("system", plugin_response) diff --git a/autogpt/prompts/generator.py b/autogpt/prompts/generator.py index c5e1c524..8bc2f378 100644 --- a/autogpt/prompts/generator.py +++ b/autogpt/prompts/generator.py @@ -1,6 +1,7 @@ """ A module for generating custom prompt strings.""" from __future__ import annotations +import logging import platform from dataclasses import dataclass from typing import TYPE_CHECKING, Callable, Optional @@ -12,6 +13,8 @@ if TYPE_CHECKING: from autogpt.config import AIConfig, AIDirectives, Config from autogpt.models.command_registry import CommandRegistry +logger = logging.getLogger(__name__) + class PromptGenerator: """ @@ -25,7 +28,7 @@ class PromptGenerator: constraints: list[str] resources: list[str] - commands: list[Command] + commands: dict[str, Command] command_registry: CommandRegistry def __init__( @@ -38,13 +41,13 @@ class PromptGenerator: self.best_practices = ai_directives.best_practices self.constraints = ai_directives.constraints self.resources = ai_directives.resources - self.commands = [] + self.commands = {} self.command_registry = command_registry @dataclass class Command: - label: str name: str + description: str params: dict[str, str] function: Optional[Callable] @@ -53,7 +56,7 @@ class PromptGenerator: params_string = ", ".join( f'"{key}": "{value}"' for key, value in self.params.items() ) - return f'{self.label}: "{self.name.rstrip(".")}". Params: ({params_string})' + return f'{self.name}: "{self.description.rstrip(".")}". Params: ({params_string})' def add_constraint(self, constraint: str) -> None: """ @@ -62,39 +65,45 @@ class PromptGenerator: Params: constraint (str): The constraint to be added. """ - self.constraints.append(constraint) + if constraint not in self.constraints: + self.constraints.append(constraint) def add_command( self, - command_label: str, - command_name: str, + name: str, + description: str, params: dict[str, str] = {}, function: Optional[Callable] = None, ) -> None: """ - Add a command to the commands list with a label, name, and optional arguments. + Registers a command. *Should only be used by plugins.* Native commands should be added directly to the CommandRegistry. Params: - command_label (str): The label of the command. - command_name (str): The name of the command. + name (str): The name of the command (e.g. `command_name`). + description (str): The description of the command. params (dict, optional): A dictionary containing argument names and their - values. Defaults to None. + types. Defaults to an empty dictionary. function (callable, optional): A callable function to be called when the command is executed. Defaults to None. """ - - self.commands.append( - PromptGenerator.Command( - label=command_label, - name=command_name, - params={name: type for name, type in params.items()}, - function=function, - ) + command = PromptGenerator.Command( + name=name, + description=description, + params={name: type for name, type in params.items()}, + function=function, ) + if name in self.commands: + if description == self.commands[name].description: + return + logger.warning( + f"Replacing command {self.commands[name]} with conflicting {command}" + ) + self.commands[name] = command + def add_resource(self, resource: str) -> None: """ Add a resource to the resources list. @@ -102,7 +111,8 @@ class PromptGenerator: Params: resource (str): The resource to be added. """ - self.resources.append(resource) + if resource not in self.resources: + self.resources.append(resource) def add_best_practice(self, best_practice: str) -> None: """ @@ -111,7 +121,8 @@ class PromptGenerator: Params: best_practice (str): The best practice item to be added. """ - self.best_practices.append(best_practice) + if best_practice not in self.best_practices: + self.best_practices.append(best_practice) def _generate_numbered_list(self, items: list[str], start_at: int = 1) -> str: """ diff --git a/tests/integration/agent_factory.py b/tests/integration/agent_factory.py index 620721a8..c63eaf80 100644 --- a/tests/integration/agent_factory.py +++ b/tests/integration/agent_factory.py @@ -29,7 +29,6 @@ def dummy_agent(config: Config, memory_json_file): "Dummy Task", ], ) - ai_config.command_registry = command_registry agent = Agent( memory=memory_json_file,