Improve token counting; account for functions (#4919)

* Improvements to token counting, including functions
---------

Co-authored-by: James Collins <collijk@uw.edu>
This commit is contained in:
Reinier van der Leer
2023-07-09 20:31:18 +02:00
committed by GitHub
parent e8b6676b22
commit 51d8b43fbf
4 changed files with 138 additions and 46 deletions

View File

@@ -20,6 +20,17 @@ class MessageDict(TypedDict):
content: str
class ResponseMessageDict(TypedDict):
role: Literal["assistant"]
content: Optional[str]
function_call: Optional[FunctionCallDict]
class FunctionCallDict(TypedDict):
name: str
arguments: str
@dataclass
class Message:
"""OpenAI Message object containing a role and the message content"""
@@ -167,8 +178,6 @@ class LLMResponse:
"""Standard response struct for a response from an LLM model."""
model_info: ModelInfo
prompt_tokens_used: int = 0
completion_tokens_used: int = 0
@dataclass
@@ -177,14 +186,10 @@ class EmbeddingModelResponse(LLMResponse):
embedding: list[float] = field(default_factory=list)
def __post_init__(self):
if self.completion_tokens_used:
raise ValueError("Embeddings should not have completion tokens used.")
@dataclass
class ChatModelResponse(LLMResponse):
"""Standard response struct for a response from an LLM model."""
"""Standard response struct for a response from a chat LLM."""
content: Optional[str] = None
function_call: Optional[OpenAIFunctionCall] = None
content: Optional[str]
function_call: Optional[OpenAIFunctionCall]

View File

@@ -3,14 +3,16 @@ from __future__ import annotations
import time
from typing import TYPE_CHECKING
from autogpt.llm.providers.openai import get_openai_command_specs
if TYPE_CHECKING:
from autogpt.agent.agent import Agent
from autogpt.config import Config
from autogpt.llm.api_manager import ApiManager
from autogpt.llm.base import ChatSequence, Message
from autogpt.llm.providers.openai import (
count_openai_functions_tokens,
get_openai_command_specs,
)
from autogpt.llm.utils import count_message_tokens, create_chat_completion
from autogpt.logs import CURRENT_CONTEXT_FILE_NAME, logger
@@ -72,23 +74,17 @@ def chat_with_ai(
],
)
# Add messages from the full message history until we reach the token limit
next_message_to_add_index = len(agent.history) - 1
insertion_index = len(message_sequence)
# Count the currently used tokens
current_tokens_used = message_sequence.token_length
insertion_index = len(message_sequence)
# while current_tokens_used > 2500:
# # remove memories until we are under 2500 tokens
# relevant_memory = relevant_memory[:-1]
# (
# next_message_to_add_index,
# current_tokens_used,
# insertion_index,
# current_context,
# ) = generate_context(
# prompt, relevant_memory, agent.history, model
# )
# Account for tokens used by OpenAI functions
openai_functions = None
if agent.config.openai_functions:
openai_functions = get_openai_command_specs(agent.command_registry)
functions_tlength = count_openai_functions_tokens(openai_functions, model)
current_tokens_used += functions_tlength
logger.debug(f"OpenAI Functions take up {functions_tlength} tokens in API call")
# Account for user input (appended later)
user_input_msg = Message("user", triggering_prompt)
@@ -97,7 +93,8 @@ def chat_with_ai(
current_tokens_used += agent.history.max_summary_tlength # Reserve space
current_tokens_used += 500 # Reserve space for the openai functions TODO improve
# Add Messages until the token limit is reached or there are no more messages to add.
# Add historical Messages until the token limit is reached
# or there are no more messages to add.
for cycle in reversed(list(agent.history.per_cycle())):
messages_to_add = [msg for msg in cycle if msg is not None]
tokens_to_add = count_message_tokens(messages_to_add, model)
@@ -162,6 +159,8 @@ def chat_with_ai(
logger.debug(f"Plugins remaining at stop: {plugin_count - i}")
break
message_sequence.add("system", plugin_response)
current_tokens_used += tokens_to_add
# Calculate remaining tokens
tokens_remaining = token_limit - current_tokens_used
# assert tokens_remaining >= 0, "Tokens remaining is negative.
@@ -193,7 +192,7 @@ def chat_with_ai(
assistant_reply = create_chat_completion(
prompt=message_sequence,
config=agent.config,
functions=get_openai_command_specs(agent),
functions=openai_functions,
max_tokens=tokens_remaining,
)

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
import functools
import time
from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional
from typing import List, Optional
from unittest.mock import patch
import openai
@@ -12,9 +12,6 @@ from colorama import Fore, Style
from openai.error import APIError, RateLimitError, ServiceUnavailableError, Timeout
from openai.openai_object import OpenAIObject
if TYPE_CHECKING:
from autogpt.agent.agent import Agent
from autogpt.llm.base import (
ChatModelInfo,
EmbeddingModelInfo,
@@ -23,6 +20,7 @@ from autogpt.llm.base import (
TText,
)
from autogpt.logs import logger
from autogpt.models.command_registry import CommandRegistry
OPEN_AI_CHAT_MODELS = {
info.name: info
@@ -301,13 +299,13 @@ class OpenAIFunctionSpec:
@dataclass
class ParameterSpec:
name: str
type: str
type: str # TODO: add enum support
description: Optional[str]
required: bool = False
@property
def __dict__(self):
"""Output an OpenAI-consumable function specification"""
def schema(self) -> dict[str, str | dict | list]:
"""Returns an OpenAI-consumable function specification"""
return {
"name": self.name,
"description": self.description,
@@ -326,14 +324,44 @@ class OpenAIFunctionSpec:
},
}
@property
def prompt_format(self) -> str:
"""Returns the function formatted similarly to the way OpenAI does it internally:
https://community.openai.com/t/how-to-calculate-the-tokens-when-using-function-call/266573/18
def get_openai_command_specs(agent: Agent) -> list[OpenAIFunctionSpec]:
Example:
```ts
// Get the current weather in a given location
type get_current_weather = (_: {
// The city and state, e.g. San Francisco, CA
location: string,
unit?: "celsius" | "fahrenheit",
}) => any;
```
"""
def param_signature(p_spec: OpenAIFunctionSpec.ParameterSpec) -> str:
# TODO: enum type support
return (
f"// {p_spec.description}\n" if p_spec.description else ""
) + f"{p_spec.name}{'' if p_spec.required else '?'}: {p_spec.type},"
return "\n".join(
[
f"// {self.description}",
f"type {self.name} = (_ :{{",
*[param_signature(p) for p in self.parameters.values()],
"}) => any;",
]
)
def get_openai_command_specs(
command_registry: CommandRegistry,
) -> list[OpenAIFunctionSpec]:
"""Get OpenAI-consumable function specs for the agent's available commands.
see https://platform.openai.com/docs/guides/gpt/function-calling
"""
if not agent.config.openai_functions:
return []
return [
OpenAIFunctionSpec(
name=command.name,
@@ -348,5 +376,48 @@ def get_openai_command_specs(agent: Agent) -> list[OpenAIFunctionSpec]:
for param in command.parameters
},
)
for command in agent.command_registry.commands.values()
for command in command_registry.commands.values()
]
def count_openai_functions_tokens(
functions: list[OpenAIFunctionSpec], for_model: str
) -> int:
"""Returns the number of tokens taken up by a set of function definitions
Reference: https://community.openai.com/t/how-to-calculate-the-tokens-when-using-function-call/266573/18
"""
from autogpt.llm.utils import count_string_tokens
return count_string_tokens(
f"# Tools\n\n## functions\n\n{format_function_specs_as_typescript_ns(functions)}",
for_model,
)
def format_function_specs_as_typescript_ns(functions: list[OpenAIFunctionSpec]) -> str:
"""Returns a function signature block in the format used by OpenAI internally:
https://community.openai.com/t/how-to-calculate-the-tokens-when-using-function-call/266573/18
For use with `count_string_tokens` to determine token usage of provided functions.
Example:
```ts
namespace functions {
// Get the current weather in a given location
type get_current_weather = (_: {
// The city and state, e.g. San Francisco, CA
location: string,
unit?: "celsius" | "fahrenheit",
}) => any;
} // namespace functions
```
"""
return (
"namespace functions {\n\n"
+ "\n\n".join(f.prompt_format for f in functions)
+ "\n\n} // namespace functions"
)

View File

@@ -7,12 +7,19 @@ from colorama import Fore
from autogpt.config import Config
from ..api_manager import ApiManager
from ..base import ChatModelResponse, ChatSequence, Message
from ..base import (
ChatModelResponse,
ChatSequence,
FunctionCallDict,
Message,
ResponseMessageDict,
)
from ..providers import openai as iopenai
from ..providers.openai import (
OPEN_AI_CHAT_MODELS,
OpenAIFunctionCall,
OpenAIFunctionSpec,
count_openai_functions_tokens,
)
from .token_counter import *
@@ -114,7 +121,13 @@ def create_chat_completion(
if temperature is None:
temperature = config.temperature
if max_tokens is None:
max_tokens = OPEN_AI_CHAT_MODELS[model].max_tokens - prompt.token_length
prompt_tlength = prompt.token_length
max_tokens = OPEN_AI_CHAT_MODELS[model].max_tokens - prompt_tlength
logger.debug(f"Prompt length: {prompt_tlength} tokens")
if functions:
functions_tlength = count_openai_functions_tokens(functions, model)
max_tokens -= functions_tlength
logger.debug(f"Functions take up {functions_tlength} tokens in API call")
logger.debug(
f"{Fore.GREEN}Creating chat completion with model {model}, temperature {temperature}, max_tokens {max_tokens}{Fore.RESET}"
@@ -143,9 +156,8 @@ def create_chat_completion(
if functions:
chat_completion_kwargs["functions"] = [
function.__dict__ for function in functions
function.schema for function in functions
]
logger.debug(f"Function dicts: {chat_completion_kwargs['functions']}")
response = iopenai.create_chat_completion(
messages=prompt.raw(),
@@ -157,19 +169,24 @@ def create_chat_completion(
logger.error(response.error)
raise RuntimeError(response.error)
first_message = response.choices[0].message
first_message: ResponseMessageDict = response.choices[0].message
content: str | None = first_message.get("content")
function_call: OpenAIFunctionCall | None = first_message.get("function_call")
function_call: FunctionCallDict | None = first_message.get("function_call")
for plugin in config.plugins:
if not plugin.can_handle_on_response():
continue
# TODO: function call support in plugin.on_response()
content = plugin.on_response(content)
return ChatModelResponse(
model_info=OPEN_AI_CHAT_MODELS[model],
content=content,
function_call=function_call,
function_call=OpenAIFunctionCall(
name=function_call["name"], arguments=function_call["arguments"]
)
if function_call
else None,
)