Feature/llm data structs (#3486)

* Organize all the llm stuff into a subpackage

* Add structs for interacting with llms
This commit is contained in:
James Collins
2023-04-28 15:04:31 -07:00
committed by GitHub
parent c7d75643d3
commit b8478a96ae
9 changed files with 122 additions and 14 deletions

View File

@@ -4,9 +4,8 @@ from __future__ import annotations
from typing import List
from autogpt.config.config import Config
from autogpt.llm import create_chat_completion
from autogpt.llm import Message, create_chat_completion
from autogpt.singleton import Singleton
from autogpt.types.openai import Message
class AgentManager(metaclass=Singleton):

View File

@@ -1,4 +1,13 @@
from autogpt.llm.api_manager import ApiManager
from autogpt.llm.base import (
ChatModelInfo,
ChatModelResponse,
EmbeddingModelInfo,
EmbeddingModelResponse,
LLMResponse,
Message,
ModelInfo,
)
from autogpt.llm.chat import chat_with_ai, create_chat_message, generate_context
from autogpt.llm.llm_utils import (
call_ai_function,
@@ -10,6 +19,13 @@ from autogpt.llm.token_counter import count_message_tokens, count_string_tokens
__all__ = [
"ApiManager",
"Message",
"ModelInfo",
"ChatModelInfo",
"EmbeddingModelInfo",
"LLMResponse",
"ChatModelResponse",
"EmbeddingModelResponse",
"create_chat_message",
"generate_context",
"chat_with_ai",

65
autogpt/llm/base.py Normal file
View File

@@ -0,0 +1,65 @@
from dataclasses import dataclass, field
from typing import List, TypedDict
class Message(TypedDict):
"""OpenAI Message object containing a role and the message content"""
role: str
content: str
@dataclass
class ModelInfo:
"""Struct for model information.
Would be lovely to eventually get this directly from APIs, but needs to be scraped from
websites for now.
"""
name: str
prompt_token_cost: float
completion_token_cost: float
max_tokens: int
@dataclass
class ChatModelInfo(ModelInfo):
"""Struct for chat model information."""
pass
@dataclass
class EmbeddingModelInfo(ModelInfo):
"""Struct for embedding model information."""
embedding_dimensions: int
@dataclass
class LLMResponse:
"""Standard response struct for a response from an LLM model."""
model_info: ModelInfo
prompt_tokens_used: int = 0
completion_tokens_used: int = 0
@dataclass
class EmbeddingModelResponse(LLMResponse):
"""Standard response struct for a response from an embedding model."""
embedding: List[float] = field(default_factory=list)
def __post_init__(self):
if self.completion_tokens_used:
raise ValueError("Embeddings should not have completion tokens used.")
@dataclass
class ChatModelResponse(LLMResponse):
"""Standard response struct for a response from an LLM model."""
content: str = None

View File

@@ -5,13 +5,13 @@ from openai.error import RateLimitError
from autogpt.config import Config
from autogpt.llm.api_manager import ApiManager
from autogpt.llm.base import Message
from autogpt.llm.llm_utils import create_chat_completion
from autogpt.llm.token_counter import count_message_tokens
from autogpt.logs import logger
from autogpt.memory_management.store_memory import (
save_memory_trimmed_from_context_window,
)
from autogpt.types.openai import Message
cfg = Config()

View File

@@ -10,8 +10,8 @@ from openai.error import APIError, RateLimitError, Timeout
from autogpt.config import Config
from autogpt.llm.api_manager import ApiManager
from autogpt.llm.base import Message
from autogpt.logs import logger
from autogpt.types.openai import Message
def retry_openai_api(

View File

View File

@@ -0,0 +1,37 @@
from autogpt.llm.base import ChatModelInfo, EmbeddingModelInfo
OPEN_AI_CHAT_MODELS = {
"gpt-3.5-turbo": ChatModelInfo(
name="gpt-3.5-turbo",
prompt_token_cost=0.002,
completion_token_cost=0.002,
max_tokens=4096,
),
"gpt-4": ChatModelInfo(
name="gpt-4",
prompt_token_cost=0.03,
completion_token_cost=0.06,
max_tokens=8192,
),
"gpt-4-32k": ChatModelInfo(
name="gpt-4-32k",
prompt_token_cost=0.06,
completion_token_cost=0.12,
max_tokens=32768,
),
}
OPEN_AI_EMBEDDING_MODELS = {
"text-embedding-ada-002": EmbeddingModelInfo(
name="text-embedding-ada-002",
prompt_token_cost=0.0004,
completion_token_cost=0.0,
max_tokens=8191,
embedding_dimensions=1536,
),
}
OPEN_AI_MODELS = {
**OPEN_AI_CHAT_MODELS,
**OPEN_AI_EMBEDDING_MODELS,
}

View File

@@ -5,8 +5,8 @@ from typing import List
import tiktoken
from autogpt.llm.base import Message
from autogpt.logs import logger
from autogpt.types.openai import Message
def count_message_tokens(

View File

@@ -1,9 +0,0 @@
"""Type helpers for working with the OpenAI library"""
from typing import TypedDict
class Message(TypedDict):
"""OpenAI Message object containing a role and the message content"""
role: str
content: str