diff --git a/autogpt/agent/agent_manager.py b/autogpt/agent/agent_manager.py index d2648150..9f123eaa 100644 --- a/autogpt/agent/agent_manager.py +++ b/autogpt/agent/agent_manager.py @@ -1,8 +1,10 @@ """Agent manager for managing GPT agents""" from __future__ import annotations +from typing import List from autogpt.config.config import Config, Singleton from autogpt.llm_utils import create_chat_completion +from autogpt.types.openai import Message class AgentManager(metaclass=Singleton): @@ -27,17 +29,14 @@ class AgentManager(metaclass=Singleton): Returns: The key of the new agent """ - messages = [ + messages: List[Message] = [ {"role": "user", "content": prompt}, ] for plugin in self.cfg.plugins: if not plugin.can_handle_pre_instruction(): continue - plugin_messages = plugin.pre_instruction(messages) - if plugin_messages: - for plugin_message in plugin_messages: - messages.append({"role": "system", "content": plugin_message}) - + if plugin_messages := plugin.pre_instruction(messages): + messages.extend(iter(plugin_messages)) # Start GPT instance agent_reply = create_chat_completion( model=model, @@ -50,9 +49,8 @@ class AgentManager(metaclass=Singleton): for i, plugin in enumerate(self.cfg.plugins): if not plugin.can_handle_on_instruction(): continue - plugin_result = plugin.on_instruction(messages) - if plugin_result: - sep = "" if not i else "\n" + if plugin_result := plugin.on_instruction(messages): + sep = "\n" if i else "" plugins_reply = f"{plugins_reply}{sep}{plugin_result}" if plugins_reply and plugins_reply != "": @@ -89,10 +87,9 @@ class AgentManager(metaclass=Singleton): for plugin in self.cfg.plugins: if not plugin.can_handle_pre_instruction(): continue - plugin_messages = plugin.pre_instruction(messages) - if plugin_messages: + if plugin_messages := plugin.pre_instruction(messages): for plugin_message in plugin_messages: - messages.append({"role": "system", "content": plugin_message}) + messages.append(plugin_message) # Start GPT instance agent_reply = create_chat_completion( @@ -106,9 +103,8 @@ class AgentManager(metaclass=Singleton): for i, plugin in enumerate(self.cfg.plugins): if not plugin.can_handle_on_instruction(): continue - plugin_result = plugin.on_instruction(messages) - if plugin_result: - sep = "" if not i else "\n" + if plugin_result := plugin.on_instruction(messages): + sep = "\n" if i else "" plugins_reply = f"{plugins_reply}{sep}{plugin_result}" # Update full message history if plugins_reply and plugins_reply != "": diff --git a/autogpt/chat.py b/autogpt/chat.py index 22fe636c..f9fc9471 100644 --- a/autogpt/chat.py +++ b/autogpt/chat.py @@ -6,11 +6,12 @@ from autogpt import token_counter from autogpt.config import Config from autogpt.llm_utils import create_chat_completion from autogpt.logs import logger +from autogpt.types.openai import Message cfg = Config() -def create_chat_message(role, content): +def create_chat_message(role, content) -> Message: """ Create a chat message with the given role and content. @@ -145,7 +146,7 @@ def chat_with_ai( if not plugin_response or plugin_response == "": continue tokens_to_add = token_counter.count_message_tokens( - [plugin_response], model + [create_chat_message("system", plugin_response)], model ) if current_tokens_used + tokens_to_add > send_token_limit: if cfg.debug_mode: diff --git a/autogpt/config/config.py b/autogpt/config/config.py index c12eed2e..f93bf17a 100644 --- a/autogpt/config/config.py +++ b/autogpt/config/config.py @@ -1,7 +1,9 @@ """Configuration class to store the state of bools for different scripts access.""" import os +from typing import List import openai +from auto_gpt_plugin_template import AutoGPTPluginTemplate import yaml from colorama import Fore from dotenv import load_dotenv @@ -107,7 +109,7 @@ class Config(metaclass=Singleton): # Initialize the OpenAI API client openai.api_key = self.openai_api_key - self.plugins = [] + self.plugins: List[AutoGPTPluginTemplate] = [] self.plugins_whitelist = [] self.plugins_blacklist = [] diff --git a/autogpt/llm_utils.py b/autogpt/llm_utils.py index 4fb0e1f5..a6d87c30 100644 --- a/autogpt/llm_utils.py +++ b/autogpt/llm_utils.py @@ -1,12 +1,14 @@ from __future__ import annotations import time +from typing import List, Optional import openai from colorama import Fore from openai.error import APIError, RateLimitError from autogpt.config import Config +from autogpt.types.openai import Message CFG = Config() @@ -35,8 +37,8 @@ def call_ai_function( # For each arg, if any are None, convert to "None": args = [str(arg) if arg is not None else "None" for arg in args] # parse args to comma separated string - args = ", ".join(args) - messages = [ + args: str = ", ".join(args) + messages: List[Message] = [ { "role": "system", "content": f"You are now the following python function: ```# {description}" @@ -51,15 +53,15 @@ def call_ai_function( # Overly simple abstraction until we create something better # simple retry mechanism when getting a rate error or a bad gateway def create_chat_completion( - messages: list, # type: ignore - model: str | None = None, + messages: List[Message], # type: ignore + model: Optional[str] = None, temperature: float = CFG.temperature, - max_tokens: int | None = None, + max_tokens: Optional[int] = None, ) -> str: """Create a chat completion using the OpenAI API Args: - messages (list[dict[str, str]]): The messages to send to the chat completion + messages (List[Message]): The messages to send to the chat completion model (str, optional): The model to use. Defaults to None. temperature (float, optional): The temperature to use. Defaults to 0.9. max_tokens (int, optional): The max tokens to use. Defaults to None. @@ -67,13 +69,10 @@ def create_chat_completion( Returns: str: The response from the chat completion """ - response = None num_retries = 10 if CFG.debug_mode: print( - Fore.GREEN - + f"Creating chat completion with model {model}, temperature {temperature}," - f" max_tokens {max_tokens}" + Fore.RESET + f"{Fore.GREEN}Creating chat completion with model {model}, temperature {temperature}, max_tokens {max_tokens}{Fore.RESET}" ) for plugin in CFG.plugins: if plugin.can_handle_chat_completion( @@ -82,13 +81,13 @@ def create_chat_completion( temperature=temperature, max_tokens=max_tokens, ): - response = plugin.handle_chat_completion( + return plugin.handle_chat_completion( messages=messages, model=model, temperature=temperature, max_tokens=max_tokens, ) - return response + response = None for attempt in range(num_retries): backoff = 2 ** (attempt + 2) try: @@ -111,20 +110,17 @@ def create_chat_completion( except RateLimitError: if CFG.debug_mode: print( - Fore.RED + "Error: ", - "Reached rate limit, passing..." + Fore.RESET, + f"{Fore.RED}Error: ", f"Reached rate limit, passing...{Fore.RESET}" ) except APIError as e: - if e.http_status == 502: - pass - else: + if e.http_status != 502: raise if attempt == num_retries - 1: raise if CFG.debug_mode: print( - Fore.RED + "Error: ", - f"API Bad gateway. Waiting {backoff} seconds..." + Fore.RESET, + f"{Fore.RED}Error: ", + f"API Bad gateway. Waiting {backoff} seconds...{Fore.RESET}", ) time.sleep(backoff) if response is None: @@ -157,15 +153,13 @@ def create_embedding_with_ada(text) -> list: except RateLimitError: pass except APIError as e: - if e.http_status == 502: - pass - else: + if e.http_status != 502: raise if attempt == num_retries - 1: raise if CFG.debug_mode: print( - Fore.RED + "Error: ", - f"API Bad gateway. Waiting {backoff} seconds..." + Fore.RESET, + f"{Fore.RED}Error: ", + f"API Bad gateway. Waiting {backoff} seconds...{Fore.RESET}", ) time.sleep(backoff) diff --git a/autogpt/plugins.py b/autogpt/plugins.py index a00b989e..a4d9c17c 100644 --- a/autogpt/plugins.py +++ b/autogpt/plugins.py @@ -6,6 +6,8 @@ from pathlib import Path from typing import List, Optional, Tuple from zipimport import zipimporter +from auto_gpt_plugin_template import AutoGPTPluginTemplate + def inspect_zip_for_module(zip_path: str, debug: bool = False) -> Optional[str]: """ @@ -45,7 +47,9 @@ def scan_plugins(plugins_path: Path, debug: bool = False) -> List[Tuple[str, Pat return plugins -def load_plugins(plugins_path: Path, debug: bool = False) -> List[Module]: +def load_plugins( + plugins_path: Path, debug: bool = False +) -> List[AutoGPTPluginTemplate]: """Load plugins from the plugins directory. Args: diff --git a/autogpt/token_counter.py b/autogpt/token_counter.py index 338fe6be..b1e59d86 100644 --- a/autogpt/token_counter.py +++ b/autogpt/token_counter.py @@ -1,13 +1,15 @@ """Functions for counting the number of tokens in a message or string.""" from __future__ import annotations +from typing import List import tiktoken from autogpt.logs import logger +from autogpt.types.openai import Message def count_message_tokens( - messages: list[dict[str, str]], model: str = "gpt-3.5-turbo-0301" + messages: List[Message], model: str = "gpt-3.5-turbo-0301" ) -> int: """ Returns the number of tokens used by a list of messages. diff --git a/autogpt/types/openai.py b/autogpt/types/openai.py new file mode 100644 index 00000000..2af85785 --- /dev/null +++ b/autogpt/types/openai.py @@ -0,0 +1,9 @@ +"""Type helpers for working with the OpenAI library""" +from typing import TypedDict + + +class Message(TypedDict): + """OpenAI Message object containing a role and the message content""" + + role: str + content: str diff --git a/requirements.txt b/requirements.txt index 843b66bf..86d24b5a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -29,6 +29,8 @@ black sourcery isort gitpython==3.1.31 +abstract-singleton +auto-gpt-plugin-template # Testing dependencies pytest