mirror of
https://github.com/aljazceru/Auto-GPT.git
synced 2026-01-08 08:44:23 +01:00
Improve token counting; account for functions (#4919)
* Improvements to token counting, including functions --------- Co-authored-by: James Collins <collijk@uw.edu>
This commit is contained in:
committed by
GitHub
parent
e8b6676b22
commit
51d8b43fbf
@@ -20,6 +20,17 @@ class MessageDict(TypedDict):
|
||||
content: str
|
||||
|
||||
|
||||
class ResponseMessageDict(TypedDict):
|
||||
role: Literal["assistant"]
|
||||
content: Optional[str]
|
||||
function_call: Optional[FunctionCallDict]
|
||||
|
||||
|
||||
class FunctionCallDict(TypedDict):
|
||||
name: str
|
||||
arguments: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class Message:
|
||||
"""OpenAI Message object containing a role and the message content"""
|
||||
@@ -167,8 +178,6 @@ class LLMResponse:
|
||||
"""Standard response struct for a response from an LLM model."""
|
||||
|
||||
model_info: ModelInfo
|
||||
prompt_tokens_used: int = 0
|
||||
completion_tokens_used: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -177,14 +186,10 @@ class EmbeddingModelResponse(LLMResponse):
|
||||
|
||||
embedding: list[float] = field(default_factory=list)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.completion_tokens_used:
|
||||
raise ValueError("Embeddings should not have completion tokens used.")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatModelResponse(LLMResponse):
|
||||
"""Standard response struct for a response from an LLM model."""
|
||||
"""Standard response struct for a response from a chat LLM."""
|
||||
|
||||
content: Optional[str] = None
|
||||
function_call: Optional[OpenAIFunctionCall] = None
|
||||
content: Optional[str]
|
||||
function_call: Optional[OpenAIFunctionCall]
|
||||
|
||||
@@ -3,14 +3,16 @@ from __future__ import annotations
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from autogpt.llm.providers.openai import get_openai_command_specs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.agent.agent import Agent
|
||||
|
||||
from autogpt.config import Config
|
||||
from autogpt.llm.api_manager import ApiManager
|
||||
from autogpt.llm.base import ChatSequence, Message
|
||||
from autogpt.llm.providers.openai import (
|
||||
count_openai_functions_tokens,
|
||||
get_openai_command_specs,
|
||||
)
|
||||
from autogpt.llm.utils import count_message_tokens, create_chat_completion
|
||||
from autogpt.logs import CURRENT_CONTEXT_FILE_NAME, logger
|
||||
|
||||
@@ -72,23 +74,17 @@ def chat_with_ai(
|
||||
],
|
||||
)
|
||||
|
||||
# Add messages from the full message history until we reach the token limit
|
||||
next_message_to_add_index = len(agent.history) - 1
|
||||
insertion_index = len(message_sequence)
|
||||
# Count the currently used tokens
|
||||
current_tokens_used = message_sequence.token_length
|
||||
insertion_index = len(message_sequence)
|
||||
|
||||
# while current_tokens_used > 2500:
|
||||
# # remove memories until we are under 2500 tokens
|
||||
# relevant_memory = relevant_memory[:-1]
|
||||
# (
|
||||
# next_message_to_add_index,
|
||||
# current_tokens_used,
|
||||
# insertion_index,
|
||||
# current_context,
|
||||
# ) = generate_context(
|
||||
# prompt, relevant_memory, agent.history, model
|
||||
# )
|
||||
# Account for tokens used by OpenAI functions
|
||||
openai_functions = None
|
||||
if agent.config.openai_functions:
|
||||
openai_functions = get_openai_command_specs(agent.command_registry)
|
||||
functions_tlength = count_openai_functions_tokens(openai_functions, model)
|
||||
current_tokens_used += functions_tlength
|
||||
logger.debug(f"OpenAI Functions take up {functions_tlength} tokens in API call")
|
||||
|
||||
# Account for user input (appended later)
|
||||
user_input_msg = Message("user", triggering_prompt)
|
||||
@@ -97,7 +93,8 @@ def chat_with_ai(
|
||||
current_tokens_used += agent.history.max_summary_tlength # Reserve space
|
||||
current_tokens_used += 500 # Reserve space for the openai functions TODO improve
|
||||
|
||||
# Add Messages until the token limit is reached or there are no more messages to add.
|
||||
# Add historical Messages until the token limit is reached
|
||||
# or there are no more messages to add.
|
||||
for cycle in reversed(list(agent.history.per_cycle())):
|
||||
messages_to_add = [msg for msg in cycle if msg is not None]
|
||||
tokens_to_add = count_message_tokens(messages_to_add, model)
|
||||
@@ -162,6 +159,8 @@ def chat_with_ai(
|
||||
logger.debug(f"Plugins remaining at stop: {plugin_count - i}")
|
||||
break
|
||||
message_sequence.add("system", plugin_response)
|
||||
current_tokens_used += tokens_to_add
|
||||
|
||||
# Calculate remaining tokens
|
||||
tokens_remaining = token_limit - current_tokens_used
|
||||
# assert tokens_remaining >= 0, "Tokens remaining is negative.
|
||||
@@ -193,7 +192,7 @@ def chat_with_ai(
|
||||
assistant_reply = create_chat_completion(
|
||||
prompt=message_sequence,
|
||||
config=agent.config,
|
||||
functions=get_openai_command_specs(agent),
|
||||
functions=openai_functions,
|
||||
max_tokens=tokens_remaining,
|
||||
)
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import functools
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
from typing import List, Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
import openai
|
||||
@@ -12,9 +12,6 @@ from colorama import Fore, Style
|
||||
from openai.error import APIError, RateLimitError, ServiceUnavailableError, Timeout
|
||||
from openai.openai_object import OpenAIObject
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.agent.agent import Agent
|
||||
|
||||
from autogpt.llm.base import (
|
||||
ChatModelInfo,
|
||||
EmbeddingModelInfo,
|
||||
@@ -23,6 +20,7 @@ from autogpt.llm.base import (
|
||||
TText,
|
||||
)
|
||||
from autogpt.logs import logger
|
||||
from autogpt.models.command_registry import CommandRegistry
|
||||
|
||||
OPEN_AI_CHAT_MODELS = {
|
||||
info.name: info
|
||||
@@ -301,13 +299,13 @@ class OpenAIFunctionSpec:
|
||||
@dataclass
|
||||
class ParameterSpec:
|
||||
name: str
|
||||
type: str
|
||||
type: str # TODO: add enum support
|
||||
description: Optional[str]
|
||||
required: bool = False
|
||||
|
||||
@property
|
||||
def __dict__(self):
|
||||
"""Output an OpenAI-consumable function specification"""
|
||||
def schema(self) -> dict[str, str | dict | list]:
|
||||
"""Returns an OpenAI-consumable function specification"""
|
||||
return {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
@@ -326,14 +324,44 @@ class OpenAIFunctionSpec:
|
||||
},
|
||||
}
|
||||
|
||||
@property
|
||||
def prompt_format(self) -> str:
|
||||
"""Returns the function formatted similarly to the way OpenAI does it internally:
|
||||
https://community.openai.com/t/how-to-calculate-the-tokens-when-using-function-call/266573/18
|
||||
|
||||
def get_openai_command_specs(agent: Agent) -> list[OpenAIFunctionSpec]:
|
||||
Example:
|
||||
```ts
|
||||
// Get the current weather in a given location
|
||||
type get_current_weather = (_: {
|
||||
// The city and state, e.g. San Francisco, CA
|
||||
location: string,
|
||||
unit?: "celsius" | "fahrenheit",
|
||||
}) => any;
|
||||
```
|
||||
"""
|
||||
|
||||
def param_signature(p_spec: OpenAIFunctionSpec.ParameterSpec) -> str:
|
||||
# TODO: enum type support
|
||||
return (
|
||||
f"// {p_spec.description}\n" if p_spec.description else ""
|
||||
) + f"{p_spec.name}{'' if p_spec.required else '?'}: {p_spec.type},"
|
||||
|
||||
return "\n".join(
|
||||
[
|
||||
f"// {self.description}",
|
||||
f"type {self.name} = (_ :{{",
|
||||
*[param_signature(p) for p in self.parameters.values()],
|
||||
"}) => any;",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def get_openai_command_specs(
|
||||
command_registry: CommandRegistry,
|
||||
) -> list[OpenAIFunctionSpec]:
|
||||
"""Get OpenAI-consumable function specs for the agent's available commands.
|
||||
see https://platform.openai.com/docs/guides/gpt/function-calling
|
||||
"""
|
||||
if not agent.config.openai_functions:
|
||||
return []
|
||||
|
||||
return [
|
||||
OpenAIFunctionSpec(
|
||||
name=command.name,
|
||||
@@ -348,5 +376,48 @@ def get_openai_command_specs(agent: Agent) -> list[OpenAIFunctionSpec]:
|
||||
for param in command.parameters
|
||||
},
|
||||
)
|
||||
for command in agent.command_registry.commands.values()
|
||||
for command in command_registry.commands.values()
|
||||
]
|
||||
|
||||
|
||||
def count_openai_functions_tokens(
|
||||
functions: list[OpenAIFunctionSpec], for_model: str
|
||||
) -> int:
|
||||
"""Returns the number of tokens taken up by a set of function definitions
|
||||
|
||||
Reference: https://community.openai.com/t/how-to-calculate-the-tokens-when-using-function-call/266573/18
|
||||
"""
|
||||
from autogpt.llm.utils import count_string_tokens
|
||||
|
||||
return count_string_tokens(
|
||||
f"# Tools\n\n## functions\n\n{format_function_specs_as_typescript_ns(functions)}",
|
||||
for_model,
|
||||
)
|
||||
|
||||
|
||||
def format_function_specs_as_typescript_ns(functions: list[OpenAIFunctionSpec]) -> str:
|
||||
"""Returns a function signature block in the format used by OpenAI internally:
|
||||
https://community.openai.com/t/how-to-calculate-the-tokens-when-using-function-call/266573/18
|
||||
|
||||
For use with `count_string_tokens` to determine token usage of provided functions.
|
||||
|
||||
Example:
|
||||
```ts
|
||||
namespace functions {
|
||||
|
||||
// Get the current weather in a given location
|
||||
type get_current_weather = (_: {
|
||||
// The city and state, e.g. San Francisco, CA
|
||||
location: string,
|
||||
unit?: "celsius" | "fahrenheit",
|
||||
}) => any;
|
||||
|
||||
} // namespace functions
|
||||
```
|
||||
"""
|
||||
|
||||
return (
|
||||
"namespace functions {\n\n"
|
||||
+ "\n\n".join(f.prompt_format for f in functions)
|
||||
+ "\n\n} // namespace functions"
|
||||
)
|
||||
|
||||
@@ -7,12 +7,19 @@ from colorama import Fore
|
||||
from autogpt.config import Config
|
||||
|
||||
from ..api_manager import ApiManager
|
||||
from ..base import ChatModelResponse, ChatSequence, Message
|
||||
from ..base import (
|
||||
ChatModelResponse,
|
||||
ChatSequence,
|
||||
FunctionCallDict,
|
||||
Message,
|
||||
ResponseMessageDict,
|
||||
)
|
||||
from ..providers import openai as iopenai
|
||||
from ..providers.openai import (
|
||||
OPEN_AI_CHAT_MODELS,
|
||||
OpenAIFunctionCall,
|
||||
OpenAIFunctionSpec,
|
||||
count_openai_functions_tokens,
|
||||
)
|
||||
from .token_counter import *
|
||||
|
||||
@@ -114,7 +121,13 @@ def create_chat_completion(
|
||||
if temperature is None:
|
||||
temperature = config.temperature
|
||||
if max_tokens is None:
|
||||
max_tokens = OPEN_AI_CHAT_MODELS[model].max_tokens - prompt.token_length
|
||||
prompt_tlength = prompt.token_length
|
||||
max_tokens = OPEN_AI_CHAT_MODELS[model].max_tokens - prompt_tlength
|
||||
logger.debug(f"Prompt length: {prompt_tlength} tokens")
|
||||
if functions:
|
||||
functions_tlength = count_openai_functions_tokens(functions, model)
|
||||
max_tokens -= functions_tlength
|
||||
logger.debug(f"Functions take up {functions_tlength} tokens in API call")
|
||||
|
||||
logger.debug(
|
||||
f"{Fore.GREEN}Creating chat completion with model {model}, temperature {temperature}, max_tokens {max_tokens}{Fore.RESET}"
|
||||
@@ -143,9 +156,8 @@ def create_chat_completion(
|
||||
|
||||
if functions:
|
||||
chat_completion_kwargs["functions"] = [
|
||||
function.__dict__ for function in functions
|
||||
function.schema for function in functions
|
||||
]
|
||||
logger.debug(f"Function dicts: {chat_completion_kwargs['functions']}")
|
||||
|
||||
response = iopenai.create_chat_completion(
|
||||
messages=prompt.raw(),
|
||||
@@ -157,19 +169,24 @@ def create_chat_completion(
|
||||
logger.error(response.error)
|
||||
raise RuntimeError(response.error)
|
||||
|
||||
first_message = response.choices[0].message
|
||||
first_message: ResponseMessageDict = response.choices[0].message
|
||||
content: str | None = first_message.get("content")
|
||||
function_call: OpenAIFunctionCall | None = first_message.get("function_call")
|
||||
function_call: FunctionCallDict | None = first_message.get("function_call")
|
||||
|
||||
for plugin in config.plugins:
|
||||
if not plugin.can_handle_on_response():
|
||||
continue
|
||||
# TODO: function call support in plugin.on_response()
|
||||
content = plugin.on_response(content)
|
||||
|
||||
return ChatModelResponse(
|
||||
model_info=OPEN_AI_CHAT_MODELS[model],
|
||||
content=content,
|
||||
function_call=function_call,
|
||||
function_call=OpenAIFunctionCall(
|
||||
name=function_call["name"], arguments=function_call["arguments"]
|
||||
)
|
||||
if function_call
|
||||
else None,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user