Add function and class descriptions to tests (#2715)

Co-authored-by: Reinier van der Leer <github@pwuts.nl>
This commit is contained in:
Andres Caicedo
2023-04-24 14:55:49 +02:00
committed by GitHub
parent 40a75c804c
commit f8dfedf1c6
19 changed files with 147 additions and 67 deletions

View File

@@ -10,7 +10,10 @@ from browse import extract_hyperlinks
class TestBrowseLinks(unittest.TestCase): class TestBrowseLinks(unittest.TestCase):
"""Unit tests for the browse module functions that extract hyperlinks."""
def test_extract_hyperlinks(self): def test_extract_hyperlinks(self):
"""Test the extract_hyperlinks function with a simple HTML body."""
body = """ body = """
<body> <body>
<a href="https://google.com">Google</a> <a href="https://google.com">Google</a>

View File

@@ -1,6 +1,7 @@
import os import os
import sys import sys
# Add the scripts directory to the path so that we can import the browse module.
sys.path.insert( sys.path.insert(
0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../scripts")) 0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../scripts"))
) )

View File

@@ -9,10 +9,11 @@ from autogpt.memory.local import LocalCache
class TestLocalCache(unittest.TestCase): class TestLocalCache(unittest.TestCase):
def random_string(self, length): def generate_random_string(self, length):
return "".join(random.choice(string.ascii_letters) for _ in range(length)) return "".join(random.choice(string.ascii_letters) for _ in range(length))
def setUp(self): def setUp(self):
"""Set up the test environment for the LocalCache tests."""
cfg = cfg = Config() cfg = cfg = Config()
self.cache = LocalCache(cfg) self.cache = LocalCache(cfg)
self.cache.clear() self.cache.clear()
@@ -24,15 +25,15 @@ class TestLocalCache(unittest.TestCase):
"The cake is a lie, but the pie is always true", "The cake is a lie, but the pie is always true",
"ChatGPT is an advanced AI model for conversation", "ChatGPT is an advanced AI model for conversation",
] ]
for text in self.example_texts: for text in self.example_texts:
self.cache.add(text) self.cache.add(text)
# Add some random strings to test noise # Add some random strings to test noise
for _ in range(5): for _ in range(5):
self.cache.add(self.random_string(10)) self.cache.add(self.generate_random_string(10))
def test_get_relevant(self): def test_get_relevant(self):
"""Test getting relevant texts from the cache."""
query = "I'm interested in artificial intelligence and NLP" query = "I'm interested in artificial intelligence and NLP"
k = 3 k = 3
relevant_texts = self.cache.get_relevant(query, k) relevant_texts = self.cache.get_relevant(query, k)

View File

@@ -10,14 +10,12 @@ from autogpt.memory.milvus import MilvusMemory
try: try:
class TestMilvusMemory(unittest.TestCase): class TestMilvusMemory(unittest.TestCase):
"""Tests for the MilvusMemory class.""" """Unit tests for the MilvusMemory class."""
def random_string(self, length: int) -> str: def generate_random_string(self, length: int) -> str:
"""Generate a random string of the given length."""
return "".join(random.choice(string.ascii_letters) for _ in range(length)) return "".join(random.choice(string.ascii_letters) for _ in range(length))
def setUp(self) -> None: def setUp(self) -> None:
"""Set up the test environment."""
cfg = Config() cfg = Config()
cfg.milvus_addr = "localhost:19530" cfg.milvus_addr = "localhost:19530"
self.memory = MilvusMemory(cfg) self.memory = MilvusMemory(cfg)
@@ -36,7 +34,7 @@ try:
# Add some random strings to test noise # Add some random strings to test noise
for _ in range(5): for _ in range(5):
self.memory.add(self.random_string(10)) self.memory.add(self.generate_random_string(10))
def test_get_relevant(self) -> None: def test_get_relevant(self) -> None:
"""Test getting relevant texts from the cache.""" """Test getting relevant texts from the cache."""

View File

