"""Functions for counting the number of tokens in a message or string.""" from __future__ import annotations import logging from typing import List, overload import tiktoken from autogpt.llm.base import Message logger = logging.getLogger(__name__) @overload def count_message_tokens(messages: Message, model: str = "gpt-3.5-turbo") -> int: ... @overload def count_message_tokens(messages: List[Message], model: str = "gpt-3.5-turbo") -> int: ... def count_message_tokens( messages: Message | List[Message], model: str = "gpt-3.5-turbo" ) -> int: """ Returns the number of tokens used by a list of messages. Args: messages (list): A list of messages, each of which is a dictionary containing the role and content of the message. model (str): The name of the model to use for tokenization. Defaults to "gpt-3.5-turbo-0301". Returns: int: The number of tokens used by the list of messages. """ if isinstance(messages, Message): messages = [messages] if model.startswith("gpt-3.5-turbo"): tokens_per_message = ( 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n ) tokens_per_name = -1 # if there's a name, the role is omitted encoding_model = "gpt-3.5-turbo" elif model.startswith("gpt-4"): tokens_per_message = 3 tokens_per_name = 1 encoding_model = "gpt-4" else: raise NotImplementedError( f"count_message_tokens() is not implemented for model {model}.\n" " See https://github.com/openai/openai-python/blob/main/chatml.md for" " information on how messages are converted to tokens." ) try: encoding = tiktoken.encoding_for_model(encoding_model) except KeyError: logger.warn("Warning: model not found. Using cl100k_base encoding.") encoding = tiktoken.get_encoding("cl100k_base") num_tokens = 0 for message in messages: num_tokens += tokens_per_message for key, value in message.raw().items(): num_tokens += len(encoding.encode(value)) if key == "name": num_tokens += tokens_per_name num_tokens += 3 # every reply is primed with <|start|>assistant<|message|> return num_tokens def count_string_tokens(string: str, model_name: str) -> int: """ Returns the number of tokens in a text string. Args: string (str): The text string. model_name (str): The name of the encoding to use. (e.g., "gpt-3.5-turbo") Returns: int: The number of tokens in the text string. """ encoding = tiktoken.encoding_for_model(model_name) return len(encoding.encode(string))