mirror of
https://github.com/aljazceru/Auto-GPT.git
synced 2026-01-31 20:04:28 +01:00
Refactor prompts into package, make the prompt able to be stored with the AI config and changed. Fix settings file.
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
"""Main script for the autogpt package."""
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from colorama import Fore
|
||||
from autogpt.agent.agent import Agent
|
||||
from autogpt.args import parse_arguments
|
||||
@@ -8,7 +10,9 @@ from autogpt.config import Config, check_openai_api_key
|
||||
from autogpt.logs import logger
|
||||
from autogpt.memory import get_memory
|
||||
|
||||
from autogpt.prompt import construct_prompt
|
||||
from autogpt.prompts.prompt import construct_prompt
|
||||
from autogpt.plugins import load_plugins
|
||||
|
||||
|
||||
# Load environment variables from .env file
|
||||
|
||||
@@ -36,8 +40,7 @@ def main() -> None:
|
||||
loaded_plugins.append(plugin())
|
||||
|
||||
if loaded_plugins:
|
||||
print(f"\nPlugins found: {len(loaded_plugins)}\n"
|
||||
"--------------------")
|
||||
print(f"\nPlugins found: {len(loaded_plugins)}\n" "--------------------")
|
||||
for plugin in loaded_plugins:
|
||||
print(f"{plugin._name}: {plugin._version} - {plugin._description}")
|
||||
|
||||
|
||||
@@ -186,7 +186,7 @@ def execute_command(command_name: str, arguments):
|
||||
elif command_name == "generate_image":
|
||||
return generate_image(arguments["prompt"])
|
||||
elif command_name == "send_tweet":
|
||||
return send_tweet(arguments['text'])
|
||||
return send_tweet(arguments["text"])
|
||||
elif command_name == "do_nothing":
|
||||
return "No action performed."
|
||||
elif command_name == "task_complete":
|
||||
|
||||
@@ -23,7 +23,9 @@ def read_audio(audio):
|
||||
headers = {"Authorization": f"Bearer {api_token}"}
|
||||
|
||||
if api_token is None:
|
||||
raise ValueError("You need to set your Hugging Face API token in the config file.")
|
||||
raise ValueError(
|
||||
"You need to set your Hugging Face API token in the config file."
|
||||
)
|
||||
|
||||
response = requests.post(
|
||||
api_url,
|
||||
@@ -31,5 +33,5 @@ def read_audio(audio):
|
||||
data=audio,
|
||||
)
|
||||
|
||||
text = json.loads(response.content.decode("utf-8"))['text']
|
||||
text = json.loads(response.content.decode("utf-8"))["text"]
|
||||
return "The audio says: " + text
|
||||
|
||||
@@ -7,9 +7,9 @@ load_dotenv()
|
||||
|
||||
def send_tweet(tweet_text):
|
||||
consumer_key = os.environ.get("TW_CONSUMER_KEY")
|
||||
consumer_secret= os.environ.get("TW_CONSUMER_SECRET")
|
||||
access_token= os.environ.get("TW_ACCESS_TOKEN")
|
||||
access_token_secret= os.environ.get("TW_ACCESS_TOKEN_SECRET")
|
||||
consumer_secret = os.environ.get("TW_CONSUMER_SECRET")
|
||||
access_token = os.environ.get("TW_ACCESS_TOKEN")
|
||||
access_token_secret = os.environ.get("TW_ACCESS_TOKEN_SECRET")
|
||||
# Authenticate to Twitter
|
||||
auth = tweepy.OAuthHandler(consumer_key, consumer_secret)
|
||||
auth.set_access_token(access_token, access_token_secret)
|
||||
|
||||
@@ -3,9 +3,12 @@
|
||||
A module that contains the AIConfig class object that contains the configuration
|
||||
"""
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Type
|
||||
import yaml
|
||||
|
||||
from autogpt.prompts.generator import PromptGenerator
|
||||
|
||||
|
||||
class AIConfig:
|
||||
"""
|
||||
@@ -35,9 +38,10 @@ class AIConfig:
|
||||
self.ai_name = ai_name
|
||||
self.ai_role = ai_role
|
||||
self.ai_goals = ai_goals
|
||||
self.prompt_generator = None
|
||||
|
||||
# Soon this will go in a folder where it remembers more stuff about the run(s)
|
||||
SAVE_FILE = os.path.join(os.path.dirname(__file__), "..", "ai_settings.yaml")
|
||||
SAVE_FILE = Path(os.getcwd()) / "ai_settings.yaml"
|
||||
|
||||
@staticmethod
|
||||
def load(config_file: str = SAVE_FILE) -> "AIConfig":
|
||||
@@ -86,7 +90,7 @@ class AIConfig:
|
||||
with open(config_file, "w", encoding="utf-8") as file:
|
||||
yaml.dump(config, file, allow_unicode=True)
|
||||
|
||||
def construct_full_prompt(self) -> str:
|
||||
def construct_full_prompt(self, prompt_generator: Optional[PromptGenerator] = None) -> str:
|
||||
"""
|
||||
Returns a prompt to the user with the class information in an organized fashion.
|
||||
|
||||
@@ -105,14 +109,16 @@ class AIConfig:
|
||||
""
|
||||
)
|
||||
|
||||
from autogpt.prompt import get_prompt
|
||||
|
||||
from autogpt.prompts.prompt import build_default_prompt_generator
|
||||
# Construct full prompt
|
||||
full_prompt = (
|
||||
f"You are {self.ai_name}, {self.ai_role}\n{prompt_start}\n\nGOALS:\n\n"
|
||||
)
|
||||
for i, goal in enumerate(self.ai_goals):
|
||||
full_prompt += f"{i+1}. {goal}\n"
|
||||
|
||||
full_prompt += f"\n\n{get_prompt()}"
|
||||
if prompt_generator is None:
|
||||
prompt_generator = build_default_prompt_generator()
|
||||
prompt_generator.goals = self.ai_goals
|
||||
self.prompt_generator = prompt_generator
|
||||
full_prompt += f"\n\n{prompt_generator.generate_prompt_string()}"
|
||||
return full_prompt
|
||||
|
||||
@@ -18,7 +18,7 @@ def inspect_zip_for_module(zip_path: str, debug: bool = False) -> Optional[str]:
|
||||
Returns:
|
||||
Optional[str]: The name of the module if found, else None.
|
||||
"""
|
||||
with zipfile.ZipFile(zip_path, 'r') as zfile:
|
||||
with zipfile.ZipFile(zip_path, "r") as zfile:
|
||||
for name in zfile.namelist():
|
||||
if name.endswith("__init__.py"):
|
||||
if debug:
|
||||
@@ -68,7 +68,6 @@ def load_plugins(plugins_path: Path, debug: bool = False) -> List[Module]:
|
||||
continue
|
||||
a_module = getattr(zipped_module, key)
|
||||
a_keys = dir(a_module)
|
||||
if '_abc_impl' in a_keys and \
|
||||
a_module.__name__ != 'AutoGPTPluginTemplate':
|
||||
if "_abc_impl" in a_keys and a_module.__name__ != "AutoGPTPluginTemplate":
|
||||
plugin_modules.append(a_module)
|
||||
return plugin_modules
|
||||
0
autogpt/prompts/__init__.py
Normal file
0
autogpt/prompts/__init__.py
Normal file
@@ -1,6 +1,6 @@
|
||||
""" A module for generating custom prompt strings."""
|
||||
import json
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
|
||||
class PromptGenerator:
|
||||
@@ -18,6 +18,7 @@ class PromptGenerator:
|
||||
self.commands = []
|
||||
self.resources = []
|
||||
self.performance_evaluation = []
|
||||
self.goals = []
|
||||
self.response_format = {
|
||||
"thoughts": {
|
||||
"text": "thought",
|
||||
@@ -38,7 +39,13 @@ class PromptGenerator:
|
||||
"""
|
||||
self.constraints.append(constraint)
|
||||
|
||||
def add_command(self, command_label: str, command_name: str, args=None) -> None:
|
||||
def add_command(
|
||||
self,
|
||||
command_label: str,
|
||||
command_name: str,
|
||||
args=None,
|
||||
function: Optional[Callable] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Add a command to the commands list with a label, name, and optional arguments.
|
||||
|
||||
@@ -47,6 +54,8 @@ class PromptGenerator:
|
||||
command_name (str): The name of the command.
|
||||
args (dict, optional): A dictionary containing argument names and their
|
||||
values. Defaults to None.
|
||||
function (callable, optional): A callable function to be called when
|
||||
the command is executed. Defaults to None.
|
||||
"""
|
||||
if args is None:
|
||||
args = {}
|
||||
@@ -57,6 +66,7 @@ class PromptGenerator:
|
||||
"label": command_label,
|
||||
"name": command_name,
|
||||
"args": command_args,
|
||||
"function": function,
|
||||
}
|
||||
|
||||
self.commands.append(command)
|
||||
@@ -2,15 +2,14 @@ from colorama import Fore
|
||||
from autogpt.config.ai_config import AIConfig
|
||||
from autogpt.config.config import Config
|
||||
from autogpt.logs import logger
|
||||
from autogpt.promptgenerator import PromptGenerator
|
||||
from autogpt.config import Config
|
||||
from autogpt.prompts.generator import PromptGenerator
|
||||
from autogpt.setup import prompt_user
|
||||
from autogpt.utils import clean_input
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
def get_prompt() -> str:
|
||||
def build_default_prompt_generator() -> PromptGenerator:
|
||||
"""
|
||||
This function generates a prompt string that includes various constraints,
|
||||
commands, resources, and performance evaluations.
|
||||
@@ -19,9 +18,6 @@ def get_prompt() -> str:
|
||||
str: The generated prompt string.
|
||||
"""
|
||||
|
||||
# Initialize the Config object
|
||||
cfg = Config()
|
||||
|
||||
# Initialize the PromptGenerator object
|
||||
prompt_generator = PromptGenerator()
|
||||
|
||||
@@ -84,11 +80,10 @@ def get_prompt() -> str:
|
||||
("Generate Image", "generate_image", {"prompt": "<prompt>"}),
|
||||
("Convert Audio to text", "read_audio_from_file", {"file": "<file>"}),
|
||||
("Send Tweet", "send_tweet", {"text": "<text>"}),
|
||||
|
||||
]
|
||||
|
||||
# Only add shell command to the prompt if the AI is allowed to execute it
|
||||
if cfg.execute_local_commands:
|
||||
if CFG.execute_local_commands:
|
||||
commands.append(
|
||||
(
|
||||
"Execute Shell Command, non-interactive commands only",
|
||||
@@ -135,8 +130,7 @@ def get_prompt() -> str:
|
||||
" the least number of steps."
|
||||
)
|
||||
|
||||
# Generate the prompt string
|
||||
return prompt_generator.generate_prompt_string()
|
||||
return prompt_generator
|
||||
|
||||
|
||||
def construct_prompt() -> str:
|
||||
@@ -171,8 +165,4 @@ Continue (y/n): """
|
||||
config = prompt_user()
|
||||
config.save()
|
||||
|
||||
# Get rid of this global:
|
||||
global ai_name
|
||||
ai_name = config.ai_name
|
||||
|
||||
return config.construct_full_prompt()
|
||||
@@ -1,6 +1,6 @@
|
||||
from unittest import TestCase
|
||||
|
||||
from autogpt.promptgenerator import PromptGenerator
|
||||
from autogpt.prompts.generator import PromptGenerator
|
||||
|
||||
|
||||
class TestPromptGenerator(TestCase):
|
||||
|
||||
@@ -50,7 +50,9 @@ class TestScrapeText:
|
||||
# Tests that the function returns an error message when an invalid or unreachable url is provided.
|
||||
def test_invalid_url(self, mocker):
|
||||
# Mock the requests.get() method to raise an exception
|
||||
mocker.patch("requests.Session.get", side_effect=requests.exceptions.RequestException)
|
||||
mocker.patch(
|
||||
"requests.Session.get", side_effect=requests.exceptions.RequestException
|
||||
)
|
||||
|
||||
# Call the function with an invalid URL and assert that it returns an error message
|
||||
url = "http://www.invalidurl.com"
|
||||
|
||||
Reference in New Issue
Block a user