Files
Auto-GPT/tests/test_token_counter.py
Andres Caicedo f8dfedf1c6 Add function and class descriptions to tests (#2715)
Co-authored-by: Reinier van der Leer <github@pwuts.nl>
2023-04-24 14:55:49 +02:00

73 lines
2.5 KiB
Python

import unittest
import tests.context
from autogpt.token_counter import count_message_tokens, count_string_tokens
class TestTokenCounter(unittest.TestCase):
def test_count_message_tokens(self):
messages = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
]
self.assertEqual(count_message_tokens(messages), 17)
def test_count_message_tokens_with_name(self):
messages = [
{"role": "user", "content": "Hello", "name": "John"},
{"role": "assistant", "content": "Hi there!"},
]
self.assertEqual(count_message_tokens(messages), 17)
def test_count_message_tokens_empty_input(self):
# Empty input should return 3 tokens
self.assertEqual(count_message_tokens([]), 3)
def test_count_message_tokens_invalid_model(self):
# Invalid model should raise a KeyError
messages = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
]
with self.assertRaises(KeyError):
count_message_tokens(messages, model="invalid_model")
def test_count_message_tokens_gpt_4(self):
messages = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
]
self.assertEqual(count_message_tokens(messages, model="gpt-4-0314"), 15)
def test_count_string_tokens(self):
"""Test that the string tokens are counted correctly."""
string = "Hello, world!"
self.assertEqual(
count_string_tokens(string, model_name="gpt-3.5-turbo-0301"), 4
)
def test_count_string_tokens_empty_input(self):
"""Test that the string tokens are counted correctly."""
self.assertEqual(count_string_tokens("", model_name="gpt-3.5-turbo-0301"), 0)
def test_count_message_tokens_invalid_model(self):
# Invalid model should raise a NotImplementedError
messages = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
]
with self.assertRaises(NotImplementedError):
count_message_tokens(messages, model="invalid_model")
def test_count_string_tokens_gpt_4(self):
"""Test that the string tokens are counted correctly."""
string = "Hello, world!"
self.assertEqual(count_string_tokens(string, model_name="gpt-4-0314"), 4)
if __name__ == "__main__":
unittest.main()