diff --git a/autogpt/token_counter.py b/autogpt/token_counter.py index c1239722..a85a54be 100644 --- a/autogpt/token_counter.py +++ b/autogpt/token_counter.py @@ -27,7 +27,7 @@ def count_message_tokens( logger.warn("Warning: model not found. Using cl100k_base encoding.") encoding = tiktoken.get_encoding("cl100k_base") if model == "gpt-3.5-turbo": - # !Node: gpt-3.5-turbo may change over time. + # !Note: gpt-3.5-turbo may change over time. # Returning num tokens assuming gpt-3.5-turbo-0301.") return count_message_tokens(messages, model="gpt-3.5-turbo-0301") elif model == "gpt-4": diff --git a/tests/test_token_counter.py b/tests/test_token_counter.py new file mode 100644 index 00000000..d13f2ae0 --- /dev/null +++ b/tests/test_token_counter.py @@ -0,0 +1,61 @@ +import unittest +import tests.context + +from scripts.token_counter import count_message_tokens, count_string_tokens + +class TestTokenCounter(unittest.TestCase): + + def test_count_message_tokens(self): + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"} + ] + self.assertEqual(count_message_tokens(messages), 17) + + def test_count_message_tokens_with_name(self): + messages = [ + {"role": "user", "content": "Hello", "name": "John"}, + {"role": "assistant", "content": "Hi there!"} + ] + self.assertEqual(count_message_tokens(messages), 17) + + def test_count_message_tokens_empty_input(self): + self.assertEqual(count_message_tokens([]), 3) + + def test_count_message_tokens_invalid_model(self): + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"} + ] + with self.assertRaises(KeyError): + count_message_tokens(messages, model="invalid_model") + + def test_count_message_tokens_gpt_4(self): + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"} + ] + self.assertEqual(count_message_tokens(messages, model="gpt-4-0314"), 15) + + def test_count_string_tokens(self): + string = "Hello, world!" + self.assertEqual(count_string_tokens(string, model_name="gpt-3.5-turbo-0301"), 4) + + def test_count_string_tokens_empty_input(self): + self.assertEqual(count_string_tokens("", model_name="gpt-3.5-turbo-0301"), 0) + + def test_count_message_tokens_invalid_model(self): + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"} + ] + with self.assertRaises(NotImplementedError): + count_message_tokens(messages, model="invalid_model") + + def test_count_string_tokens_gpt_4(self): + string = "Hello, world!" + self.assertEqual(count_string_tokens(string, model_name="gpt-4-0314"), 4) + + +if __name__ == '__main__': + unittest.main()