diff --git a/tests/test_token_counter.py b/tests/test_token_counter.py index f7c84672..ac25796d 100644 --- a/tests/test_token_counter.py +++ b/tests/test_token_counter.py @@ -1,72 +1,73 @@ -import unittest +import pytest import tests.context from autogpt.token_counter import count_message_tokens, count_string_tokens -class TestTokenCounter(unittest.TestCase): - def test_count_message_tokens(self): - messages = [ - {"role": "user", "content": "Hello"}, - {"role": "assistant", "content": "Hi there!"}, - ] - self.assertEqual(count_message_tokens(messages), 17) - - def test_count_message_tokens_with_name(self): - messages = [ - {"role": "user", "content": "Hello", "name": "John"}, - {"role": "assistant", "content": "Hi there!"}, - ] - self.assertEqual(count_message_tokens(messages), 17) - - def test_count_message_tokens_empty_input(self): - # Empty input should return 3 tokens - self.assertEqual(count_message_tokens([]), 3) - - def test_count_message_tokens_invalid_model(self): - # Invalid model should raise a KeyError - messages = [ - {"role": "user", "content": "Hello"}, - {"role": "assistant", "content": "Hi there!"}, - ] - with self.assertRaises(KeyError): - count_message_tokens(messages, model="invalid_model") - - def test_count_message_tokens_gpt_4(self): - messages = [ - {"role": "user", "content": "Hello"}, - {"role": "assistant", "content": "Hi there!"}, - ] - self.assertEqual(count_message_tokens(messages, model="gpt-4-0314"), 15) - - def test_count_string_tokens(self): - """Test that the string tokens are counted correctly.""" - - string = "Hello, world!" - self.assertEqual( - count_string_tokens(string, model_name="gpt-3.5-turbo-0301"), 4 - ) - - def test_count_string_tokens_empty_input(self): - """Test that the string tokens are counted correctly.""" - - self.assertEqual(count_string_tokens("", model_name="gpt-3.5-turbo-0301"), 0) - - def test_count_message_tokens_invalid_model(self): - # Invalid model should raise a NotImplementedError - messages = [ - {"role": "user", "content": "Hello"}, - {"role": "assistant", "content": "Hi there!"}, - ] - with self.assertRaises(NotImplementedError): - count_message_tokens(messages, model="invalid_model") - - def test_count_string_tokens_gpt_4(self): - """Test that the string tokens are counted correctly.""" - - string = "Hello, world!" - self.assertEqual(count_string_tokens(string, model_name="gpt-4-0314"), 4) +def test_count_message_tokens(): + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ] + assert count_message_tokens(messages) == 17 -if __name__ == "__main__": - unittest.main() +def test_count_message_tokens_with_name(): + messages = [ + {"role": "user", "content": "Hello", "name": "John"}, + {"role": "assistant", "content": "Hi there!"}, + ] + assert count_message_tokens(messages) == 17 + + +def test_count_message_tokens_empty_input(): + """Empty input should return 3 tokens""" + assert count_message_tokens([]) == 3 + + +def test_count_message_tokens_invalid_model(): + """Invalid model should raise a KeyError""" + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ] + with pytest.raises(KeyError): + count_message_tokens(messages, model="invalid_model") + + +def test_count_message_tokens_gpt_4(): + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ] + assert count_message_tokens(messages, model="gpt-4-0314") == 15 + + +def test_count_string_tokens(): + """Test that the string tokens are counted correctly.""" + + string = "Hello, world!" + assert count_string_tokens(string, model_name="gpt-3.5-turbo-0301") == 4 + + +def test_count_string_tokens_empty_input(): + """Test that the string tokens are counted correctly.""" + + assert count_string_tokens("", model_name="gpt-3.5-turbo-0301") == 0 + + +def test_count_message_tokens_invalid_model(): + """Invalid model should raise a NotImplementedError""" + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ] + with pytest.raises(NotImplementedError): + count_message_tokens(messages, model="invalid_model") + + +def test_count_string_tokens_gpt_4(): + """Test that the string tokens are counted correctly.""" + + string = "Hello, world!" + assert count_string_tokens(string, model_name="gpt-4-0314") == 4