Add function and class descriptions to tests (#2715)

Co-authored-by: Reinier van der Leer <github@pwuts.nl>
This commit is contained in:
Andres Caicedo
2023-04-24 14:55:49 +02:00
committed by GitHub
parent 40a75c804c
commit f8dfedf1c6
19 changed files with 147 additions and 67 deletions

View File

@@ -10,7 +10,10 @@ from browse import extract_hyperlinks
class TestBrowseLinks(unittest.TestCase):
"""Unit tests for the browse module functions that extract hyperlinks."""
def test_extract_hyperlinks(self):
"""Test the extract_hyperlinks function with a simple HTML body."""
body = """
<body>
<a href="https://google.com">Google</a>

View File

@@ -1,6 +1,7 @@
import os
import sys
# Add the scripts directory to the path so that we can import the browse module.
sys.path.insert(
0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../scripts"))
)

View File

@@ -9,10 +9,11 @@ from autogpt.memory.local import LocalCache
class TestLocalCache(unittest.TestCase):
def random_string(self, length):
def generate_random_string(self, length):
return "".join(random.choice(string.ascii_letters) for _ in range(length))
def setUp(self):
"""Set up the test environment for the LocalCache tests."""
cfg = cfg = Config()
self.cache = LocalCache(cfg)
self.cache.clear()
@@ -24,15 +25,15 @@ class TestLocalCache(unittest.TestCase):
"The cake is a lie, but the pie is always true",
"ChatGPT is an advanced AI model for conversation",
]
for text in self.example_texts:
self.cache.add(text)
# Add some random strings to test noise
for _ in range(5):
self.cache.add(self.random_string(10))
self.cache.add(self.generate_random_string(10))
def test_get_relevant(self):
"""Test getting relevant texts from the cache."""
query = "I'm interested in artificial intelligence and NLP"
k = 3
relevant_texts = self.cache.get_relevant(query, k)

View File

@@ -10,14 +10,12 @@ from autogpt.memory.milvus import MilvusMemory
try:
class TestMilvusMemory(unittest.TestCase):
"""Tests for the MilvusMemory class."""
"""Unit tests for the MilvusMemory class."""
def random_string(self, length: int) -> str:
"""Generate a random string of the given length."""
def generate_random_string(self, length: int) -> str:
return "".join(random.choice(string.ascii_letters) for _ in range(length))
def setUp(self) -> None:
"""Set up the test environment."""
cfg = Config()
cfg.milvus_addr = "localhost:19530"
self.memory = MilvusMemory(cfg)
@@ -36,7 +34,7 @@ try:
# Add some random strings to test noise
for _ in range(5):
self.memory.add(self.random_string(10))
self.memory.add(self.generate_random_string(10))
def test_get_relevant(self) -> None:
"""Test getting relevant texts from the cache."""

View File

@@ -16,6 +16,7 @@ class TestWeaviateMemory(unittest.TestCase):
@classmethod
def setUpClass(cls):
"""Set up the test environment for the WeaviateMemory tests."""
# only create the connection to weaviate once
cls.cfg = Config()
@@ -47,6 +48,7 @@ class TestWeaviateMemory(unittest.TestCase):
"""
def setUp(self):
"""Set up the test environment for the WeaviateMemory tests."""
try:
self.client.schema.delete_class(self.index)
except:
@@ -55,6 +57,7 @@ class TestWeaviateMemory(unittest.TestCase):
self.memory = WeaviateMemory(self.cfg)
def test_add(self):
"""Test adding a text to the cache"""
doc = "You are a Titan name Thanos and you are looking for the Infinity Stones"
self.memory.add(doc)
result = self.client.query.get(self.index, ["raw_text"]).do()
@@ -64,8 +67,9 @@ class TestWeaviateMemory(unittest.TestCase):
self.assertEqual(actual[0]["raw_text"], doc)
def test_get(self):
"""Test getting a text from the cache"""
doc = "You are an Avenger and swore to defend the Galaxy from a menace called Thanos"
# add the document to the cache
with self.client.batch as batch:
batch.add_data_object(
uuid=get_valid_uuid(uuid4()),
@@ -82,6 +86,7 @@ class TestWeaviateMemory(unittest.TestCase):
self.assertEqual(actual[0], doc)
def test_get_stats(self):
"""Test getting the stats of the cache"""
docs = [
"You are now about to count the number of docs in this index",
"And then you about to find out if you can count correctly",
@@ -96,6 +101,7 @@ class TestWeaviateMemory(unittest.TestCase):
self.assertEqual(stats["count"], 2)
def test_clear(self):
"""Test clearing the cache"""
docs = [
"Shame this is the last test for this class",
"Testing is fun when someone else is doing it",

View File

@@ -8,7 +8,8 @@ try:
from autogpt.memory.milvus import MilvusMemory
def mock_config() -> dict:
"""Mock the Config class"""
"""Mock the config object for testing purposes."""
# Return a mock config object with the required attributes
return type(
"MockConfig",
(object,),

View File

@@ -3,4 +3,5 @@ from autogpt.commands.command import command
@command("function_based", "Function-based test command")
def function_based(arg1: int, arg2: str) -> str:
"""A function-based test command that returns a string with the two arguments separated by a dash."""
return f"{arg1} - {arg2}"

View File

@@ -9,41 +9,55 @@ from autogpt.commands.command import Command, CommandRegistry
class TestCommand:
"""Test cases for the Command class."""
@staticmethod
def example_function(arg1: int, arg2: str) -> str:
def example_command_method(arg1: int, arg2: str) -> str:
"""Example function for testing the Command class."""
# This function is static because it is not used by any other test cases.
return f"{arg1} - {arg2}"
def test_command_creation(self):
"""Test that a Command object can be created with the correct attributes."""
cmd = Command(
name="example", description="Example command", method=self.example_function
name="example",
description="Example command",
method=self.example_command_method,
)
assert cmd.name == "example"
assert cmd.description == "Example command"
assert cmd.method == self.example_function
assert cmd.method == self.example_command_method
assert cmd.signature == "(arg1: int, arg2: str) -> str"
def test_command_call(self):
"""Test that Command(*args) calls and returns the result of method(*args)."""
# Create a Command object with the example_command_method.
cmd = Command(
name="example", description="Example command", method=self.example_function
name="example",
description="Example command",
method=self.example_command_method,
)
result = cmd(arg1=1, arg2="test")
assert result == "1 - test"
def test_command_call_with_invalid_arguments(self):
"""Test that calling a Command object with invalid arguments raises a TypeError."""
cmd = Command(
name="example", description="Example command", method=self.example_function
name="example",
description="Example command",
method=self.example_command_method,
)
with pytest.raises(TypeError):
cmd(arg1="invalid", does_not_exist="test")
def test_command_default_signature(self):
"""Test that the default signature is generated correctly."""
cmd = Command(
name="example", description="Example command", method=self.example_function
name="example",
description="Example command",
method=self.example_command_method,
)
assert cmd.signature == "(arg1: int, arg2: str) -> str"
def test_command_custom_signature(self):
@@ -51,7 +65,7 @@ class TestCommand:
cmd = Command(
name="example",
description="Example command",
method=self.example_function,
method=self.example_command_method,
signature=custom_signature,
)
@@ -60,14 +74,16 @@ class TestCommand:
class TestCommandRegistry:
@staticmethod
def example_function(arg1: int, arg2: str) -> str:
def example_command_method(arg1: int, arg2: str) -> str:
return f"{arg1} - {arg2}"
def test_register_command(self):
"""Test that a command can be registered to the registry."""
registry = CommandRegistry()
cmd = Command(
name="example", description="Example command", method=self.example_function
name="example",
description="Example command",
method=self.example_command_method,
)
registry.register(cmd)
@@ -79,7 +95,9 @@ class TestCommandRegistry:
"""Test that a command can be unregistered from the registry."""
registry = CommandRegistry()
cmd = Command(
name="example", description="Example command", method=self.example_function
name="example",
description="Example command",
method=self.example_command_method,
)
registry.register(cmd)
@@ -91,7 +109,9 @@ class TestCommandRegistry:
"""Test that a command can be retrieved from the registry."""
registry = CommandRegistry()
cmd = Command(
name="example", description="Example command", method=self.example_function
name="example",
description="Example command",
method=self.example_command_method,
)
registry.register(cmd)
@@ -110,7 +130,9 @@ class TestCommandRegistry:
"""Test that a command can be called through the registry."""
registry = CommandRegistry()
cmd = Command(
name="example", description="Example command", method=self.example_function
name="example",
description="Example command",
method=self.example_command_method,
)
registry.register(cmd)
@@ -129,7 +151,9 @@ class TestCommandRegistry:
"""Test that the command prompt is correctly formatted."""
registry = CommandRegistry()
cmd = Command(
name="example", description="Example command", method=self.example_function
name="example",
description="Example command",
method=self.example_command_method,
)
registry.register(cmd)
@@ -152,7 +176,11 @@ class TestCommandRegistry:
)
def test_import_temp_command_file_module(self, tmp_path):
"""Test that the registry can import a command plugins module from a temp file."""
"""
Test that the registry can import a command plugins module from a temp file.
Args:
tmp_path (pathlib.Path): Path to a temporary directory.
"""
registry = CommandRegistry()
# Create a temp command file

View File

@@ -13,6 +13,7 @@ from tests.utils import requires_api_key
def lst(txt):
"""Extract the file path from the output of `generate_image()`"""
return Path(txt.split(":")[1].strip())
@@ -30,6 +31,7 @@ class TestImageGen(unittest.TestCase):
@requires_api_key("OPENAI_API_KEY")
def test_dalle(self):
"""Test DALL-E image generation."""
self.config.image_provider = "dalle"
# Test using size 256
@@ -47,6 +49,7 @@ class TestImageGen(unittest.TestCase):
@requires_api_key("HUGGINGFACE_API_TOKEN")
def test_huggingface(self):
"""Test HuggingFace image generation."""
self.config.image_provider = "huggingface"
# Test usin SD 1.4 model and size 512
@@ -65,6 +68,7 @@ class TestImageGen(unittest.TestCase):
image_path.unlink()
def test_sd_webui(self):
"""Test SD WebUI image generation."""
self.config.image_provider = "sd_webui"
return

View File

@@ -6,32 +6,32 @@ from autogpt.json_utils.json_fix_llm import fix_and_parse_json
class TestParseJson(unittest.TestCase):
def test_valid_json(self):
# Test that a valid JSON string is parsed correctly
"""Test that a valid JSON string is parsed correctly."""
json_str = '{"name": "John", "age": 30, "city": "New York"}'
obj = fix_and_parse_json(json_str)
self.assertEqual(obj, {"name": "John", "age": 30, "city": "New York"})
def test_invalid_json_minor(self):
# Test that an invalid JSON string can be fixed with gpt
"""Test that an invalid JSON string can be fixed with gpt"""
json_str = '{"name": "John", "age": 30, "city": "New York",}'
with self.assertRaises(Exception):
fix_and_parse_json(json_str, try_to_fix_with_gpt=False)
def test_invalid_json_major_with_gpt(self):
# Test that an invalid JSON string raises an error when try_to_fix_with_gpt is False
"""Test that an invalid JSON string raises an error when try_to_fix_with_gpt is False"""
json_str = 'BEGIN: "name": "John" - "age": 30 - "city": "New York" :END'
with self.assertRaises(Exception):
fix_and_parse_json(json_str, try_to_fix_with_gpt=False)
def test_invalid_json_major_without_gpt(self):
# Test that a REALLY invalid JSON string raises an error when try_to_fix_with_gpt is False
"""Test that a REALLY invalid JSON string raises an error when try_to_fix_with_gpt is False"""
json_str = 'BEGIN: "name": "John" - "age": 30 - "city": "New York" :END'
# Assert that this raises an exception:
with self.assertRaises(Exception):
fix_and_parse_json(json_str, try_to_fix_with_gpt=False)
def test_invalid_json_leading_sentence_with_gpt(self):
# Test that a REALLY invalid JSON string raises an error when try_to_fix_with_gpt is False
"""Test that a REALLY invalid JSON string raises an error when try_to_fix_with_gpt is False"""
json_str = """I suggest we start by browsing the repository to find any issues that we can fix.
{
@@ -69,7 +69,7 @@ class TestParseJson(unittest.TestCase):
)
def test_invalid_json_leading_sentence_with_gpt(self):
# Test that a REALLY invalid JSON string raises an error when try_to_fix_with_gpt is False
"""Test that a REALLY invalid JSON string raises an error when try_to_fix_with_gpt is False"""
json_str = """I will first need to browse the repository (https://github.com/Torantulino/Auto-GPT) and identify any potential bugs that need fixing. I will use the "browse_website" command for this.
{

View File

@@ -20,9 +20,11 @@ class TestTokenCounter(unittest.TestCase):
self.assertEqual(count_message_tokens(messages), 17)
def test_count_message_tokens_empty_input(self):
# Empty input should return 3 tokens
self.assertEqual(count_message_tokens([]), 3)
def test_count_message_tokens_invalid_model(self):
# Invalid model should raise a KeyError
messages = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
@@ -38,15 +40,20 @@ class TestTokenCounter(unittest.TestCase):
self.assertEqual(count_message_tokens(messages, model="gpt-4-0314"), 15)
def test_count_string_tokens(self):
"""Test that the string tokens are counted correctly."""
string = "Hello, world!"
self.assertEqual(
count_string_tokens(string, model_name="gpt-3.5-turbo-0301"), 4
)
def test_count_string_tokens_empty_input(self):
"""Test that the string tokens are counted correctly."""
self.assertEqual(count_string_tokens("", model_name="gpt-3.5-turbo-0301"), 0)
def test_count_message_tokens_invalid_model(self):
# Invalid model should raise a NotImplementedError
messages = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
@@ -55,6 +62,8 @@ class TestTokenCounter(unittest.TestCase):
count_message_tokens(messages, model="invalid_model")
def test_count_string_tokens_gpt_4(self):
"""Test that the string tokens are counted correctly."""
string = "Hello, world!"
self.assertEqual(count_string_tokens(string, model_name="gpt-4-0314"), 4)

View File

@@ -5,13 +5,13 @@ from autogpt.json_utils.json_fix_llm import fix_and_parse_json
class TestParseJson(unittest.TestCase):
def test_valid_json(self):
# Test that a valid JSON string is parsed correctly
"""Test that a valid JSON string is parsed correctly."""
json_str = '{"name": "John", "age": 30, "city": "New York"}'
obj = fix_and_parse_json(json_str)
self.assertEqual(obj, {"name": "John", "age": 30, "city": "New York"})
def test_invalid_json_minor(self):
# Test that an invalid JSON string can be fixed with gpt
"""Test that an invalid JSON string can be fixed with gpt."""
json_str = '{"name": "John", "age": 30, "city": "New York",}'
self.assertEqual(
fix_and_parse_json(json_str, try_to_fix_with_gpt=False),
@@ -19,7 +19,7 @@ class TestParseJson(unittest.TestCase):
)
def test_invalid_json_major_with_gpt(self):
# Test that an invalid JSON string raises an error when try_to_fix_with_gpt is False
"""Test that an invalid JSON string raises an error when try_to_fix_with_gpt is False."""
json_str = 'BEGIN: "name": "John" - "age": 30 - "city": "New York" :END'
self.assertEqual(
fix_and_parse_json(json_str, try_to_fix_with_gpt=True),
@@ -27,14 +27,15 @@ class TestParseJson(unittest.TestCase):
)
def test_invalid_json_major_without_gpt(self):
# Test that a REALLY invalid JSON string raises an error when try_to_fix_with_gpt is False
"""Test that a REALLY invalid JSON string raises an error when try_to_fix_with_gpt is False."""
json_str = 'BEGIN: "name": "John" - "age": 30 - "city": "New York" :END'
# Assert that this raises an exception:
with self.assertRaises(Exception):
fix_and_parse_json(json_str, try_to_fix_with_gpt=False)
def test_invalid_json_leading_sentence_with_gpt(self):
# Test that a REALLY invalid JSON string raises an error when try_to_fix_with_gpt is False
"""Test that a REALLY invalid JSON string raises an error when try_to_fix_with_gpt is False."""
json_str = """I suggest we start by browsing the repository to find any issues that we can fix.
{
@@ -72,7 +73,7 @@ class TestParseJson(unittest.TestCase):
)
def test_invalid_json_leading_sentence_with_gpt(self):
# Test that a REALLY invalid JSON string raises an error when try_to_fix_with_gpt is False
"""Test that a REALLY invalid JSON string raises an error when try_to_fix_with_gpt is False."""
json_str = """I will first need to browse the repository (https://github.com/Torantulino/Auto-GPT) and identify any potential bugs that need fixing. I will use the "browse_website" command for this.
{

View File

@@ -10,11 +10,14 @@ from autogpt.models.base_open_ai_plugin import (
class DummyPlugin(BaseOpenAIPlugin):
"""A dummy plugin for testing purposes."""
pass
@pytest.fixture
def dummy_plugin():
"""A dummy plugin for testing purposes."""
manifests_specs_clients = {
"manifest": {
"name_for_model": "Dummy",
@@ -28,22 +31,27 @@ def dummy_plugin():
def test_dummy_plugin_inheritance(dummy_plugin):
"""Test that the DummyPlugin class inherits from the BaseOpenAIPlugin class."""
assert isinstance(dummy_plugin, BaseOpenAIPlugin)
def test_dummy_plugin_name(dummy_plugin):
"""Test that the DummyPlugin class has the correct name."""
assert dummy_plugin._name == "Dummy"
def test_dummy_plugin_version(dummy_plugin):
"""Test that the DummyPlugin class has the correct version."""
assert dummy_plugin._version == "1.0"
def test_dummy_plugin_description(dummy_plugin):
"""Test that the DummyPlugin class has the correct description."""
assert dummy_plugin._description == "A dummy plugin for testing purposes"
def test_dummy_plugin_default_methods(dummy_plugin):
"""Test that the DummyPlugin class has the correct default methods."""
assert not dummy_plugin.can_handle_on_response()
assert not dummy_plugin.can_handle_post_prompt()
assert not dummy_plugin.can_handle_on_planning()

View File

@@ -38,8 +38,11 @@ requests and parse HTML content, respectively.
class TestScrapeLinks:
# Tests that the function returns a list of formatted hyperlinks when
# provided with a valid url that returns a webpage with hyperlinks.
"""
Tests that the function returns a list of formatted hyperlinks when
provided with a valid url that returns a webpage with hyperlinks.
"""
def test_valid_url_with_hyperlinks(self):
url = "https://www.google.com"
result = scrape_links(url)
@@ -47,8 +50,8 @@ class TestScrapeLinks:
assert isinstance(result, list)
assert isinstance(result[0], str)
# Tests that the function returns correctly formatted hyperlinks when given a valid url.
def test_valid_url(self, mocker):
"""Test that the function returns correctly formatted hyperlinks when given a valid url."""
# Mock the requests.get() function to return a response with sample HTML containing hyperlinks
mock_response = mocker.Mock()
mock_response.status_code = 200
@@ -63,8 +66,8 @@ class TestScrapeLinks:
# Assert that the function returns correctly formatted hyperlinks
assert result == ["Google (https://www.google.com)"]
# Tests that the function returns "error" when given an invalid url.
def test_invalid_url(self, mocker):
"""Test that the function returns "error" when given an invalid url."""
# Mock the requests.get() function to return an HTTP error response
mock_response = mocker.Mock()
mock_response.status_code = 404
@@ -76,8 +79,8 @@ class TestScrapeLinks:
# Assert that the function returns "error"
assert "Error:" in result
# Tests that the function returns an empty list when the html contains no hyperlinks.
def test_no_hyperlinks(self, mocker):
"""Test that the function returns an empty list when the html contains no hyperlinks."""
# Mock the requests.get() function to return a response with sample HTML containing no hyperlinks
mock_response = mocker.Mock()
mock_response.status_code = 200
@@ -90,10 +93,8 @@ class TestScrapeLinks:
# Assert that the function returns an empty list
assert result == []
# Tests that scrape_links() correctly extracts and formats hyperlinks from
# a sample HTML containing a few hyperlinks.
def test_scrape_links_with_few_hyperlinks(self, mocker):
# Mock the requests.get() function to return a response with a sample HTML containing hyperlinks
"""Test that scrape_links() correctly extracts and formats hyperlinks from a sample HTML containing a few hyperlinks."""
mock_response = mocker.Mock()
mock_response.status_code = 200
mock_response.text = """

View File

@@ -42,8 +42,8 @@ Additional aspects:
class TestScrapeText:
# Tests that scrape_text() returns the expected text when given a valid URL.
def test_scrape_text_with_valid_url(self, mocker):
"""Tests that scrape_text() returns the expected text when given a valid URL."""
# Mock the requests.get() method to return a response with expected text
expected_text = "This is some sample text"
mock_response = mocker.Mock()
@@ -59,14 +59,13 @@ class TestScrapeText:
url = "http://www.example.com"
assert scrape_text(url) == expected_text
# Tests that an error is raised when an invalid url is provided.
def test_invalid_url(self):
"""Tests that an error is raised when an invalid url is provided."""
url = "invalidurl.com"
pytest.raises(ValueError, scrape_text, url)
# Tests that the function returns an error message when an unreachable
# url is provided.
def test_unreachable_url(self, mocker):
"""Test that scrape_text returns an error message when an invalid or unreachable url is provided."""
# Mock the requests.get() method to raise an exception
mocker.patch(
"requests.Session.get", side_effect=requests.exceptions.RequestException
@@ -78,9 +77,8 @@ class TestScrapeText:
error_message = scrape_text(url)
assert "Error:" in error_message
# Tests that the function returns an empty string when the html page contains no
# text to be scraped.
def test_no_text(self, mocker):
"""Test that scrape_text returns an empty string when the html page contains no text to be scraped."""
# Mock the requests.get() method to return a response with no text
mock_response = mocker.Mock()
mock_response.status_code = 200
@@ -91,9 +89,8 @@ class TestScrapeText:
url = "http://www.example.com"
assert scrape_text(url) == ""
# Tests that the function returns an error message when the response status code is
# an http error (>=400).
def test_http_error(self, mocker):
"""Test that scrape_text returns an error message when the response status code is an http error (>=400)."""
# Mock the requests.get() method to return a response with a 404 status code
mocker.patch("requests.Session.get", return_value=mocker.Mock(status_code=404))
@@ -103,8 +100,8 @@ class TestScrapeText:
# Check that the function returns an error message
assert result == "Error: HTTP 404 error"
# Tests that scrape_text() properly handles HTML tags.
def test_scrape_text_with_html_tags(self, mocker):
"""Test that scrape_text() properly handles HTML tags."""
# Create a mock response object with HTML containing tags
html = "<html><body><p>This is <b>bold</b> text.</p></body></html>"
mock_response = mocker.Mock()

View File

@@ -7,19 +7,21 @@ from autogpt.chat import create_chat_message, generate_context
class TestChat(unittest.TestCase):
# Tests that the function returns a dictionary with the correct keys and values when valid strings are provided for role and content.
"""Test the chat module functions."""
def test_happy_path_role_content(self):
"""Test that the function returns a dictionary with the correct keys and values when valid strings are provided for role and content."""
result = create_chat_message("system", "Hello, world!")
self.assertEqual(result, {"role": "system", "content": "Hello, world!"})
# Tests that the function returns a dictionary with the correct keys and values when empty strings are provided for role and content.
def test_empty_role_content(self):
"""Test that the function returns a dictionary with the correct keys and values when empty strings are provided for role and content."""
result = create_chat_message("", "")
self.assertEqual(result, {"role": "", "content": ""})
# Tests the behavior of the generate_context function when all input parameters are empty.
@patch("time.strftime")
def test_generate_context_empty_inputs(self, mock_strftime):
"""Test the behavior of the generate_context function when all input parameters are empty."""
# Mock the time.strftime function to return a fixed value
mock_strftime.return_value = "Sat Apr 15 00:00:00 2023"
# Arrange
@@ -50,8 +52,8 @@ class TestChat(unittest.TestCase):
)
self.assertEqual(result, expected_result)
# Tests that the function successfully generates a current_context given valid inputs.
def test_generate_context_valid_inputs(self):
"""Test that the function successfully generates a current_context given valid inputs."""
# Given
prompt = "What is your favorite color?"
relevant_memory = "You once painted your room blue."

View File

@@ -11,7 +11,8 @@ from tests.utils import requires_api_key
@pytest.mark.integration_test
@requires_api_key("OPENAI_API_KEY")
def test_make_agent() -> None:
"""Test the make_agent command"""
"""Test that an agent can be created"""
# Use the mock agent manager to avoid creating a real agent
with patch("openai.ChatCompletion.create") as mock:
obj = MagicMock()
obj.response.choices[0].messages[0].content = "Test message"

View File

@@ -21,6 +21,8 @@ def test_inspect_zip_for_modules():
@pytest.fixture
def mock_config_denylist_allowlist_check():
class MockConfig:
"""Mock config object for testing the denylist_allowlist_check function"""
plugins_denylist = ["BadPlugin"]
plugins_allowlist = ["GoodPlugin"]
@@ -30,6 +32,7 @@ def mock_config_denylist_allowlist_check():
def test_denylist_allowlist_check_denylist(
mock_config_denylist_allowlist_check, monkeypatch
):
# Test that the function returns False when the plugin is in the denylist
monkeypatch.setattr("builtins.input", lambda _: "y")
assert not denylist_allowlist_check(
"BadPlugin", mock_config_denylist_allowlist_check
@@ -39,6 +42,7 @@ def test_denylist_allowlist_check_denylist(
def test_denylist_allowlist_check_allowlist(
mock_config_denylist_allowlist_check, monkeypatch
):
# Test that the function returns True when the plugin is in the allowlist
monkeypatch.setattr("builtins.input", lambda _: "y")
assert denylist_allowlist_check("GoodPlugin", mock_config_denylist_allowlist_check)
@@ -46,6 +50,7 @@ def test_denylist_allowlist_check_allowlist(
def test_denylist_allowlist_check_user_input_yes(
mock_config_denylist_allowlist_check, monkeypatch
):
# Test that the function returns True when the user inputs "y"
monkeypatch.setattr("builtins.input", lambda _: "y")
assert denylist_allowlist_check(
"UnknownPlugin", mock_config_denylist_allowlist_check
@@ -55,6 +60,7 @@ def test_denylist_allowlist_check_user_input_yes(
def test_denylist_allowlist_check_user_input_no(
mock_config_denylist_allowlist_check, monkeypatch
):
# Test that the function returns False when the user inputs "n"
monkeypatch.setattr("builtins.input", lambda _: "n")
assert not denylist_allowlist_check(
"UnknownPlugin", mock_config_denylist_allowlist_check
@@ -64,6 +70,7 @@ def test_denylist_allowlist_check_user_input_no(
def test_denylist_allowlist_check_user_input_invalid(
mock_config_denylist_allowlist_check, monkeypatch
):
# Test that the function returns False when the user inputs an invalid value
monkeypatch.setattr("builtins.input", lambda _: "invalid")
assert not denylist_allowlist_check(
"UnknownPlugin", mock_config_denylist_allowlist_check
@@ -72,6 +79,8 @@ def test_denylist_allowlist_check_user_input_invalid(
@pytest.fixture
def config_with_plugins():
"""Mock config object for testing the scan_plugins function"""
# Test that the function returns the correct number of plugins
cfg = Config()
cfg.plugins_dir = PLUGINS_TEST_DIR
cfg.plugins_openai = ["https://weathergpt.vercel.app/"]
@@ -80,7 +89,11 @@ def config_with_plugins():
@pytest.fixture
def mock_config_openai_plugin():
"""Mock config object for testing the scan_plugins function"""
class MockConfig:
"""Mock config object for testing the scan_plugins function"""
plugins_dir = PLUGINS_TEST_DIR
plugins_openai = [PLUGIN_TEST_OPENAI]
plugins_denylist = ["AutoGPTPVicuna"]
@@ -90,12 +103,16 @@ def mock_config_openai_plugin():
def test_scan_plugins_openai(mock_config_openai_plugin):
# Test that the function returns the correct number of plugins
result = scan_plugins(mock_config_openai_plugin, debug=True)
assert len(result) == 1
@pytest.fixture
def mock_config_generic_plugin():
"""Mock config object for testing the scan_plugins function"""
# Test that the function returns the correct number of plugins
class MockConfig:
plugins_dir = PLUGINS_TEST_DIR
plugins_openai = []
@@ -106,5 +123,6 @@ def mock_config_generic_plugin():
def test_scan_plugins_generic(mock_config_generic_plugin):
# Test that the function returns the correct number of plugins
result = scan_plugins(mock_config_generic_plugin, debug=True)
assert len(result) == 1

View File

@@ -30,28 +30,29 @@ PLEASE_WAIT = "Please wait..."
class TestSpinner(unittest.TestCase):
# Tests that the spinner initializes with default values.
def test_spinner_initializes_with_default_values(self):
"""Tests that the spinner initializes with default values."""
with Spinner() as spinner:
self.assertEqual(spinner.message, "Loading...")
self.assertEqual(spinner.delay, 0.1)
# Tests that the spinner initializes with custom message and delay values.
def test_spinner_initializes_with_custom_values(self):
"""Tests that the spinner initializes with custom message and delay values."""
with Spinner(message=PLEASE_WAIT, delay=0.2) as spinner:
self.assertEqual(spinner.message, PLEASE_WAIT)
self.assertEqual(spinner.delay, 0.2)
# Tests that the spinner starts spinning and stops spinning without errors.
#
def test_spinner_stops_spinning(self):
"""Tests that the spinner starts spinning and stops spinning without errors."""
with Spinner() as spinner:
time.sleep(1)
spinner.update_message(ALMOST_DONE_MESSAGE)
time.sleep(1)
self.assertFalse(spinner.running)
# Tests that the spinner message can be updated while the spinner is running and the spinner continues spinning.
def test_spinner_updates_message_and_still_spins(self):
"""Tests that the spinner message can be updated while the spinner is running and the spinner continues spinning."""
with Spinner() as spinner:
self.assertTrue(spinner.running)
time.sleep(1)
@@ -60,9 +61,8 @@ class TestSpinner(unittest.TestCase):
self.assertEqual(spinner.message, ALMOST_DONE_MESSAGE)
self.assertFalse(spinner.running)
# Tests that the spinner can be used as a context manager.
def test_spinner_can_be_used_as_context_manager(self):
"""Tests that the spinner can be used as a context manager."""
with Spinner() as spinner:
self.assertTrue(spinner.running)
self.assertFalse(spinner.running)