diff --git a/autogpt/llm/base.py b/autogpt/llm/base.py index 5cf4993f..14a146b3 100644 --- a/autogpt/llm/base.py +++ b/autogpt/llm/base.py @@ -20,6 +20,17 @@ class MessageDict(TypedDict): content: str +class ResponseMessageDict(TypedDict): + role: Literal["assistant"] + content: Optional[str] + function_call: Optional[FunctionCallDict] + + +class FunctionCallDict(TypedDict): + name: str + arguments: str + + @dataclass class Message: """OpenAI Message object containing a role and the message content""" @@ -167,8 +178,6 @@ class LLMResponse: """Standard response struct for a response from an LLM model.""" model_info: ModelInfo - prompt_tokens_used: int = 0 - completion_tokens_used: int = 0 @dataclass @@ -177,14 +186,10 @@ class EmbeddingModelResponse(LLMResponse): embedding: list[float] = field(default_factory=list) - def __post_init__(self): - if self.completion_tokens_used: - raise ValueError("Embeddings should not have completion tokens used.") - @dataclass class ChatModelResponse(LLMResponse): - """Standard response struct for a response from an LLM model.""" + """Standard response struct for a response from a chat LLM.""" - content: Optional[str] = None - function_call: Optional[OpenAIFunctionCall] = None + content: Optional[str] + function_call: Optional[OpenAIFunctionCall] diff --git a/autogpt/llm/chat.py b/autogpt/llm/chat.py index 14e06737..4364cb1d 100644 --- a/autogpt/llm/chat.py +++ b/autogpt/llm/chat.py @@ -3,14 +3,16 @@ from __future__ import annotations import time from typing import TYPE_CHECKING -from autogpt.llm.providers.openai import get_openai_command_specs - if TYPE_CHECKING: from autogpt.agent.agent import Agent from autogpt.config import Config from autogpt.llm.api_manager import ApiManager from autogpt.llm.base import ChatSequence, Message +from autogpt.llm.providers.openai import ( + count_openai_functions_tokens, + get_openai_command_specs, +) from autogpt.llm.utils import count_message_tokens, create_chat_completion from autogpt.logs import CURRENT_CONTEXT_FILE_NAME, logger @@ -72,23 +74,17 @@ def chat_with_ai( ], ) - # Add messages from the full message history until we reach the token limit - next_message_to_add_index = len(agent.history) - 1 - insertion_index = len(message_sequence) # Count the currently used tokens current_tokens_used = message_sequence.token_length + insertion_index = len(message_sequence) - # while current_tokens_used > 2500: - # # remove memories until we are under 2500 tokens - # relevant_memory = relevant_memory[:-1] - # ( - # next_message_to_add_index, - # current_tokens_used, - # insertion_index, - # current_context, - # ) = generate_context( - # prompt, relevant_memory, agent.history, model - # ) + # Account for tokens used by OpenAI functions + openai_functions = None + if agent.config.openai_functions: + openai_functions = get_openai_command_specs(agent.command_registry) + functions_tlength = count_openai_functions_tokens(openai_functions, model) + current_tokens_used += functions_tlength + logger.debug(f"OpenAI Functions take up {functions_tlength} tokens in API call") # Account for user input (appended later) user_input_msg = Message("user", triggering_prompt) @@ -97,7 +93,8 @@ def chat_with_ai( current_tokens_used += agent.history.max_summary_tlength # Reserve space current_tokens_used += 500 # Reserve space for the openai functions TODO improve - # Add Messages until the token limit is reached or there are no more messages to add. + # Add historical Messages until the token limit is reached + # or there are no more messages to add. for cycle in reversed(list(agent.history.per_cycle())): messages_to_add = [msg for msg in cycle if msg is not None] tokens_to_add = count_message_tokens(messages_to_add, model) @@ -162,6 +159,8 @@ def chat_with_ai( logger.debug(f"Plugins remaining at stop: {plugin_count - i}") break message_sequence.add("system", plugin_response) + current_tokens_used += tokens_to_add + # Calculate remaining tokens tokens_remaining = token_limit - current_tokens_used # assert tokens_remaining >= 0, "Tokens remaining is negative. @@ -193,7 +192,7 @@ def chat_with_ai( assistant_reply = create_chat_completion( prompt=message_sequence, config=agent.config, - functions=get_openai_command_specs(agent), + functions=openai_functions, max_tokens=tokens_remaining, ) diff --git a/autogpt/llm/providers/openai.py b/autogpt/llm/providers/openai.py index baf7ab87..933f9435 100644 --- a/autogpt/llm/providers/openai.py +++ b/autogpt/llm/providers/openai.py @@ -3,7 +3,7 @@ from __future__ import annotations import functools import time from dataclasses import dataclass -from typing import TYPE_CHECKING, List, Optional +from typing import List, Optional from unittest.mock import patch import openai @@ -12,9 +12,6 @@ from colorama import Fore, Style from openai.error import APIError, RateLimitError, ServiceUnavailableError, Timeout from openai.openai_object import OpenAIObject -if TYPE_CHECKING: - from autogpt.agent.agent import Agent - from autogpt.llm.base import ( ChatModelInfo, EmbeddingModelInfo, @@ -23,6 +20,7 @@ from autogpt.llm.base import ( TText, ) from autogpt.logs import logger +from autogpt.models.command_registry import CommandRegistry OPEN_AI_CHAT_MODELS = { info.name: info @@ -301,13 +299,13 @@ class OpenAIFunctionSpec: @dataclass class ParameterSpec: name: str - type: str + type: str # TODO: add enum support description: Optional[str] required: bool = False @property - def __dict__(self): - """Output an OpenAI-consumable function specification""" + def schema(self) -> dict[str, str | dict | list]: + """Returns an OpenAI-consumable function specification""" return { "name": self.name, "description": self.description, @@ -326,14 +324,44 @@ class OpenAIFunctionSpec: }, } + @property + def prompt_format(self) -> str: + """Returns the function formatted similarly to the way OpenAI does it internally: + https://community.openai.com/t/how-to-calculate-the-tokens-when-using-function-call/266573/18 -def get_openai_command_specs(agent: Agent) -> list[OpenAIFunctionSpec]: + Example: + ```ts + // Get the current weather in a given location + type get_current_weather = (_: { + // The city and state, e.g. San Francisco, CA + location: string, + unit?: "celsius" | "fahrenheit", + }) => any; + ``` + """ + + def param_signature(p_spec: OpenAIFunctionSpec.ParameterSpec) -> str: + # TODO: enum type support + return ( + f"// {p_spec.description}\n" if p_spec.description else "" + ) + f"{p_spec.name}{'' if p_spec.required else '?'}: {p_spec.type}," + + return "\n".join( + [ + f"// {self.description}", + f"type {self.name} = (_ :{{", + *[param_signature(p) for p in self.parameters.values()], + "}) => any;", + ] + ) + + +def get_openai_command_specs( + command_registry: CommandRegistry, +) -> list[OpenAIFunctionSpec]: """Get OpenAI-consumable function specs for the agent's available commands. see https://platform.openai.com/docs/guides/gpt/function-calling """ - if not agent.config.openai_functions: - return [] - return [ OpenAIFunctionSpec( name=command.name, @@ -348,5 +376,48 @@ def get_openai_command_specs(agent: Agent) -> list[OpenAIFunctionSpec]: for param in command.parameters }, ) - for command in agent.command_registry.commands.values() + for command in command_registry.commands.values() ] + + +def count_openai_functions_tokens( + functions: list[OpenAIFunctionSpec], for_model: str +) -> int: + """Returns the number of tokens taken up by a set of function definitions + + Reference: https://community.openai.com/t/how-to-calculate-the-tokens-when-using-function-call/266573/18 + """ + from autogpt.llm.utils import count_string_tokens + + return count_string_tokens( + f"# Tools\n\n## functions\n\n{format_function_specs_as_typescript_ns(functions)}", + for_model, + ) + + +def format_function_specs_as_typescript_ns(functions: list[OpenAIFunctionSpec]) -> str: + """Returns a function signature block in the format used by OpenAI internally: + https://community.openai.com/t/how-to-calculate-the-tokens-when-using-function-call/266573/18 + + For use with `count_string_tokens` to determine token usage of provided functions. + + Example: + ```ts + namespace functions { + + // Get the current weather in a given location + type get_current_weather = (_: { + // The city and state, e.g. San Francisco, CA + location: string, + unit?: "celsius" | "fahrenheit", + }) => any; + + } // namespace functions + ``` + """ + + return ( + "namespace functions {\n\n" + + "\n\n".join(f.prompt_format for f in functions) + + "\n\n} // namespace functions" + ) diff --git a/autogpt/llm/utils/__init__.py b/autogpt/llm/utils/__init__.py index 3c2835b7..74e88dc6 100644 --- a/autogpt/llm/utils/__init__.py +++ b/autogpt/llm/utils/__init__.py @@ -7,12 +7,19 @@ from colorama import Fore from autogpt.config import Config from ..api_manager import ApiManager -from ..base import ChatModelResponse, ChatSequence, Message +from ..base import ( + ChatModelResponse, + ChatSequence, + FunctionCallDict, + Message, + ResponseMessageDict, +) from ..providers import openai as iopenai from ..providers.openai import ( OPEN_AI_CHAT_MODELS, OpenAIFunctionCall, OpenAIFunctionSpec, + count_openai_functions_tokens, ) from .token_counter import * @@ -114,7 +121,13 @@ def create_chat_completion( if temperature is None: temperature = config.temperature if max_tokens is None: - max_tokens = OPEN_AI_CHAT_MODELS[model].max_tokens - prompt.token_length + prompt_tlength = prompt.token_length + max_tokens = OPEN_AI_CHAT_MODELS[model].max_tokens - prompt_tlength + logger.debug(f"Prompt length: {prompt_tlength} tokens") + if functions: + functions_tlength = count_openai_functions_tokens(functions, model) + max_tokens -= functions_tlength + logger.debug(f"Functions take up {functions_tlength} tokens in API call") logger.debug( f"{Fore.GREEN}Creating chat completion with model {model}, temperature {temperature}, max_tokens {max_tokens}{Fore.RESET}" @@ -143,9 +156,8 @@ def create_chat_completion( if functions: chat_completion_kwargs["functions"] = [ - function.__dict__ for function in functions + function.schema for function in functions ] - logger.debug(f"Function dicts: {chat_completion_kwargs['functions']}") response = iopenai.create_chat_completion( messages=prompt.raw(), @@ -157,19 +169,24 @@ def create_chat_completion( logger.error(response.error) raise RuntimeError(response.error) - first_message = response.choices[0].message + first_message: ResponseMessageDict = response.choices[0].message content: str | None = first_message.get("content") - function_call: OpenAIFunctionCall | None = first_message.get("function_call") + function_call: FunctionCallDict | None = first_message.get("function_call") for plugin in config.plugins: if not plugin.can_handle_on_response(): continue + # TODO: function call support in plugin.on_response() content = plugin.on_response(content) return ChatModelResponse( model_info=OPEN_AI_CHAT_MODELS[model], content=content, - function_call=function_call, + function_call=OpenAIFunctionCall( + name=function_call["name"], arguments=function_call["arguments"] + ) + if function_call + else None, )