From b8478a96aea94f9f9f34cf966a0f7ef27c6d04c4 Mon Sep 17 00:00:00 2001 From: James Collins Date: Fri, 28 Apr 2023 15:04:31 -0700 Subject: [PATCH] Feature/llm data structs (#3486) * Organize all the llm stuff into a subpackage * Add structs for interacting with llms --- autogpt/agent/agent_manager.py | 3 +- autogpt/llm/__init__.py | 16 ++++++++ autogpt/llm/base.py | 65 +++++++++++++++++++++++++++++++ autogpt/llm/chat.py | 2 +- autogpt/llm/llm_utils.py | 2 +- autogpt/llm/providers/__init__.py | 0 autogpt/llm/providers/openai.py | 37 ++++++++++++++++++ autogpt/llm/token_counter.py | 2 +- autogpt/types/openai.py | 9 ----- 9 files changed, 122 insertions(+), 14 deletions(-) create mode 100644 autogpt/llm/base.py create mode 100644 autogpt/llm/providers/__init__.py create mode 100644 autogpt/llm/providers/openai.py delete mode 100644 autogpt/types/openai.py diff --git a/autogpt/agent/agent_manager.py b/autogpt/agent/agent_manager.py index 1283fdae..17fb35d8 100644 --- a/autogpt/agent/agent_manager.py +++ b/autogpt/agent/agent_manager.py @@ -4,9 +4,8 @@ from __future__ import annotations from typing import List from autogpt.config.config import Config -from autogpt.llm import create_chat_completion +from autogpt.llm import Message, create_chat_completion from autogpt.singleton import Singleton -from autogpt.types.openai import Message class AgentManager(metaclass=Singleton): diff --git a/autogpt/llm/__init__.py b/autogpt/llm/__init__.py index 3a958285..2a6f0b8f 100644 --- a/autogpt/llm/__init__.py +++ b/autogpt/llm/__init__.py @@ -1,4 +1,13 @@ from autogpt.llm.api_manager import ApiManager +from autogpt.llm.base import ( + ChatModelInfo, + ChatModelResponse, + EmbeddingModelInfo, + EmbeddingModelResponse, + LLMResponse, + Message, + ModelInfo, +) from autogpt.llm.chat import chat_with_ai, create_chat_message, generate_context from autogpt.llm.llm_utils import ( call_ai_function, @@ -10,6 +19,13 @@ from autogpt.llm.token_counter import count_message_tokens, count_string_tokens __all__ = [ "ApiManager", + "Message", + "ModelInfo", + "ChatModelInfo", + "EmbeddingModelInfo", + "LLMResponse", + "ChatModelResponse", + "EmbeddingModelResponse", "create_chat_message", "generate_context", "chat_with_ai", diff --git a/autogpt/llm/base.py b/autogpt/llm/base.py new file mode 100644 index 00000000..722e0f0f --- /dev/null +++ b/autogpt/llm/base.py @@ -0,0 +1,65 @@ +from dataclasses import dataclass, field +from typing import List, TypedDict + + +class Message(TypedDict): + """OpenAI Message object containing a role and the message content""" + + role: str + content: str + + +@dataclass +class ModelInfo: + """Struct for model information. + + Would be lovely to eventually get this directly from APIs, but needs to be scraped from + websites for now. + + """ + + name: str + prompt_token_cost: float + completion_token_cost: float + max_tokens: int + + +@dataclass +class ChatModelInfo(ModelInfo): + """Struct for chat model information.""" + + pass + + +@dataclass +class EmbeddingModelInfo(ModelInfo): + """Struct for embedding model information.""" + + embedding_dimensions: int + + +@dataclass +class LLMResponse: + """Standard response struct for a response from an LLM model.""" + + model_info: ModelInfo + prompt_tokens_used: int = 0 + completion_tokens_used: int = 0 + + +@dataclass +class EmbeddingModelResponse(LLMResponse): + """Standard response struct for a response from an embedding model.""" + + embedding: List[float] = field(default_factory=list) + + def __post_init__(self): + if self.completion_tokens_used: + raise ValueError("Embeddings should not have completion tokens used.") + + +@dataclass +class ChatModelResponse(LLMResponse): + """Standard response struct for a response from an LLM model.""" + + content: str = None diff --git a/autogpt/llm/chat.py b/autogpt/llm/chat.py index 119468c3..e0f0226d 100644 --- a/autogpt/llm/chat.py +++ b/autogpt/llm/chat.py @@ -5,13 +5,13 @@ from openai.error import RateLimitError from autogpt.config import Config from autogpt.llm.api_manager import ApiManager +from autogpt.llm.base import Message from autogpt.llm.llm_utils import create_chat_completion from autogpt.llm.token_counter import count_message_tokens from autogpt.logs import logger from autogpt.memory_management.store_memory import ( save_memory_trimmed_from_context_window, ) -from autogpt.types.openai import Message cfg = Config() diff --git a/autogpt/llm/llm_utils.py b/autogpt/llm/llm_utils.py index c1ba5fa5..9a2400c7 100644 --- a/autogpt/llm/llm_utils.py +++ b/autogpt/llm/llm_utils.py @@ -10,8 +10,8 @@ from openai.error import APIError, RateLimitError, Timeout from autogpt.config import Config from autogpt.llm.api_manager import ApiManager +from autogpt.llm.base import Message from autogpt.logs import logger -from autogpt.types.openai import Message def retry_openai_api( diff --git a/autogpt/llm/providers/__init__.py b/autogpt/llm/providers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/autogpt/llm/providers/openai.py b/autogpt/llm/providers/openai.py new file mode 100644 index 00000000..188d5cf7 --- /dev/null +++ b/autogpt/llm/providers/openai.py @@ -0,0 +1,37 @@ +from autogpt.llm.base import ChatModelInfo, EmbeddingModelInfo + +OPEN_AI_CHAT_MODELS = { + "gpt-3.5-turbo": ChatModelInfo( + name="gpt-3.5-turbo", + prompt_token_cost=0.002, + completion_token_cost=0.002, + max_tokens=4096, + ), + "gpt-4": ChatModelInfo( + name="gpt-4", + prompt_token_cost=0.03, + completion_token_cost=0.06, + max_tokens=8192, + ), + "gpt-4-32k": ChatModelInfo( + name="gpt-4-32k", + prompt_token_cost=0.06, + completion_token_cost=0.12, + max_tokens=32768, + ), +} + +OPEN_AI_EMBEDDING_MODELS = { + "text-embedding-ada-002": EmbeddingModelInfo( + name="text-embedding-ada-002", + prompt_token_cost=0.0004, + completion_token_cost=0.0, + max_tokens=8191, + embedding_dimensions=1536, + ), +} + +OPEN_AI_MODELS = { + **OPEN_AI_CHAT_MODELS, + **OPEN_AI_EMBEDDING_MODELS, +} diff --git a/autogpt/llm/token_counter.py b/autogpt/llm/token_counter.py index 2d50547b..5e13920e 100644 --- a/autogpt/llm/token_counter.py +++ b/autogpt/llm/token_counter.py @@ -5,8 +5,8 @@ from typing import List import tiktoken +from autogpt.llm.base import Message from autogpt.logs import logger -from autogpt.types.openai import Message def count_message_tokens( diff --git a/autogpt/types/openai.py b/autogpt/types/openai.py deleted file mode 100644 index 2af85785..00000000 --- a/autogpt/types/openai.py +++ /dev/null @@ -1,9 +0,0 @@ -"""Type helpers for working with the OpenAI library""" -from typing import TypedDict - - -class Message(TypedDict): - """OpenAI Message object containing a role and the message content""" - - role: str - content: str