mirror of
https://github.com/aljazceru/Auto-GPT.git
synced 2026-02-11 09:14:19 +01:00
test: Write unit tests for token_counter
This commit is contained in:
@@ -27,7 +27,7 @@ def count_message_tokens(
|
||||
logger.warn("Warning: model not found. Using cl100k_base encoding.")
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
if model == "gpt-3.5-turbo":
|
||||
# !Node: gpt-3.5-turbo may change over time.
|
||||
# !Note: gpt-3.5-turbo may change over time.
|
||||
# Returning num tokens assuming gpt-3.5-turbo-0301.")
|
||||
return count_message_tokens(messages, model="gpt-3.5-turbo-0301")
|
||||
elif model == "gpt-4":
|
||||
|
||||
61
tests/test_token_counter.py
Normal file
61
tests/test_token_counter.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import unittest
|
||||
import tests.context
|
||||
|
||||
from scripts.token_counter import count_message_tokens, count_string_tokens
|
||||
|
||||
class TestTokenCounter(unittest.TestCase):
|
||||
|
||||
def test_count_message_tokens(self):
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"}
|
||||
]
|
||||
self.assertEqual(count_message_tokens(messages), 17)
|
||||
|
||||
def test_count_message_tokens_with_name(self):
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello", "name": "John"},
|
||||
{"role": "assistant", "content": "Hi there!"}
|
||||
]
|
||||
self.assertEqual(count_message_tokens(messages), 17)
|
||||
|
||||
def test_count_message_tokens_empty_input(self):
|
||||
self.assertEqual(count_message_tokens([]), 3)
|
||||
|
||||
def test_count_message_tokens_invalid_model(self):
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"}
|
||||
]
|
||||
with self.assertRaises(KeyError):
|
||||
count_message_tokens(messages, model="invalid_model")
|
||||
|
||||
def test_count_message_tokens_gpt_4(self):
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"}
|
||||
]
|
||||
self.assertEqual(count_message_tokens(messages, model="gpt-4-0314"), 15)
|
||||
|
||||
def test_count_string_tokens(self):
|
||||
string = "Hello, world!"
|
||||
self.assertEqual(count_string_tokens(string, model_name="gpt-3.5-turbo-0301"), 4)
|
||||
|
||||
def test_count_string_tokens_empty_input(self):
|
||||
self.assertEqual(count_string_tokens("", model_name="gpt-3.5-turbo-0301"), 0)
|
||||
|
||||
def test_count_message_tokens_invalid_model(self):
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"}
|
||||
]
|
||||
with self.assertRaises(NotImplementedError):
|
||||
count_message_tokens(messages, model="invalid_model")
|
||||
|
||||
def test_count_string_tokens_gpt_4(self):
|
||||
string = "Hello, world!"
|
||||
self.assertEqual(count_string_tokens(string, model_name="gpt-4-0314"), 4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user