@@ -16,6 +16,7 @@ class TestWeaviateMemory(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
"""Set up the test environment for the WeaviateMemory tests."""
# only create the connection to weaviate once # only create the connection to weaviate once
cls.cfg = Config() cls.cfg = Config()
@@ -47,6 +48,7 @@ class TestWeaviateMemory(unittest.TestCase):
""" """
def setUp(self): def setUp(self):
"""Set up the test environment for the WeaviateMemory tests."""
try: try:
self.client.schema.delete_class(self.index) self.client.schema.delete_class(self.index)
except: except:
@@ -55,6 +57,7 @@ class TestWeaviateMemory(unittest.TestCase):
self.memory = WeaviateMemory(self.cfg) self.memory = WeaviateMemory(self.cfg)
def test_add(self): def test_add(self):
"""Test adding a text to the cache"""
doc = "You are a Titan name Thanos and you are looking for the Infinity Stones" doc = "You are a Titan name Thanos and you are looking for the Infinity Stones"
self.memory.add(doc) self.memory.add(doc)
result = self.client.query.get(self.index, ["raw_text"]).do() result = self.client.query.get(self.index, ["raw_text"]).do()
@@ -64,8 +67,9 @@ class TestWeaviateMemory(unittest.TestCase):
self.assertEqual(actual[0]["raw_text"], doc) self.assertEqual(actual[0]["raw_text"], doc)
def test_get(self): def test_get(self):
"""Test getting a text from the cache"""
doc = "You are an Avenger and swore to defend the Galaxy from a menace called Thanos" doc = "You are an Avenger and swore to defend the Galaxy from a menace called Thanos"
# add the document to the cache
with self.client.batch as batch: with self.client.batch as batch:
batch.add_data_object( batch.add_data_object(
uuid=get_valid_uuid(uuid4()), uuid=get_valid_uuid(uuid4()),
@@ -82,6 +86,7 @@ class TestWeaviateMemory(unittest.TestCase):
self.assertEqual(actual[0], doc) self.assertEqual(actual[0], doc)
def test_get_stats(self): def test_get_stats(self):
"""Test getting the stats of the cache"""
docs = [ docs = [
"You are now about to count the number of docs in this index", "You are now about to count the number of docs in this index",
"And then you about to find out if you can count correctly", "And then you about to find out if you can count correctly",
@@ -96,6 +101,7 @@ class TestWeaviateMemory(unittest.TestCase):
self.assertEqual(stats["count"], 2) self.assertEqual(stats["count"], 2)
def test_clear(self): def test_clear(self):
"""Test clearing the cache"""
docs = [ docs = [
"Shame this is the last test for this class", "Shame this is the last test for this class",
"Testing is fun when someone else is doing it", "Testing is fun when someone else is doing it",

View File

@@ -8,7 +8,8 @@ try:
from autogpt.memory.milvus import MilvusMemory from autogpt.memory.milvus import MilvusMemory
def mock_config() -> dict: def mock_config() -> dict:
"""Mock the Config class""" """Mock the config object for testing purposes."""
# Return a mock config object with the required attributes
return type( return type(
"MockConfig", "MockConfig",
(object,), (object,),

View File

@@ -3,4 +3,5 @@ from autogpt.commands.command import command
@command("function_based", "Function-based test command") @command("function_based", "Function-based test command")
def function_based(arg1: int, arg2: str) -> str: def function_based(arg1: int, arg2: str) -> str:
"""A function-based test command that returns a string with the two arguments separated by a dash."""
return f"{arg1} - {arg2}" return f"{arg1} - {arg2}"

View File

@@ -9,41 +9,55 @@ from autogpt.commands.command import Command, CommandRegistry
class TestCommand: class TestCommand:
"""Test cases for the Command class."""
@staticmethod @staticmethod
def example_function(arg1: int, arg2: str) -> str: def example_command_method(arg1: int, arg2: str) -> str:
"""Example function for testing the Command class."""
# This function is static because it is not used by any other test cases.
return f"{arg1} - {arg2}" return f"{arg1} - {arg2}"
def test_command_creation(self): def test_command_creation(self):
"""Test that a Command object can be created with the correct attributes."""
cmd = Command( cmd = Command(
name="example", description="Example command", method=self.example_function name="example",
description="Example command",
method=self.example_command_method,
) )
assert cmd.name == "example" assert cmd.name == "example"
assert cmd.description == "Example command" assert cmd.description == "Example command"
assert cmd.method == self.example_function assert cmd.method == self.example_command_method
assert cmd.signature == "(arg1: int, arg2: str) -> str" assert cmd.signature == "(arg1: int, arg2: str) -> str"
def test_command_call(self): def test_command_call(self):
"""Test that Command(*args) calls and returns the result of method(*args)."""
# Create a Command object with the example_command_method.
cmd = Command( cmd = Command(
name="example", description="Example command", method=self.example_function name="example",
description="Example command",
method=self.example_command_method,
) )
result = cmd(arg1=1, arg2="test") result = cmd(arg1=1, arg2="test")
assert result == "1 - test" assert result == "1 - test"
def test_command_call_with_invalid_arguments(self): def test_command_call_with_invalid_arguments(self):
"""Test that calling a Command object with invalid arguments raises a TypeError."""
cmd = Command( cmd = Command(
name="example", description="Example command", method=self.example_function name="example",
description="Example command",
method=self.example_command_method,
) )
with pytest.raises(TypeError): with pytest.raises(TypeError):
cmd(arg1="invalid", does_not_exist="test") cmd(arg1="invalid", does_not_exist="test")
def test_command_default_signature(self): def test_command_default_signature(self):
"""Test that the default signature is generated correctly."""
cmd = Command( cmd = Command(
name="example", description="Example command", method=self.example_function name="example",
description="Example command",
method=self.example_command_method,
) )
assert cmd.signature == "(arg1: int, arg2: str) -> str" assert cmd.signature == "(arg1: int, arg2: str) -> str"
def test_command_custom_signature(self): def test_command_custom_signature(self):
@@ -51,7 +65,7 @@ class TestCommand:
cmd = Command( cmd = Command(
name="example", name="example",
description="Example command", description="Example command",
method=self.example_function, method=self.example_command_method,
signature=custom_signature, signature=custom_signature,
) )
@@ -60,14 +74,16 @@ class TestCommand:
class TestCommandRegistry: class TestCommandRegistry:
@staticmethod @staticmethod
def example_function(arg1: int, arg2: str) -> str: def example_command_method(arg1: int, arg2: str) -> str:
return f"{arg1} - {arg2}" return f"{arg1} - {arg2}"
def test_register_command(self): def test_register_command(self):
"""Test that a command can be registered to the registry.""" """Test that a command can be registered to the registry."""
registry = CommandRegistry() registry = CommandRegistry()
cmd = Command( cmd = Command(
name="example", description="Example command", method=self.example_function name="example",
description="Example command",
method=self.example_command_method,
) )
registry.register(cmd) registry.register(cmd)
@@ -79,7 +95,9 @@ class TestCommandRegistry:
"""Test that a command can be unregistered from the registry.""" """Test that a command can be unregistered from the registry."""
registry = CommandRegistry() registry = CommandRegistry()
cmd = Command( cmd = Command(
name="example", description="Example command", method=self.example_function name="example",
description="Example command",
method=self.example_command_method,
) )
registry.register(cmd) registry.register(cmd)
@@ -91,7 +109,9 @@ class TestCommandRegistry:
"""Test that a command can be retrieved from the registry.""" """Test that a command can be retrieved from the registry."""
registry = CommandRegistry() registry = CommandRegistry()
cmd = Command( cmd = Command(
name="example", description="Example command", method=self.example_function name="example",
description="Example command",
method=self.example_command_method,
) )
registry.register(cmd) registry.register(cmd)
@@ -110,7 +130,9 @@ class TestCommandRegistry:
"""Test that a command can be called through the registry.""" """Test that a command can be called through the registry."""
registry = CommandRegistry() registry = CommandRegistry()
cmd = Command( cmd = Command(
name="example", description="Example command", method=self.example_function name="example",
description="Example command",
method=self.example_command_method,
) )
registry.register(cmd) registry.register(cmd)
@@ -129,7 +151,9 @@ class TestCommandRegistry:
"""Test that the command prompt is correctly formatted.""" """Test that the command prompt is correctly formatted."""
registry = CommandRegistry() registry = CommandRegistry()
cmd = Command( cmd = Command(
name="example", description="Example command", method=self.example_function name="example",
description="Example command",
method=self.example_command_method,
) )
registry.register(cmd) registry.register(cmd)
@@ -152,7 +176,11 @@ class TestCommandRegistry:
) )
def test_import_temp_command_file_module(self, tmp_path): def test_import_temp_command_file_module(self, tmp_path):
"""Test that the registry can import a command plugins module from a temp file.""" """
Test that the registry can import a command plugins module from a temp file.
Args:
tmp_path (pathlib.Path): Path to a temporary directory.
"""
registry = CommandRegistry() registry = CommandRegistry()
# Create a temp command file # Create a temp command file

View File

@@ -13,6 +13,7 @@ from tests.utils import requires_api_key
def lst(txt): def lst(txt):
"""Extract the file path from the output of `generate_image()`"""
return Path(txt.split(":")[1].strip()) return Path(txt.split(":")[1].strip())
@@ -30,6 +31,7 @@ class TestImageGen(unittest.TestCase):
@requires_api_key("OPENAI_API_KEY") @requires_api_key("OPENAI_API_KEY")
def test_dalle(self): def test_dalle(self):
"""Test DALL-E image generation."""
self.config.image_provider = "dalle" self.config.image_provider = "dalle"
# Test using size 256 # Test using size 256
@@ -47,6 +49,7 @@ class TestImageGen(unittest.TestCase):
@requires_api_key("HUGGINGFACE_API_TOKEN") @requires_api_key("HUGGINGFACE_API_TOKEN")
def test_huggingface(self): def test_huggingface(self):
"""Test HuggingFace image generation."""
self.config.image_provider = "huggingface" self.config.image_provider = "huggingface"
# Test usin SD 1.4 model and size 512 # Test usin SD 1.4 model and size 512
@@ -65,6 +68,7 @@ class TestImageGen(unittest.TestCase):
image_path.unlink() image_path.unlink()
def test_sd_webui(self): def test_sd_webui(self):
"""Test SD WebUI image generation."""
self.config.image_provider = "sd_webui" self.config.image_provider = "sd_webui"
return return

View File

@@ -6,32 +6,32 @@ from autogpt.json_utils.json_fix_llm import fix_and_parse_json
class TestParseJson(unittest.TestCase): class TestParseJson(unittest.TestCase):
def test_valid_json(self): def test_valid_json(self):
# Test that a valid JSON string is parsed correctly """Test that a valid JSON string is parsed correctly."""
json_str = '{"name": "John", "age": 30, "city": "New York"}' json_str = '{"name": "John", "age": 30, "city": "New York"}'
obj = fix_and_parse_json(json_str) obj = fix_and_parse_json(json_str)
self.assertEqual(obj, {"name": "John", "age": 30, "city": "New York"}) self.assertEqual(obj, {"name": "John", "age": 30, "city": "New York"})
def test_invalid_json_minor(self): def test_invalid_json_minor(self):
# Test that an invalid JSON string can be fixed with gpt """Test that an invalid JSON string can be fixed with gpt"""
json_str = '{"name": "John", "age": 30, "city": "New York",}' json_str = '{"name": "John", "age": 30, "city": "New York",}'
with self.assertRaises(Exception): with self.assertRaises(Exception):
fix_and_parse_json(json_str, try_to_fix_with_gpt=False) fix_and_parse_json(json_str, try_to_fix_with_gpt=False)
def test_invalid_json_major_with_gpt(self): def test_invalid_json_major_with_gpt(self):
# Test that an invalid JSON string raises an error when try_to_fix_with_gpt is False """Test that an invalid JSON string raises an error when try_to_fix_with_gpt is False"""
json_str = 'BEGIN: "name": "John" - "age": 30 - "city": "New York" :END' json_str = 'BEGIN: "name": "John" - "age": 30 - "city": "New York" :END'
with self.assertRaises(Exception): with self.assertRaises(Exception):
fix_and_parse_json(json_str, try_to_fix_with_gpt=False) fix_and_parse_json(json_str, try_to_fix_with_gpt=False)
def test_invalid_json_major_without_gpt(self): def test_invalid_json_major_without_gpt(self):
# Test that a REALLY invalid JSON string raises an error when try_to_fix_with_gpt is False """Test that a REALLY invalid JSON string raises an error when try_to_fix_with_gpt is False"""
json_str = 'BEGIN: "name": "John" - "age": 30 - "city": "New York" :END' json_str = 'BEGIN: "name": "John" - "age": 30 - "city": "New York" :END'
# Assert that this raises an exception: # Assert that this raises an exception:
with self.assertRaises(Exception): with self.assertRaises(Exception):
fix_and_parse_json(json_str, try_to_fix_with_gpt=False) fix_and_parse_json(json_str, try_to_fix_with_gpt=False)
def test_invalid_json_leading_sentence_with_gpt(self): def test_invalid_json_leading_sentence_with_gpt(self):
# Test that a REALLY invalid JSON string raises an error when try_to_fix_with_gpt is False """Test that a REALLY invalid JSON string raises an error when try_to_fix_with_gpt is False"""
json_str = """I suggest we start by browsing the repository to find any issues that we can fix. json_str = """I suggest we start by browsing the repository to find any issues that we can fix.
{ {
@@ -69,7 +69,7 @@ class TestParseJson(unittest.TestCase):
) )
def test_invalid_json_leading_sentence_with_gpt(self): def test_invalid_json_leading_sentence_with_gpt(self):
# Test that a REALLY invalid JSON string raises an error when try_to_fix_with_gpt is False """Test that a REALLY invalid JSON string raises an error when try_to_fix_with_gpt is False"""
json_str = """I will first need to browse the repository (https://github.com/Torantulino/Auto-GPT) and identify any potential bugs that need fixing. I will use the "browse_website" command for this. json_str = """I will first need to browse the repository (https://github.com/Torantulino/Auto-GPT) and identify any potential bugs that need fixing. I will use the "browse_website" command for this.
{ {

View File

@@ -20,9 +20,11 @@ class TestTokenCounter(unittest.TestCase):
self.assertEqual(count_message_tokens(messages), 17) self.assertEqual(count_message_tokens(messages), 17)
def test_count_message_tokens_empty_input(self): def test_count_message_tokens_empty_input(self):
# Empty input should return 3 tokens
self.assertEqual(count_message_tokens([]), 3) self.assertEqual(count_message_tokens([]), 3)
def test_count_message_tokens_invalid_model(self): def test_count_message_tokens_invalid_model(self):
# Invalid model should raise a KeyError
messages = [ messages = [
{"role": "user", "content": "Hello"}, {"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"}, {"role": "assistant", "content": "Hi there!"},
@@ -38,15 +40,20 @@ class TestTokenCounter(unittest.TestCase):
self.assertEqual(count_message_tokens(messages, model="gpt-4-0314"), 15) self.assertEqual(count_message_tokens(messages, model="gpt-4-0314"), 15)
def test_count_string_tokens(self): def test_count_string_tokens(self):
"""Test that the string tokens are counted correctly."""
string = "Hello, world!" string = "Hello, world!"
self.assertEqual( self.assertEqual(
count_string_tokens(string, model_name="gpt-3.5-turbo-0301"), 4 count_string_tokens(string, model_name="gpt-3.5-turbo-0301"), 4
) )
def test_count_string_tokens_empty_input(self): def test_count_string_tokens_empty_input(self):
"""Test that the string tokens are counted correctly."""
self.assertEqual(count_string_tokens("", model_name="gpt-3.5-turbo-0301"), 0) self.assertEqual(count_string_tokens("", model_name="gpt-3.5-turbo-0301"), 0)
def test_count_message_tokens_invalid_model(self): def test_count_message_tokens_invalid_model(self):
# Invalid model should raise a NotImplementedError
messages = [ messages = [
{"role": "user", "content": "Hello"}, {"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"}, {"role": "assistant", "content": "Hi there!"},
@@ -55,6 +62,8 @@ class TestTokenCounter(unittest.TestCase):
count_message_tokens(messages, model="invalid_model") count_message_tokens(messages, model="invalid_model")
def test_count_string_tokens_gpt_4(self): def test_count_string_tokens_gpt_4(self):
"""Test that the string tokens are counted correctly."""
string = "Hello, world!" string = "Hello, world!"
self.assertEqual(count_string_tokens(string, model_name="gpt-4-0314"), 4) self.assertEqual(count_string_tokens(string, model_name="gpt-4-0314"), 4)

View File

@@ -5,13 +5,13 @@ from autogpt.json_utils.json_fix_llm import fix_and_parse_json
class TestParseJson(unittest.TestCase): class TestParseJson(unittest.TestCase):
def test_valid_json(self): def test_valid_json(self):
# Test that a valid JSON string is parsed correctly """Test that a valid JSON string is parsed correctly."""
json_str = '{"name": "John", "age": 30, "city": "New York"}' json_str = '{"name": "John", "age": 30, "city": "New York"}'
obj = fix_and_parse_json(json_str) obj = fix_and_parse_json(json_str)
self.assertEqual(obj, {"name": "John", "age": 30, "city": "New York"}) self.assertEqual(obj, {"name": "John", "age": 30, "city": "New York"})
def test_invalid_json_minor(self): def test_invalid_json_minor(self):
# Test that an invalid JSON string can be fixed with gpt """Test that an invalid JSON string can be fixed with gpt."""
json_str = '{"name": "John", "age": 30, "city": "New York",}' json_str = '{"name": "John", "age": 30, "city": "New York",}'
self.assertEqual( self.assertEqual(
fix_and_parse_json(json_str, try_to_fix_with_gpt=False), fix_and_parse_json(json_str, try_to_fix_with_gpt=False),
@@ -19,7 +19,7 @@ class TestParseJson(unittest.TestCase):
) )
def test_invalid_json_major_with_gpt(self): def test_invalid_json_major_with_gpt(self):
# Test that an invalid JSON string raises an error when try_to_fix_with_gpt is False """Test that an invalid JSON string raises an error when try_to_fix_with_gpt is False."""
json_str = 'BEGIN: "name": "John" - "age": 30 - "city": "New York" :END' json_str = 'BEGIN: "name": "John" - "age": 30 - "city": "New York" :END'
self.assertEqual( self.assertEqual(
fix_and_parse_json(json_str, try_to_fix_with_gpt=True), fix_and_parse_json(json_str, try_to_fix_with_gpt=True),
@@ -27,14 +27,15 @@ class TestParseJson(unittest.TestCase):
) )
def test_invalid_json_major_without_gpt(self): def test_invalid_json_major_without_gpt(self):
# Test that a REALLY invalid JSON string raises an error when try_to_fix_with_gpt is False """Test that a REALLY invalid JSON string raises an error when try_to_fix_with_gpt is False."""
json_str = 'BEGIN: "name": "John" - "age": 30 - "city": "New York" :END' json_str = 'BEGIN: "name": "John" - "age": 30 - "city": "New York" :END'
# Assert that this raises an exception: # Assert that this raises an exception:
with self.assertRaises(Exception): with self.assertRaises(Exception):
fix_and_parse_json(json_str, try_to_fix_with_gpt=False) fix_and_parse_json(json_str, try_to_fix_with_gpt=False)
def test_invalid_json_leading_sentence_with_gpt(self): def test_invalid_json_leading_sentence_with_gpt(self):
# Test that a REALLY invalid JSON string raises an error when try_to_fix_with_gpt is False """Test that a REALLY invalid JSON string raises an error when try_to_fix_with_gpt is False."""
json_str = """I suggest we start by browsing the repository to find any issues that we can fix. json_str = """I suggest we start by browsing the repository to find any issues that we can fix.
{ {
@@ -72,7 +73,7 @@ class TestParseJson(unittest.TestCase):
) )
def test_invalid_json_leading_sentence_with_gpt(self): def test_invalid_json_leading_sentence_with_gpt(self):
# Test that a REALLY invalid JSON string raises an error when try_to_fix_with_gpt is False """Test that a REALLY invalid JSON string raises an error when try_to_fix_with_gpt is False."""
json_str = """I will first need to browse the repository (https://github.com/Torantulino/Auto-GPT) and identify any potential bugs that need fixing. I will use the "browse_website" command for this. json_str = """I will first need to browse the repository (https://github.com/Torantulino/Auto-GPT) and identify any potential bugs that need fixing. I will use the "browse_website" command for this.
{ {

View File

@@ -10,11 +10,14 @@ from autogpt.models.base_open_ai_plugin import (
class DummyPlugin(BaseOpenAIPlugin): class DummyPlugin(BaseOpenAIPlugin):
"""A dummy plugin for testing purposes."""
pass pass
@pytest.fixture @pytest.fixture
def dummy_plugin(): def dummy_plugin():
"""A dummy plugin for testing purposes."""
manifests_specs_clients = { manifests_specs_clients = {
"manifest": { "manifest": {
"name_for_model": "Dummy", "name_for_model": "Dummy",
@@ -28,22 +31,27 @@ def dummy_plugin():
def test_dummy_plugin_inheritance(dummy_plugin): def test_dummy_plugin_inheritance(dummy_plugin):
"""Test that the DummyPlugin class inherits from the BaseOpenAIPlugin class."""
assert isinstance(dummy_plugin, BaseOpenAIPlugin) assert isinstance(dummy_plugin, BaseOpenAIPlugin)
def test_dummy_plugin_name(dummy_plugin): def test_dummy_plugin_name(dummy_plugin):
"""Test that the DummyPlugin class has the correct name."""
assert dummy_plugin._name == "Dummy" assert dummy_plugin._name == "Dummy"
def test_dummy_plugin_version(dummy_plugin): def test_dummy_plugin_version(dummy_plugin):
"""Test that the DummyPlugin class has the correct version."""
assert dummy_plugin._version == "1.0" assert dummy_plugin._version == "1.0"
def test_dummy_plugin_description(dummy_plugin): def test_dummy_plugin_description(dummy_plugin):
"""Test that the DummyPlugin class has the correct description."""
assert dummy_plugin._description == "A dummy plugin for testing purposes" assert dummy_plugin._description == "A dummy plugin for testing purposes"
def test_dummy_plugin_default_methods(dummy_plugin): def test_dummy_plugin_default_methods(dummy_plugin):
"""Test that the DummyPlugin class has the correct default methods."""
assert not dummy_plugin.can_handle_on_response() assert not dummy_plugin.can_handle_on_response()
assert not dummy_plugin.can_handle_post_prompt() assert not dummy_plugin.can_handle_post_prompt()
assert not dummy_plugin.can_handle_on_planning() assert not dummy_plugin.can_handle_on_planning()

View File

@@ -38,8 +38,11 @@ requests and parse HTML content, respectively.
class TestScrapeLinks: class TestScrapeLinks:
# Tests that the function returns a list of formatted hyperlinks when """
# provided with a valid url that returns a webpage with hyperlinks. Tests that the function returns a list of formatted hyperlinks when
provided with a valid url that returns a webpage with hyperlinks.
"""
def test_valid_url_with_hyperlinks(self): def test_valid_url_with_hyperlinks(self):
url = "https://www.google.com" url = "https://www.google.com"
result = scrape_links(url) result = scrape_links(url)
@@ -47,8 +50,8 @@ class TestScrapeLinks:
assert isinstance(result, list) assert isinstance(result, list)
assert isinstance(result[0], str) assert isinstance(result[0], str)
# Tests that the function returns correctly formatted hyperlinks when given a valid url.
def test_valid_url(self, mocker): def test_valid_url(self, mocker):
"""Test that the function returns correctly formatted hyperlinks when given a valid url."""
# Mock the requests.get() function to return a response with sample HTML containing hyperlinks # Mock the requests.get() function to return a response with sample HTML containing hyperlinks
mock_response = mocker.Mock() mock_response = mocker.Mock()
mock_response.status_code = 200 mock_response.status_code = 200
@@ -63,8 +66,8 @@ class TestScrapeLinks:
# Assert that the function returns correctly formatted hyperlinks # Assert that the function returns correctly formatted hyperlinks
assert result == ["Google (https://www.google.com)"] assert result == ["Google (https://www.google.com)"]
# Tests that the function returns "error" when given an invalid url.
def test_invalid_url(self, mocker): def test_invalid_url(self, mocker):
"""Test that the function returns "error" when given an invalid url."""
# Mock the requests.get() function to return an HTTP error response # Mock the requests.get() function to return an HTTP error response
mock_response = mocker.Mock() mock_response = mocker.Mock()
mock_response.status_code = 404 mock_response.status_code = 404
@@ -76,8 +79,8 @@ class TestScrapeLinks:
# Assert that the function returns "error" # Assert that the function returns "error"
assert "Error:" in result assert "Error:" in result
# Tests that the function returns an empty list when the html contains no hyperlinks.
def test_no_hyperlinks(self, mocker): def test_no_hyperlinks(self, mocker):
"""Test that the function returns an empty list when the html contains no hyperlinks."""
# Mock the requests.get() function to return a response with sample HTML containing no hyperlinks # Mock the requests.get() function to return a response with sample HTML containing no hyperlinks
mock_response = mocker.Mock() mock_response = mocker.Mock()
mock_response.status_code = 200 mock_response.status_code = 200
@@ -90,10 +93,8 @@ class TestScrapeLinks:
# Assert that the function returns an empty list # Assert that the function returns an empty list
assert result == [] assert result == []
# Tests that scrape_links() correctly extracts and formats hyperlinks from
# a sample HTML containing a few hyperlinks.
def test_scrape_links_with_few_hyperlinks(self, mocker): def test_scrape_links_with_few_hyperlinks(self, mocker):
# Mock the requests.get() function to return a response with a sample HTML containing hyperlinks """Test that scrape_links() correctly extracts and formats hyperlinks from a sample HTML containing a few hyperlinks."""
mock_response = mocker.Mock() mock_response = mocker.Mock()
mock_response.status_code = 200 mock_response.status_code = 200
mock_response.text = """ mock_response.text = """

View File

@@ -42,8 +42,8 @@ Additional aspects:
class TestScrapeText: class TestScrapeText:
# Tests that scrape_text() returns the expected text when given a valid URL.
def test_scrape_text_with_valid_url(self, mocker): def test_scrape_text_with_valid_url(self, mocker):
"""Tests that scrape_text() returns the expected text when given a valid URL."""
# Mock the requests.get() method to return a response with expected text # Mock the requests.get() method to return a response with expected text
expected_text = "This is some sample text" expected_text = "This is some sample text"
mock_response = mocker.Mock() mock_response = mocker.Mock()
@@ -59,14 +59,13 @@ class TestScrapeText:
url = "http://www.example.com" url = "http://www.example.com"
assert scrape_text(url) == expected_text assert scrape_text(url) == expected_text
# Tests that an error is raised when an invalid url is provided.
def test_invalid_url(self): def test_invalid_url(self):
"""Tests that an error is raised when an invalid url is provided."""
url = "invalidurl.com" url = "invalidurl.com"
pytest.raises(ValueError, scrape_text, url) pytest.raises(ValueError, scrape_text, url)
# Tests that the function returns an error message when an unreachable
# url is provided.
def test_unreachable_url(self, mocker): def test_unreachable_url(self, mocker):
"""Test that scrape_text returns an error message when an invalid or unreachable url is provided."""
# Mock the requests.get() method to raise an exception # Mock the requests.get() method to raise an exception
mocker.patch( mocker.patch(
"requests.Session.get", side_effect=requests.exceptions.RequestException "requests.Session.get", side_effect=requests.exceptions.RequestException
@@ -78,9 +77,8 @@ class TestScrapeText:
error_message = scrape_text(url) error_message = scrape_text(url)
assert "Error:" in error_message assert "Error:" in error_message
# Tests that the function returns an empty string when the html page contains no
# text to be scraped.
def test_no_text(self, mocker): def test_no_text(self, mocker):
"""Test that scrape_text returns an empty string when the html page contains no text to be scraped."""
# Mock the requests.get() method to return a response with no text # Mock the requests.get() method to return a response with no text
mock_response = mocker.Mock() mock_response = mocker.Mock()
mock_response.status_code = 200 mock_response.status_code = 200
@@ -91,9 +89,8 @@ class TestScrapeText:
url = "http://www.example.com" url = "http://www.example.com"
assert scrape_text(url) == "" assert scrape_text(url) == ""
# Tests that the function returns an error message when the response status code is
# an http error (>=400).
def test_http_error(self, mocker): def test_http_error(self, mocker):
"""Test that scrape_text returns an error message when the response status code is an http error (>=400)."""
# Mock the requests.get() method to return a response with a 404 status code # Mock the requests.get() method to return a response with a 404 status code
mocker.patch("requests.Session.get", return_value=mocker.Mock(status_code=404)) mocker.patch("requests.Session.get", return_value=mocker.Mock(status_code=404))
@@ -103,8 +100,8 @@ class TestScrapeText:
# Check that the function returns an error message # Check that the function returns an error message
assert result == "Error: HTTP 404 error" assert result == "Error: HTTP 404 error"
# Tests that scrape_text() properly handles HTML tags.
def test_scrape_text_with_html_tags(self, mocker): def test_scrape_text_with_html_tags(self, mocker):
"""Test that scrape_text() properly handles HTML tags."""
# Create a mock response object with HTML containing tags # Create a mock response object with HTML containing tags
html = "<html><body><p>This is <b>bold</b> text.</p></body></html>" html = "<html><body><p>This is <b>bold</b> text.</p></body></html>"
mock_response = mocker.Mock() mock_response = mocker.Mock()

View File

@@ -7,19 +7,21 @@ from autogpt.chat import create_chat_message, generate_context
class TestChat(unittest.TestCase): class TestChat(unittest.TestCase):
# Tests that the function returns a dictionary with the correct keys and values when valid strings are provided for role and content. """Test the chat module functions."""
def test_happy_path_role_content(self): def test_happy_path_role_content(self):
"""Test that the function returns a dictionary with the correct keys and values when valid strings are provided for role and content."""
result = create_chat_message("system", "Hello, world!") result = create_chat_message("system", "Hello, world!")
self.assertEqual(result, {"role": "system", "content": "Hello, world!"}) self.assertEqual(result, {"role": "system", "content": "Hello, world!"})
# Tests that the function returns a dictionary with the correct keys and values when empty strings are provided for role and content.
def test_empty_role_content(self): def test_empty_role_content(self):
"""Test that the function returns a dictionary with the correct keys and values when empty strings are provided for role and content."""
result = create_chat_message("", "") result = create_chat_message("", "")
self.assertEqual(result, {"role": "", "content": ""}) self.assertEqual(result, {"role": "", "content": ""})
# Tests the behavior of the generate_context function when all input parameters are empty.
@patch("time.strftime") @patch("time.strftime")
def test_generate_context_empty_inputs(self, mock_strftime): def test_generate_context_empty_inputs(self, mock_strftime):
"""Test the behavior of the generate_context function when all input parameters are empty."""
# Mock the time.strftime function to return a fixed value # Mock the time.strftime function to return a fixed value
mock_strftime.return_value = "Sat Apr 15 00:00:00 2023" mock_strftime.return_value = "Sat Apr 15 00:00:00 2023"
# Arrange # Arrange
@@ -50,8 +52,8 @@ class TestChat(unittest.TestCase):
) )
self.assertEqual(result, expected_result) self.assertEqual(result, expected_result)
# Tests that the function successfully generates a current_context given valid inputs.
def test_generate_context_valid_inputs(self): def test_generate_context_valid_inputs(self):
"""Test that the function successfully generates a current_context given valid inputs."""
# Given # Given
prompt = "What is your favorite color?" prompt = "What is your favorite color?"
relevant_memory = "You once painted your room blue." relevant_memory = "You once painted your room blue."

View File

@@ -11,7 +11,8 @@ from tests.utils import requires_api_key
@pytest.mark.integration_test @pytest.mark.integration_test
@requires_api_key("OPENAI_API_KEY") @requires_api_key("OPENAI_API_KEY")
def test_make_agent() -> None: def test_make_agent() -> None:
"""Test the make_agent command""" """Test that an agent can be created"""
# Use the mock agent manager to avoid creating a real agent
with patch("openai.ChatCompletion.create") as mock: with patch("openai.ChatCompletion.create") as mock:
obj = MagicMock() obj = MagicMock()
obj.response.choices[0].messages[0].content = "Test message" obj.response.choices[0].messages[0].content = "Test message"

View File

@@ -21,6 +21,8 @@ def test_inspect_zip_for_modules():
@pytest.fixture @pytest.fixture
def mock_config_denylist_allowlist_check(): def mock_config_denylist_allowlist_check():
class MockConfig: class MockConfig:
"""Mock config object for testing the denylist_allowlist_check function"""
plugins_denylist = ["BadPlugin"] plugins_denylist = ["BadPlugin"]
plugins_allowlist = ["GoodPlugin"] plugins_allowlist = ["GoodPlugin"]
@@ -30,6 +32,7 @@ def mock_config_denylist_allowlist_check():
def test_denylist_allowlist_check_denylist( def test_denylist_allowlist_check_denylist(
mock_config_denylist_allowlist_check, monkeypatch mock_config_denylist_allowlist_check, monkeypatch
): ):
# Test that the function returns False when the plugin is in the denylist
monkeypatch.setattr("builtins.input", lambda _: "y") monkeypatch.setattr("builtins.input", lambda _: "y")
assert not denylist_allowlist_check( assert not denylist_allowlist_check(
"BadPlugin", mock_config_denylist_allowlist_check "BadPlugin", mock_config_denylist_allowlist_check
@@ -39,6 +42,7 @@ def test_denylist_allowlist_check_denylist(
def test_denylist_allowlist_check_allowlist( def test_denylist_allowlist_check_allowlist(
mock_config_denylist_allowlist_check, monkeypatch mock_config_denylist_allowlist_check, monkeypatch
): ):
# Test that the function returns True when the plugin is in the allowlist
monkeypatch.setattr("builtins.input", lambda _: "y") monkeypatch.setattr("builtins.input", lambda _: "y")
assert denylist_allowlist_check("GoodPlugin", mock_config_denylist_allowlist_check) assert denylist_allowlist_check("GoodPlugin", mock_config_denylist_allowlist_check)
@@ -46,6 +50,7 @@ def test_denylist_allowlist_check_allowlist(
def test_denylist_allowlist_check_user_input_yes( def test_denylist_allowlist_check_user_input_yes(
mock_config_denylist_allowlist_check, monkeypatch mock_config_denylist_allowlist_check, monkeypatch
): ):
# Test that the function returns True when the user inputs "y"
monkeypatch.setattr("builtins.input", lambda _: "y") monkeypatch.setattr("builtins.input", lambda _: "y")
assert denylist_allowlist_check( assert denylist_allowlist_check(
"UnknownPlugin", mock_config_denylist_allowlist_check "UnknownPlugin", mock_config_denylist_allowlist_check
@@ -55,6 +60,7 @@ def test_denylist_allowlist_check_user_input_yes(
def test_denylist_allowlist_check_user_input_no( def test_denylist_allowlist_check_user_input_no(
mock_config_denylist_allowlist_check, monkeypatch mock_config_denylist_allowlist_check, monkeypatch
): ):
# Test that the function returns False when the user inputs "n"
monkeypatch.setattr("builtins.input", lambda _: "n") monkeypatch.setattr("builtins.input", lambda _: "n")
assert not denylist_allowlist_check( assert not denylist_allowlist_check(
"UnknownPlugin", mock_config_denylist_allowlist_check "UnknownPlugin", mock_config_denylist_allowlist_check
@@ -64,6 +70,7 @@ def test_denylist_allowlist_check_user_input_no(
def test_denylist_allowlist_check_user_input_invalid( def test_denylist_allowlist_check_user_input_invalid(
mock_config_denylist_allowlist_check, monkeypatch mock_config_denylist_allowlist_check, monkeypatch
): ):
# Test that the function returns False when the user inputs an invalid value
monkeypatch.setattr("builtins.input", lambda _: "invalid") monkeypatch.setattr("builtins.input", lambda _: "invalid")
assert not denylist_allowlist_check( assert not denylist_allowlist_check(
"UnknownPlugin", mock_config_denylist_allowlist_check "UnknownPlugin", mock_config_denylist_allowlist_check
@@ -72,6 +79,8 @@ def test_denylist_allowlist_check_user_input_invalid(
@pytest.fixture @pytest.fixture
def config_with_plugins(): def config_with_plugins():
"""Mock config object for testing the scan_plugins function"""
# Test that the function returns the correct number of plugins
cfg = Config() cfg = Config()
cfg.plugins_dir = PLUGINS_TEST_DIR cfg.plugins_dir = PLUGINS_TEST_DIR
cfg.plugins_openai = ["https://weathergpt.vercel.app/"] cfg.plugins_openai = ["https://weathergpt.vercel.app/"]
@@ -80,7 +89,11 @@ def config_with_plugins():
@pytest.fixture @pytest.fixture
def mock_config_openai_plugin(): def mock_config_openai_plugin():
"""Mock config object for testing the scan_plugins function"""
class MockConfig: class MockConfig:
"""Mock config object for testing the scan_plugins function"""
plugins_dir = PLUGINS_TEST_DIR plugins_dir = PLUGINS_TEST_DIR
plugins_openai = [PLUGIN_TEST_OPENAI] plugins_openai = [PLUGIN_TEST_OPENAI]
plugins_denylist = ["AutoGPTPVicuna"] plugins_denylist = ["AutoGPTPVicuna"]
@@ -90,12 +103,16 @@ def mock_config_openai_plugin():
def test_scan_plugins_openai(mock_config_openai_plugin): def test_scan_plugins_openai(mock_config_openai_plugin):
# Test that the function returns the correct number of plugins
result = scan_plugins(mock_config_openai_plugin, debug=True) result = scan_plugins(mock_config_openai_plugin, debug=True)
assert len(result) == 1 assert len(result) == 1
@pytest.fixture @pytest.fixture
def mock_config_generic_plugin(): def mock_config_generic_plugin():
"""Mock config object for testing the scan_plugins function"""
# Test that the function returns the correct number of plugins
class MockConfig: class MockConfig:
plugins_dir = PLUGINS_TEST_DIR plugins_dir = PLUGINS_TEST_DIR
plugins_openai = [] plugins_openai = []
@@ -106,5 +123,6 @@ def mock_config_generic_plugin():
def test_scan_plugins_generic(mock_config_generic_plugin): def test_scan_plugins_generic(mock_config_generic_plugin):
# Test that the function returns the correct number of plugins
result = scan_plugins(mock_config_generic_plugin, debug=True) result = scan_plugins(mock_config_generic_plugin, debug=True)
assert len(result) == 1 assert len(result) == 1

View File

@@ -30,28 +30,29 @@ PLEASE_WAIT = "Please wait..."
class TestSpinner(unittest.TestCase): class TestSpinner(unittest.TestCase):
# Tests that the spinner initializes with default values.
def test_spinner_initializes_with_default_values(self): def test_spinner_initializes_with_default_values(self):
"""Tests that the spinner initializes with default values."""
with Spinner() as spinner: with Spinner() as spinner:
self.assertEqual(spinner.message, "Loading...") self.assertEqual(spinner.message, "Loading...")
self.assertEqual(spinner.delay, 0.1) self.assertEqual(spinner.delay, 0.1)
# Tests that the spinner initializes with custom message and delay values.
def test_spinner_initializes_with_custom_values(self): def test_spinner_initializes_with_custom_values(self):
"""Tests that the spinner initializes with custom message and delay values."""
with Spinner(message=PLEASE_WAIT, delay=0.2) as spinner: with Spinner(message=PLEASE_WAIT, delay=0.2) as spinner:
self.assertEqual(spinner.message, PLEASE_WAIT) self.assertEqual(spinner.message, PLEASE_WAIT)
self.assertEqual(spinner.delay, 0.2) self.assertEqual(spinner.delay, 0.2)
# Tests that the spinner starts spinning and stops spinning without errors. #
def test_spinner_stops_spinning(self): def test_spinner_stops_spinning(self):
"""Tests that the spinner starts spinning and stops spinning without errors."""
with Spinner() as spinner: with Spinner() as spinner:
time.sleep(1) time.sleep(1)
spinner.update_message(ALMOST_DONE_MESSAGE) spinner.update_message(ALMOST_DONE_MESSAGE)
time.sleep(1) time.sleep(1)
self.assertFalse(spinner.running) self.assertFalse(spinner.running)
# Tests that the spinner message can be updated while the spinner is running and the spinner continues spinning.
def test_spinner_updates_message_and_still_spins(self): def test_spinner_updates_message_and_still_spins(self):
"""Tests that the spinner message can be updated while the spinner is running and the spinner continues spinning."""
with Spinner() as spinner: with Spinner() as spinner:
self.assertTrue(spinner.running) self.assertTrue(spinner.running)
time.sleep(1) time.sleep(1)
@@ -60,9 +61,8 @@ class TestSpinner(unittest.TestCase):
self.assertEqual(spinner.message, ALMOST_DONE_MESSAGE) self.assertEqual(spinner.message, ALMOST_DONE_MESSAGE)
self.assertFalse(spinner.running) self.assertFalse(spinner.running)
# Tests that the spinner can be used as a context manager.
def test_spinner_can_be_used_as_context_manager(self): def test_spinner_can_be_used_as_context_manager(self):
"""Tests that the spinner can be used as a context manager."""
with Spinner() as spinner: with Spinner() as spinner:
self.assertTrue(spinner.running) self.assertTrue(spinner.running)
self.assertFalse(spinner.running) self.assertFalse(spinner.running)