🐛 Minor type fixes

This commit is contained in:
Taylor Beeston
2023-04-17 12:42:17 -07:00
parent f784049079
commit ea67b6772c
3 changed files with 17 additions and 8 deletions

View File

@@ -1,12 +1,14 @@
from __future__ import annotations from __future__ import annotations
import time import time
from typing import List, Optional
import openai import openai
from colorama import Fore from colorama import Fore
from openai.error import APIError, RateLimitError from openai.error import APIError, RateLimitError
from autogpt.config import Config from autogpt.config import Config
from plugin_template import Message
CFG = Config() CFG = Config()
@@ -35,8 +37,8 @@ def call_ai_function(
# For each arg, if any are None, convert to "None": # For each arg, if any are None, convert to "None":
args = [str(arg) if arg is not None else "None" for arg in args] args = [str(arg) if arg is not None else "None" for arg in args]
# parse args to comma separated string # parse args to comma separated string
args = ", ".join(args) args: str = ", ".join(args)
messages = [ messages: List[Message] = [
{ {
"role": "system", "role": "system",
"content": f"You are now the following python function: ```# {description}" "content": f"You are now the following python function: ```# {description}"
@@ -51,15 +53,15 @@ def call_ai_function(
# Overly simple abstraction until we create something better # Overly simple abstraction until we create something better
# simple retry mechanism when getting a rate error or a bad gateway # simple retry mechanism when getting a rate error or a bad gateway
def create_chat_completion( def create_chat_completion(
messages: list, # type: ignore messages: List[Message], # type: ignore
model: str | None = None, model: Optional[str] = None,
temperature: float = CFG.temperature, temperature: float = CFG.temperature,
max_tokens: int | None = None, max_tokens: Optional[int] = None,
) -> str: ) -> str:
"""Create a chat completion using the OpenAI API """Create a chat completion using the OpenAI API
Args: Args:
messages (list[dict[str, str]]): The messages to send to the chat completion messages (List[Message]): The messages to send to the chat completion
model (str, optional): The model to use. Defaults to None. model (str, optional): The model to use. Defaults to None.
temperature (float, optional): The temperature to use. Defaults to 0.9. temperature (float, optional): The temperature to use. Defaults to 0.9.
max_tokens (int, optional): The max tokens to use. Defaults to None. max_tokens (int, optional): The max tokens to use. Defaults to None.

View File

@@ -6,6 +6,8 @@ from pathlib import Path
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from zipimport import zipimporter from zipimport import zipimporter
from plugin_template import AutoGPTPluginTemplate
def inspect_zip_for_module(zip_path: str, debug: bool = False) -> Optional[str]: def inspect_zip_for_module(zip_path: str, debug: bool = False) -> Optional[str]:
""" """
@@ -45,7 +47,9 @@ def scan_plugins(plugins_path: Path, debug: bool = False) -> List[Tuple[str, Pat
return plugins return plugins
def load_plugins(plugins_path: Path, debug: bool = False) -> List[Module]: def load_plugins(
plugins_path: Path, debug: bool = False
) -> List[AutoGPTPluginTemplate]:
"""Load plugins from the plugins directory. """Load plugins from the plugins directory.
Args: Args:

View File

@@ -1,13 +1,16 @@
"""Functions for counting the number of tokens in a message or string.""" """Functions for counting the number of tokens in a message or string."""
from __future__ import annotations from __future__ import annotations
from typing import List
import tiktoken import tiktoken
from autogpt.logs import logger from autogpt.logs import logger
from plugin_template import Message
def count_message_tokens( def count_message_tokens(
messages: list[dict[str, str]], model: str = "gpt-3.5-turbo-0301" messages: List[Message], model: str = "gpt-3.5-turbo-0301"
) -> int: ) -> int:
""" """
Returns the number of tokens used by a list of messages. Returns the number of tokens used by a list of messages.