Remove unittest in favor of pytest in the test_token_counter module (#3453)

* init remove unittest for pytest

* docstrings

* black

---------

Co-authored-by: James Collins <collijk@uw.edu>
This commit is contained in:
Media
2023-04-28 18:48:30 +02:00
committed by GitHub
parent cf5fdabdfc
commit aebe891489

View File

@@ -1,72 +1,73 @@
import unittest
import pytest
import tests.context
from autogpt.token_counter import count_message_tokens, count_string_tokens
class TestTokenCounter(unittest.TestCase):
def test_count_message_tokens(self):
messages = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
]
self.assertEqual(count_message_tokens(messages), 17)
def test_count_message_tokens_with_name(self):
messages = [
{"role": "user", "content": "Hello", "name": "John"},
{"role": "assistant", "content": "Hi there!"},
]
self.assertEqual(count_message_tokens(messages), 17)
def test_count_message_tokens_empty_input(self):
# Empty input should return 3 tokens
self.assertEqual(count_message_tokens([]), 3)
def test_count_message_tokens_invalid_model(self):
# Invalid model should raise a KeyError
messages = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
]
with self.assertRaises(KeyError):
count_message_tokens(messages, model="invalid_model")
def test_count_message_tokens_gpt_4(self):
messages = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
]
self.assertEqual(count_message_tokens(messages, model="gpt-4-0314"), 15)
def test_count_string_tokens(self):
"""Test that the string tokens are counted correctly."""
string = "Hello, world!"
self.assertEqual(
count_string_tokens(string, model_name="gpt-3.5-turbo-0301"), 4
)
def test_count_string_tokens_empty_input(self):
"""Test that the string tokens are counted correctly."""
self.assertEqual(count_string_tokens("", model_name="gpt-3.5-turbo-0301"), 0)
def test_count_message_tokens_invalid_model(self):
# Invalid model should raise a NotImplementedError
messages = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
]
with self.assertRaises(NotImplementedError):
count_message_tokens(messages, model="invalid_model")
def test_count_string_tokens_gpt_4(self):
"""Test that the string tokens are counted correctly."""
string = "Hello, world!"
self.assertEqual(count_string_tokens(string, model_name="gpt-4-0314"), 4)
def test_count_message_tokens():
messages = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
]
assert count_message_tokens(messages) == 17
if __name__ == "__main__":
unittest.main()
def test_count_message_tokens_with_name():
messages = [
{"role": "user", "content": "Hello", "name": "John"},
{"role": "assistant", "content": "Hi there!"},
]
assert count_message_tokens(messages) == 17
def test_count_message_tokens_empty_input():
"""Empty input should return 3 tokens"""
assert count_message_tokens([]) == 3
def test_count_message_tokens_invalid_model():
"""Invalid model should raise a KeyError"""
messages = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
]
with pytest.raises(KeyError):
count_message_tokens(messages, model="invalid_model")
def test_count_message_tokens_gpt_4():
messages = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
]
assert count_message_tokens(messages, model="gpt-4-0314") == 15
def test_count_string_tokens():
"""Test that the string tokens are counted correctly."""
string = "Hello, world!"
assert count_string_tokens(string, model_name="gpt-3.5-turbo-0301") == 4
def test_count_string_tokens_empty_input():
"""Test that the string tokens are counted correctly."""
assert count_string_tokens("", model_name="gpt-3.5-turbo-0301") == 0
def test_count_message_tokens_invalid_model():
"""Invalid model should raise a NotImplementedError"""
messages = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
]
with pytest.raises(NotImplementedError):
count_message_tokens(messages, model="invalid_model")
def test_count_string_tokens_gpt_4():
"""Test that the string tokens are counted correctly."""
string = "Hello, world!"
assert count_string_tokens(string, model_name="gpt-4-0314") == 4