Files
Auto-GPT/autogpt/llm/base.py
Reinier van der Leer 51d8b43fbf Improve token counting; account for functions (#4919)
* Improvements to token counting, including functions
---------

Co-authored-by: James Collins <collijk@uw.edu>
2023-07-09 11:31:18 -07:00

196 lines
4.9 KiB
Python

from __future__ import annotations
from copy import deepcopy
from dataclasses import dataclass, field
from math import ceil, floor
from typing import TYPE_CHECKING, Literal, Optional, Type, TypedDict, TypeVar, overload
if TYPE_CHECKING:
from autogpt.llm.providers.openai import OpenAIFunctionCall
MessageRole = Literal["system", "user", "assistant", "function"]
MessageType = Literal["ai_response", "action_result"]
TText = list[int]
"""Token array representing tokenized text"""
class MessageDict(TypedDict):
role: MessageRole
content: str
class ResponseMessageDict(TypedDict):
role: Literal["assistant"]
content: Optional[str]
function_call: Optional[FunctionCallDict]
class FunctionCallDict(TypedDict):
name: str
arguments: str
@dataclass
class Message:
"""OpenAI Message object containing a role and the message content"""
role: MessageRole
content: str
type: MessageType | None = None
def raw(self) -> MessageDict:
return {"role": self.role, "content": self.content}
@dataclass
class ModelInfo:
"""Struct for model information.
Would be lovely to eventually get this directly from APIs, but needs to be scraped from
websites for now.
"""
name: str
max_tokens: int
prompt_token_cost: float
@dataclass
class CompletionModelInfo(ModelInfo):
"""Struct for generic completion model information."""
completion_token_cost: float
@dataclass
class ChatModelInfo(CompletionModelInfo):
"""Struct for chat model information."""
@dataclass
class TextModelInfo(CompletionModelInfo):
"""Struct for text completion model information."""
@dataclass
class EmbeddingModelInfo(ModelInfo):
"""Struct for embedding model information."""
embedding_dimensions: int
# Can be replaced by Self in Python 3.11
TChatSequence = TypeVar("TChatSequence", bound="ChatSequence")
@dataclass
class ChatSequence:
"""Utility container for a chat sequence"""
model: ChatModelInfo
messages: list[Message] = field(default_factory=list[Message])
@overload
def __getitem__(self, key: int) -> Message:
...
@overload
def __getitem__(self: TChatSequence, key: slice) -> TChatSequence:
...
def __getitem__(self: TChatSequence, key: int | slice) -> Message | TChatSequence:
if isinstance(key, slice):
copy = deepcopy(self)
copy.messages = self.messages[key]
return copy
return self.messages[key]
def __iter__(self):
return iter(self.messages)
def __len__(self):
return len(self.messages)
def add(
self,
message_role: MessageRole,
content: str,
type: MessageType | None = None,
) -> None:
self.append(Message(message_role, content, type))
def append(self, message: Message):
return self.messages.append(message)
def extend(self, messages: list[Message] | ChatSequence):
return self.messages.extend(messages)
def insert(self, index: int, *messages: Message):
for message in reversed(messages):
self.messages.insert(index, message)
@classmethod
def for_model(
cls: Type[TChatSequence],
model_name: str,
messages: list[Message] | ChatSequence = [],
**kwargs,
) -> TChatSequence:
from autogpt.llm.providers.openai import OPEN_AI_CHAT_MODELS
if not model_name in OPEN_AI_CHAT_MODELS:
raise ValueError(f"Unknown chat model '{model_name}'")
return cls(
model=OPEN_AI_CHAT_MODELS[model_name], messages=list(messages), **kwargs
)
@property
def token_length(self) -> int:
from autogpt.llm.utils import count_message_tokens
return count_message_tokens(self.messages, self.model.name)
def raw(self) -> list[MessageDict]:
return [m.raw() for m in self.messages]
def dump(self) -> str:
SEPARATOR_LENGTH = 42
def separator(text: str):
half_sep_len = (SEPARATOR_LENGTH - 2 - len(text)) / 2
return f"{floor(half_sep_len)*'-'} {text.upper()} {ceil(half_sep_len)*'-'}"
formatted_messages = "\n".join(
[f"{separator(m.role)}\n{m.content}" for m in self.messages]
)
return f"""
============== {__class__.__name__} ==============
Length: {self.token_length} tokens; {len(self.messages)} messages
{formatted_messages}
==========================================
"""
@dataclass
class LLMResponse:
"""Standard response struct for a response from an LLM model."""
model_info: ModelInfo
@dataclass
class EmbeddingModelResponse(LLMResponse):
"""Standard response struct for a response from an embedding model."""
embedding: list[float] = field(default_factory=list)
@dataclass
class ChatModelResponse(LLMResponse):
"""Standard response struct for a response from a chat LLM."""
content: Optional[str]
function_call: Optional[OpenAIFunctionCall]