Merge pull request #6 from TaylorBeeston/type-fixes

Type fixes
This commit is contained in:
BillSchumacher
2023-04-18 17:57:23 -05:00
committed by GitHub
8 changed files with 54 additions and 44 deletions

View File

@@ -1,8 +1,10 @@
"""Agent manager for managing GPT agents"""
from __future__ import annotations
from typing import List
from autogpt.config.config import Config, Singleton
from autogpt.llm_utils import create_chat_completion
from autogpt.types.openai import Message
class AgentManager(metaclass=Singleton):
@@ -27,17 +29,14 @@ class AgentManager(metaclass=Singleton):
Returns:
The key of the new agent
"""
messages = [
messages: List[Message] = [
{"role": "user", "content": prompt},
]
for plugin in self.cfg.plugins:
if not plugin.can_handle_pre_instruction():
continue
plugin_messages = plugin.pre_instruction(messages)
if plugin_messages:
for plugin_message in plugin_messages:
messages.append({"role": "system", "content": plugin_message})
if plugin_messages := plugin.pre_instruction(messages):
messages.extend(iter(plugin_messages))
# Start GPT instance
agent_reply = create_chat_completion(
model=model,
@@ -50,9 +49,8 @@ class AgentManager(metaclass=Singleton):
for i, plugin in enumerate(self.cfg.plugins):
if not plugin.can_handle_on_instruction():
continue
plugin_result = plugin.on_instruction(messages)
if plugin_result:
sep = "" if not i else "\n"
if plugin_result := plugin.on_instruction(messages):
sep = "\n" if i else ""
plugins_reply = f"{plugins_reply}{sep}{plugin_result}"
if plugins_reply and plugins_reply != "":
@@ -89,10 +87,9 @@ class AgentManager(metaclass=Singleton):
for plugin in self.cfg.plugins:
if not plugin.can_handle_pre_instruction():
continue
plugin_messages = plugin.pre_instruction(messages)
if plugin_messages:
if plugin_messages := plugin.pre_instruction(messages):
for plugin_message in plugin_messages:
messages.append({"role": "system", "content": plugin_message})
messages.append(plugin_message)
# Start GPT instance
agent_reply = create_chat_completion(
@@ -106,9 +103,8 @@ class AgentManager(metaclass=Singleton):
for i, plugin in enumerate(self.cfg.plugins):
if not plugin.can_handle_on_instruction():
continue
plugin_result = plugin.on_instruction(messages)
if plugin_result:
sep = "" if not i else "\n"
if plugin_result := plugin.on_instruction(messages):
sep = "\n" if i else ""
plugins_reply = f"{plugins_reply}{sep}{plugin_result}"
# Update full message history
if plugins_reply and plugins_reply != "":

View File

@@ -6,11 +6,12 @@ from autogpt import token_counter
from autogpt.config import Config
from autogpt.llm_utils import create_chat_completion
from autogpt.logs import logger
from autogpt.types.openai import Message
cfg = Config()
def create_chat_message(role, content):
def create_chat_message(role, content) -> Message:
"""
Create a chat message with the given role and content.
@@ -145,7 +146,7 @@ def chat_with_ai(
if not plugin_response or plugin_response == "":
continue
tokens_to_add = token_counter.count_message_tokens(
[plugin_response], model
[create_chat_message("system", plugin_response)], model
)
if current_tokens_used + tokens_to_add > send_token_limit:
if cfg.debug_mode:

View File

@@ -1,7 +1,9 @@
"""Configuration class to store the state of bools for different scripts access."""
import os
from typing import List
import openai
from auto_gpt_plugin_template import AutoGPTPluginTemplate
import yaml
from colorama import Fore
from dotenv import load_dotenv
@@ -107,7 +109,7 @@ class Config(metaclass=Singleton):
# Initialize the OpenAI API client
openai.api_key = self.openai_api_key
self.plugins = []
self.plugins: List[AutoGPTPluginTemplate] = []
self.plugins_whitelist = []
self.plugins_blacklist = []

View File

@@ -1,12 +1,14 @@
from __future__ import annotations
import time
from typing import List, Optional
import openai
from colorama import Fore
from openai.error import APIError, RateLimitError
from autogpt.config import Config
from autogpt.types.openai import Message
CFG = Config()
@@ -35,8 +37,8 @@ def call_ai_function(
# For each arg, if any are None, convert to "None":
args = [str(arg) if arg is not None else "None" for arg in args]
# parse args to comma separated string
args = ", ".join(args)
messages = [
args: str = ", ".join(args)
messages: List[Message] = [
{
"role": "system",
"content": f"You are now the following python function: ```# {description}"
@@ -51,15 +53,15 @@ def call_ai_function(
# Overly simple abstraction until we create something better
# simple retry mechanism when getting a rate error or a bad gateway
def create_chat_completion(
messages: list, # type: ignore
model: str | None = None,
messages: List[Message], # type: ignore
model: Optional[str] = None,
temperature: float = CFG.temperature,
max_tokens: int | None = None,
max_tokens: Optional[int] = None,
) -> str:
"""Create a chat completion using the OpenAI API
Args:
messages (list[dict[str, str]]): The messages to send to the chat completion
messages (List[Message]): The messages to send to the chat completion
model (str, optional): The model to use. Defaults to None.
temperature (float, optional): The temperature to use. Defaults to 0.9.
max_tokens (int, optional): The max tokens to use. Defaults to None.
@@ -67,13 +69,10 @@ def create_chat_completion(
Returns:
str: The response from the chat completion
"""
response = None
num_retries = 10
if CFG.debug_mode:
print(
Fore.GREEN
+ f"Creating chat completion with model {model}, temperature {temperature},"
f" max_tokens {max_tokens}" + Fore.RESET
f"{Fore.GREEN}Creating chat completion with model {model}, temperature {temperature}, max_tokens {max_tokens}{Fore.RESET}"
)
for plugin in CFG.plugins:
if plugin.can_handle_chat_completion(
@@ -82,13 +81,13 @@ def create_chat_completion(
temperature=temperature,
max_tokens=max_tokens,
):
response = plugin.handle_chat_completion(
return plugin.handle_chat_completion(
messages=messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
)
return response
response = None
for attempt in range(num_retries):
backoff = 2 ** (attempt + 2)
try:
@@ -111,20 +110,17 @@ def create_chat_completion(
except RateLimitError:
if CFG.debug_mode:
print(
Fore.RED + "Error: ",
"Reached rate limit, passing..." + Fore.RESET,
f"{Fore.RED}Error: ", f"Reached rate limit, passing...{Fore.RESET}"
)
except APIError as e:
if e.http_status == 502:
pass
else:
if e.http_status != 502:
raise
if attempt == num_retries - 1:
raise
if CFG.debug_mode:
print(
Fore.RED + "Error: ",
f"API Bad gateway. Waiting {backoff} seconds..." + Fore.RESET,
f"{Fore.RED}Error: ",
f"API Bad gateway. Waiting {backoff} seconds...{Fore.RESET}",
)
time.sleep(backoff)
if response is None:
@@ -157,15 +153,13 @@ def create_embedding_with_ada(text) -> list:
except RateLimitError:
pass
except APIError as e:
if e.http_status == 502:
pass
else:
if e.http_status != 502:
raise
if attempt == num_retries - 1:
raise
if CFG.debug_mode:
print(
Fore.RED + "Error: ",
f"API Bad gateway. Waiting {backoff} seconds..." + Fore.RESET,
f"{Fore.RED}Error: ",
f"API Bad gateway. Waiting {backoff} seconds...{Fore.RESET}",
)
time.sleep(backoff)

View File

@@ -6,6 +6,8 @@ from pathlib import Path
from typing import List, Optional, Tuple
from zipimport import zipimporter
from auto_gpt_plugin_template import AutoGPTPluginTemplate
def inspect_zip_for_module(zip_path: str, debug: bool = False) -> Optional[str]:
"""
@@ -45,7 +47,9 @@ def scan_plugins(plugins_path: Path, debug: bool = False) -> List[Tuple[str, Pat
return plugins
def load_plugins(plugins_path: Path, debug: bool = False) -> List[Module]:
def load_plugins(
plugins_path: Path, debug: bool = False
) -> List[AutoGPTPluginTemplate]:
"""Load plugins from the plugins directory.
Args:

View File

@@ -1,13 +1,15 @@
"""Functions for counting the number of tokens in a message or string."""
from __future__ import annotations
from typing import List
import tiktoken
from autogpt.logs import logger
from autogpt.types.openai import Message
def count_message_tokens(
messages: list[dict[str, str]], model: str = "gpt-3.5-turbo-0301"
messages: List[Message], model: str = "gpt-3.5-turbo-0301"
) -> int:
"""
Returns the number of tokens used by a list of messages.

9
autogpt/types/openai.py Normal file
View File

@@ -0,0 +1,9 @@
"""Type helpers for working with the OpenAI library"""
from typing import TypedDict
class Message(TypedDict):
"""OpenAI Message object containing a role and the message content"""
role: str
content: str

View File

@@ -29,6 +29,8 @@ black
sourcery
isort
gitpython==3.1.31
abstract-singleton
auto-gpt-plugin-template
# Testing dependencies
pytest