mirror of
https://github.com/aljazceru/Auto-GPT.git
synced 2025-12-23 08:54:24 +01:00
🐛 Minor type fixes
This commit is contained in:
@@ -1,12 +1,14 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import time
|
import time
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
from colorama import Fore
|
from colorama import Fore
|
||||||
from openai.error import APIError, RateLimitError
|
from openai.error import APIError, RateLimitError
|
||||||
|
|
||||||
from autogpt.config import Config
|
from autogpt.config import Config
|
||||||
|
from plugin_template import Message
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
@@ -35,8 +37,8 @@ def call_ai_function(
|
|||||||
# For each arg, if any are None, convert to "None":
|
# For each arg, if any are None, convert to "None":
|
||||||
args = [str(arg) if arg is not None else "None" for arg in args]
|
args = [str(arg) if arg is not None else "None" for arg in args]
|
||||||
# parse args to comma separated string
|
# parse args to comma separated string
|
||||||
args = ", ".join(args)
|
args: str = ", ".join(args)
|
||||||
messages = [
|
messages: List[Message] = [
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
"content": f"You are now the following python function: ```# {description}"
|
"content": f"You are now the following python function: ```# {description}"
|
||||||
@@ -51,15 +53,15 @@ def call_ai_function(
|
|||||||
# Overly simple abstraction until we create something better
|
# Overly simple abstraction until we create something better
|
||||||
# simple retry mechanism when getting a rate error or a bad gateway
|
# simple retry mechanism when getting a rate error or a bad gateway
|
||||||
def create_chat_completion(
|
def create_chat_completion(
|
||||||
messages: list, # type: ignore
|
messages: List[Message], # type: ignore
|
||||||
model: str | None = None,
|
model: Optional[str] = None,
|
||||||
temperature: float = CFG.temperature,
|
temperature: float = CFG.temperature,
|
||||||
max_tokens: int | None = None,
|
max_tokens: Optional[int] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Create a chat completion using the OpenAI API
|
"""Create a chat completion using the OpenAI API
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages (list[dict[str, str]]): The messages to send to the chat completion
|
messages (List[Message]): The messages to send to the chat completion
|
||||||
model (str, optional): The model to use. Defaults to None.
|
model (str, optional): The model to use. Defaults to None.
|
||||||
temperature (float, optional): The temperature to use. Defaults to 0.9.
|
temperature (float, optional): The temperature to use. Defaults to 0.9.
|
||||||
max_tokens (int, optional): The max tokens to use. Defaults to None.
|
max_tokens (int, optional): The max tokens to use. Defaults to None.
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ from pathlib import Path
|
|||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
from zipimport import zipimporter
|
from zipimport import zipimporter
|
||||||
|
|
||||||
|
from plugin_template import AutoGPTPluginTemplate
|
||||||
|
|
||||||
|
|
||||||
def inspect_zip_for_module(zip_path: str, debug: bool = False) -> Optional[str]:
|
def inspect_zip_for_module(zip_path: str, debug: bool = False) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
@@ -45,7 +47,9 @@ def scan_plugins(plugins_path: Path, debug: bool = False) -> List[Tuple[str, Pat
|
|||||||
return plugins
|
return plugins
|
||||||
|
|
||||||
|
|
||||||
def load_plugins(plugins_path: Path, debug: bool = False) -> List[Module]:
|
def load_plugins(
|
||||||
|
plugins_path: Path, debug: bool = False
|
||||||
|
) -> List[AutoGPTPluginTemplate]:
|
||||||
"""Load plugins from the plugins directory.
|
"""Load plugins from the plugins directory.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -1,13 +1,16 @@
|
|||||||
"""Functions for counting the number of tokens in a message or string."""
|
"""Functions for counting the number of tokens in a message or string."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import tiktoken
|
import tiktoken
|
||||||
|
|
||||||
from autogpt.logs import logger
|
from autogpt.logs import logger
|
||||||
|
|
||||||
|
from plugin_template import Message
|
||||||
|
|
||||||
|
|
||||||
def count_message_tokens(
|
def count_message_tokens(
|
||||||
messages: list[dict[str, str]], model: str = "gpt-3.5-turbo-0301"
|
messages: List[Message], model: str = "gpt-3.5-turbo-0301"
|
||||||
) -> int:
|
) -> int:
|
||||||
"""
|
"""
|
||||||
Returns the number of tokens used by a list of messages.
|
Returns the number of tokens used by a list of messages.
|
||||||
|
|||||||
Reference in New Issue
Block a user