mirror of
https://github.com/aljazceru/Auto-GPT.git
synced 2025-12-23 17:04:21 +01:00
Add function and class descriptions to tests (#2715)
Co-authored-by: Reinier van der Leer <github@pwuts.nl>
This commit is contained in:
@@ -10,7 +10,10 @@ from browse import extract_hyperlinks
|
||||
|
||||
|
||||
class TestBrowseLinks(unittest.TestCase):
|
||||
"""Unit tests for the browse module functions that extract hyperlinks."""
|
||||
|
||||
def test_extract_hyperlinks(self):
|
||||
"""Test the extract_hyperlinks function with a simple HTML body."""
|
||||
body = """
|
||||
<body>
|
||||
<a href="https://google.com">Google</a>
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
# Add the scripts directory to the path so that we can import the browse module.
|
||||
sys.path.insert(
|
||||
0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../scripts"))
|
||||
)
|
||||
|
||||
@@ -9,10 +9,11 @@ from autogpt.memory.local import LocalCache
|
||||
|
||||
|
||||
class TestLocalCache(unittest.TestCase):
|
||||
def random_string(self, length):
|
||||
def generate_random_string(self, length):
|
||||
return "".join(random.choice(string.ascii_letters) for _ in range(length))
|
||||
|
||||
def setUp(self):
|
||||
"""Set up the test environment for the LocalCache tests."""
|
||||
cfg = cfg = Config()
|
||||
self.cache = LocalCache(cfg)
|
||||
self.cache.clear()
|
||||
@@ -24,15 +25,15 @@ class TestLocalCache(unittest.TestCase):
|
||||
"The cake is a lie, but the pie is always true",
|
||||
"ChatGPT is an advanced AI model for conversation",
|
||||
]
|
||||
|
||||
for text in self.example_texts:
|
||||
self.cache.add(text)
|
||||
|
||||
# Add some random strings to test noise
|
||||
for _ in range(5):
|
||||
self.cache.add(self.random_string(10))
|
||||
self.cache.add(self.generate_random_string(10))
|
||||
|
||||
def test_get_relevant(self):
|
||||
"""Test getting relevant texts from the cache."""
|
||||
query = "I'm interested in artificial intelligence and NLP"
|
||||
k = 3
|
||||
relevant_texts = self.cache.get_relevant(query, k)
|
||||
|
||||
@@ -10,14 +10,12 @@ from autogpt.memory.milvus import MilvusMemory
|
||||
try:
|
||||
|
||||
class TestMilvusMemory(unittest.TestCase):
|
||||
"""Tests for the MilvusMemory class."""
|
||||
"""Unit tests for the MilvusMemory class."""
|
||||
|
||||
def random_string(self, length: int) -> str:
|
||||
"""Generate a random string of the given length."""
|
||||
def generate_random_string(self, length: int) -> str:
|
||||
return "".join(random.choice(string.ascii_letters) for _ in range(length))
|
||||
|
||||
def setUp(self) -> None:
|
||||
"""Set up the test environment."""
|
||||
cfg = Config()
|
||||
cfg.milvus_addr = "localhost:19530"
|
||||
self.memory = MilvusMemory(cfg)
|
||||
@@ -36,7 +34,7 @@ try:
|
||||
|
||||
# Add some random strings to test noise
|
||||
for _ in range(5):
|
||||
self.memory.add(self.random_string(10))
|
||||
self.memory.add(self.generate_random_string(10))
|
||||
|
||||
def test_get_relevant(self) -> None:
|
||||
"""Test getting relevant texts from the cache."""
|
||||
|
||||
@@ -16,6 +16,7 @@ class TestWeaviateMemory(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
"""Set up the test environment for the WeaviateMemory tests."""
|
||||
# only create the connection to weaviate once
|
||||
cls.cfg = Config()
|
||||
|
||||
@@ -47,6 +48,7 @@ class TestWeaviateMemory(unittest.TestCase):
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up the test environment for the WeaviateMemory tests."""
|
||||
try:
|
||||
self.client.schema.delete_class(self.index)
|
||||
except:
|
||||
@@ -55,6 +57,7 @@ class TestWeaviateMemory(unittest.TestCase):
|
||||
self.memory = WeaviateMemory(self.cfg)
|
||||
|
||||
def test_add(self):
|
||||
"""Test adding a text to the cache"""
|
||||
doc = "You are a Titan name Thanos and you are looking for the Infinity Stones"
|
||||
self.memory.add(doc)
|
||||
result = self.client.query.get(self.index, ["raw_text"]).do()
|
||||
@@ -64,8 +67,9 @@ class TestWeaviateMemory(unittest.TestCase):
|
||||
self.assertEqual(actual[0]["raw_text"], doc)
|
||||
|
||||
def test_get(self):
|
||||
"""Test getting a text from the cache"""
|
||||
doc = "You are an Avenger and swore to defend the Galaxy from a menace called Thanos"
|
||||
|
||||
# add the document to the cache
|
||||
with self.client.batch as batch:
|
||||
batch.add_data_object(
|
||||
uuid=get_valid_uuid(uuid4()),
|
||||
@@ -82,6 +86,7 @@ class TestWeaviateMemory(unittest.TestCase):
|
||||
self.assertEqual(actual[0], doc)
|
||||
|
||||
def test_get_stats(self):
|
||||
"""Test getting the stats of the cache"""
|
||||
docs = [
|
||||
"You are now about to count the number of docs in this index",
|
||||
"And then you about to find out if you can count correctly",
|
||||
@@ -96,6 +101,7 @@ class TestWeaviateMemory(unittest.TestCase):
|
||||
self.assertEqual(stats["count"], 2)
|
||||
|
||||
def test_clear(self):
|
||||
"""Test clearing the cache"""
|
||||
docs = [
|
||||
"Shame this is the last test for this class",
|
||||
"Testing is fun when someone else is doing it",
|
||||
|
||||
@@ -8,7 +8,8 @@ try:
|
||||
from autogpt.memory.milvus import MilvusMemory
|
||||
|
||||
def mock_config() -> dict:
|
||||
"""Mock the Config class"""
|
||||
"""Mock the config object for testing purposes."""
|
||||
# Return a mock config object with the required attributes
|
||||
return type(
|
||||
"MockConfig",
|
||||
(object,),
|
||||
|
||||
@@ -3,4 +3,5 @@ from autogpt.commands.command import command
|
||||
|
||||
@command("function_based", "Function-based test command")
|
||||
def function_based(arg1: int, arg2: str) -> str:
|
||||
"""A function-based test command that returns a string with the two arguments separated by a dash."""
|
||||
return f"{arg1} - {arg2}"
|
||||
|
||||
@@ -9,41 +9,55 @@ from autogpt.commands.command import Command, CommandRegistry
|
||||
|
||||
|
||||
class TestCommand:
|
||||
"""Test cases for the Command class."""
|
||||
|
||||
@staticmethod
|
||||
def example_function(arg1: int, arg2: str) -> str:
|
||||
def example_command_method(arg1: int, arg2: str) -> str:
|
||||
"""Example function for testing the Command class."""
|
||||
# This function is static because it is not used by any other test cases.
|
||||
return f"{arg1} - {arg2}"
|
||||
|
||||
def test_command_creation(self):
|
||||
"""Test that a Command object can be created with the correct attributes."""
|
||||
cmd = Command(
|
||||
name="example", description="Example command", method=self.example_function
|
||||
name="example",
|
||||
description="Example command",
|
||||
method=self.example_command_method,
|
||||
)
|
||||
|
||||
assert cmd.name == "example"
|
||||
assert cmd.description == "Example command"
|
||||
assert cmd.method == self.example_function
|
||||
assert cmd.method == self.example_command_method
|
||||
assert cmd.signature == "(arg1: int, arg2: str) -> str"
|
||||
|
||||
def test_command_call(self):
|
||||
"""Test that Command(*args) calls and returns the result of method(*args)."""
|
||||
# Create a Command object with the example_command_method.
|
||||
cmd = Command(
|
||||
name="example", description="Example command", method=self.example_function
|
||||
name="example",
|
||||
description="Example command",
|
||||
method=self.example_command_method,
|
||||
)
|
||||
|
||||
result = cmd(arg1=1, arg2="test")
|
||||
assert result == "1 - test"
|
||||
|
||||
def test_command_call_with_invalid_arguments(self):
|
||||
"""Test that calling a Command object with invalid arguments raises a TypeError."""
|
||||
cmd = Command(
|
||||
name="example", description="Example command", method=self.example_function
|
||||
name="example",
|
||||
description="Example command",
|
||||
method=self.example_command_method,
|
||||
)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
cmd(arg1="invalid", does_not_exist="test")
|
||||
|
||||
def test_command_default_signature(self):
|
||||
"""Test that the default signature is generated correctly."""
|
||||
cmd = Command(
|
||||
name="example", description="Example command", method=self.example_function
|
||||
name="example",
|
||||
description="Example command",
|
||||
method=self.example_command_method,
|
||||
)
|
||||
|
||||
assert cmd.signature == "(arg1: int, arg2: str) -> str"
|
||||
|
||||
def test_command_custom_signature(self):
|
||||
@@ -51,7 +65,7 @@ class TestCommand:
|
||||
cmd = Command(
|
||||
name="example",
|
||||
description="Example command",
|
||||
method=self.example_function,
|
||||
method=self.example_command_method,
|
||||
signature=custom_signature,
|
||||
)
|
||||
|
||||
@@ -60,14 +74,16 @@ class TestCommand:
|
||||
|
||||
class TestCommandRegistry:
|
||||
@staticmethod
|
||||
def example_function(arg1: int, arg2: str) -> str:
|
||||
def example_command_method(arg1: int, arg2: str) -> str:
|
||||
return f"{arg1} - {arg2}"
|
||||
|
||||
def test_register_command(self):
|
||||
"""Test that a command can be registered to the registry."""
|
||||
registry = CommandRegistry()
|
||||
cmd = Command(
|
||||
name="example", description="Example command", method=self.example_function
|
||||
name="example",
|
||||
description="Example command",
|
||||
method=self.example_command_method,
|
||||
)
|
||||
|
||||
registry.register(cmd)
|
||||
@@ -79,7 +95,9 @@ class TestCommandRegistry:
|
||||
"""Test that a command can be unregistered from the registry."""
|
||||
registry = CommandRegistry()
|
||||
cmd = Command(
|
||||
name="example", description="Example command", method=self.example_function
|
||||
name="example",
|
||||
description="Example command",
|
||||
method=self.example_command_method,
|
||||
)
|
||||
|
||||
registry.register(cmd)
|
||||
@@ -91,7 +109,9 @@ class TestCommandRegistry:
|
||||
"""Test that a command can be retrieved from the registry."""
|
||||
registry = CommandRegistry()
|
||||
cmd = Command(
|
||||
name="example", description="Example command", method=self.example_function
|
||||
name="example",
|
||||
description="Example command",
|
||||
method=self.example_command_method,
|
||||
)
|
||||
|
||||
registry.register(cmd)
|
||||
@@ -110,7 +130,9 @@ class TestCommandRegistry:
|
||||
"""Test that a command can be called through the registry."""
|
||||
registry = CommandRegistry()
|
||||
cmd = Command(
|
||||
name="example", description="Example command", method=self.example_function
|
||||
name="example",
|
||||
description="Example command",
|
||||
method=self.example_command_method,
|
||||
)
|
||||
|
||||
registry.register(cmd)
|
||||
@@ -129,7 +151,9 @@ class TestCommandRegistry:
|
||||
"""Test that the command prompt is correctly formatted."""
|
||||
registry = CommandRegistry()
|
||||
cmd = Command(
|
||||
name="example", description="Example command", method=self.example_function
|
||||
name="example",
|
||||
description="Example command",
|
||||
method=self.example_command_method,
|
||||
)
|
||||
|
||||
registry.register(cmd)
|
||||
@@ -152,7 +176,11 @@ class TestCommandRegistry:
|
||||
)
|
||||
|
||||
def test_import_temp_command_file_module(self, tmp_path):
|
||||
"""Test that the registry can import a command plugins module from a temp file."""
|
||||
"""
|
||||
Test that the registry can import a command plugins module from a temp file.
|
||||
Args:
|
||||
tmp_path (pathlib.Path): Path to a temporary directory.
|
||||
"""
|
||||
registry = CommandRegistry()
|
||||
|
||||
# Create a temp command file
|
||||
|
||||
@@ -13,6 +13,7 @@ from tests.utils import requires_api_key
|
||||
|
||||
|
||||
def lst(txt):
|
||||
"""Extract the file path from the output of `generate_image()`"""
|
||||
return Path(txt.split(":")[1].strip())
|
||||
|
||||
|
||||
@@ -30,6 +31,7 @@ class TestImageGen(unittest.TestCase):
|
||||
|
||||
@requires_api_key("OPENAI_API_KEY")
|
||||
def test_dalle(self):
|
||||
"""Test DALL-E image generation."""
|
||||
self.config.image_provider = "dalle"
|
||||
|
||||
# Test using size 256
|
||||
@@ -47,6 +49,7 @@ class TestImageGen(unittest.TestCase):
|
||||
|
||||
@requires_api_key("HUGGINGFACE_API_TOKEN")
|
||||
def test_huggingface(self):
|
||||
"""Test HuggingFace image generation."""
|
||||
self.config.image_provider = "huggingface"
|
||||
|
||||
# Test usin SD 1.4 model and size 512
|
||||
@@ -65,6 +68,7 @@ class TestImageGen(unittest.TestCase):
|
||||
image_path.unlink()
|
||||
|
||||
def test_sd_webui(self):
|
||||
"""Test SD WebUI image generation."""
|
||||
self.config.image_provider = "sd_webui"
|
||||
return
|
||||
|
||||
|
||||
@@ -6,32 +6,32 @@ from autogpt.json_utils.json_fix_llm import fix_and_parse_json
|
||||
|
||||
class TestParseJson(unittest.TestCase):
|
||||
def test_valid_json(self):
|
||||
# Test that a valid JSON string is parsed correctly
|
||||
"""Test that a valid JSON string is parsed correctly."""
|
||||
json_str = '{"name": "John", "age": 30, "city": "New York"}'
|
||||
obj = fix_and_parse_json(json_str)
|
||||
self.assertEqual(obj, {"name": "John", "age": 30, "city": "New York"})
|
||||
|
||||
def test_invalid_json_minor(self):
|
||||
# Test that an invalid JSON string can be fixed with gpt
|
||||
"""Test that an invalid JSON string can be fixed with gpt"""
|
||||
json_str = '{"name": "John", "age": 30, "city": "New York",}'
|
||||
with self.assertRaises(Exception):
|
||||
fix_and_parse_json(json_str, try_to_fix_with_gpt=False)
|
||||
|
||||
def test_invalid_json_major_with_gpt(self):
|
||||
# Test that an invalid JSON string raises an error when try_to_fix_with_gpt is False
|
||||
"""Test that an invalid JSON string raises an error when try_to_fix_with_gpt is False"""
|
||||
json_str = 'BEGIN: "name": "John" - "age": 30 - "city": "New York" :END'
|
||||
with self.assertRaises(Exception):
|
||||
fix_and_parse_json(json_str, try_to_fix_with_gpt=False)
|
||||
|
||||
def test_invalid_json_major_without_gpt(self):
|
||||
# Test that a REALLY invalid JSON string raises an error when try_to_fix_with_gpt is False
|
||||
"""Test that a REALLY invalid JSON string raises an error when try_to_fix_with_gpt is False"""
|
||||
json_str = 'BEGIN: "name": "John" - "age": 30 - "city": "New York" :END'
|
||||
# Assert that this raises an exception:
|
||||
with self.assertRaises(Exception):
|
||||
fix_and_parse_json(json_str, try_to_fix_with_gpt=False)
|
||||
|
||||
def test_invalid_json_leading_sentence_with_gpt(self):
|
||||
# Test that a REALLY invalid JSON string raises an error when try_to_fix_with_gpt is False
|
||||
"""Test that a REALLY invalid JSON string raises an error when try_to_fix_with_gpt is False"""
|
||||
json_str = """I suggest we start by browsing the repository to find any issues that we can fix.
|
||||
|
||||
{
|
||||
@@ -69,7 +69,7 @@ class TestParseJson(unittest.TestCase):
|
||||
)
|
||||
|
||||
def test_invalid_json_leading_sentence_with_gpt(self):
|
||||
# Test that a REALLY invalid JSON string raises an error when try_to_fix_with_gpt is False
|
||||
"""Test that a REALLY invalid JSON string raises an error when try_to_fix_with_gpt is False"""
|
||||
json_str = """I will first need to browse the repository (https://github.com/Torantulino/Auto-GPT) and identify any potential bugs that need fixing. I will use the "browse_website" command for this.
|
||||
|
||||
{
|
||||
|
||||
@@ -20,9 +20,11 @@ class TestTokenCounter(unittest.TestCase):
|
||||
self.assertEqual(count_message_tokens(messages), 17)
|
||||
|
||||
def test_count_message_tokens_empty_input(self):
|
||||
# Empty input should return 3 tokens
|
||||
self.assertEqual(count_message_tokens([]), 3)
|
||||
|
||||
def test_count_message_tokens_invalid_model(self):
|
||||
# Invalid model should raise a KeyError
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
@@ -38,15 +40,20 @@ class TestTokenCounter(unittest.TestCase):
|
||||
self.assertEqual(count_message_tokens(messages, model="gpt-4-0314"), 15)
|
||||
|
||||
def test_count_string_tokens(self):
|
||||
"""Test that the string tokens are counted correctly."""
|
||||
|
||||
string = "Hello, world!"
|
||||
self.assertEqual(
|
||||
count_string_tokens(string, model_name="gpt-3.5-turbo-0301"), 4
|
||||
)
|
||||
|
||||
def test_count_string_tokens_empty_input(self):
|
||||
"""Test that the string tokens are counted correctly."""
|
||||
|
||||
self.assertEqual(count_string_tokens("", model_name="gpt-3.5-turbo-0301"), 0)
|
||||
|
||||
def test_count_message_tokens_invalid_model(self):
|
||||
# Invalid model should raise a NotImplementedError
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
@@ -55,6 +62,8 @@ class TestTokenCounter(unittest.TestCase):
|
||||
count_message_tokens(messages, model="invalid_model")
|
||||
|
||||
def test_count_string_tokens_gpt_4(self):
|
||||
"""Test that the string tokens are counted correctly."""
|
||||
|
||||
string = "Hello, world!"
|
||||
self.assertEqual(count_string_tokens(string, model_name="gpt-4-0314"), 4)
|
||||
|
||||
|
||||
@@ -5,13 +5,13 @@ from autogpt.json_utils.json_fix_llm import fix_and_parse_json
|
||||
|
||||
class TestParseJson(unittest.TestCase):
|
||||
def test_valid_json(self):
|
||||
# Test that a valid JSON string is parsed correctly
|
||||
"""Test that a valid JSON string is parsed correctly."""
|
||||
json_str = '{"name": "John", "age": 30, "city": "New York"}'
|
||||
obj = fix_and_parse_json(json_str)
|
||||
self.assertEqual(obj, {"name": "John", "age": 30, "city": "New York"})
|
||||
|
||||
def test_invalid_json_minor(self):
|
||||
# Test that an invalid JSON string can be fixed with gpt
|
||||
"""Test that an invalid JSON string can be fixed with gpt."""
|
||||
json_str = '{"name": "John", "age": 30, "city": "New York",}'
|
||||
self.assertEqual(
|
||||
fix_and_parse_json(json_str, try_to_fix_with_gpt=False),
|
||||
@@ -19,7 +19,7 @@ class TestParseJson(unittest.TestCase):
|
||||
)
|
||||
|
||||
def test_invalid_json_major_with_gpt(self):
|
||||
# Test that an invalid JSON string raises an error when try_to_fix_with_gpt is False
|
||||
"""Test that an invalid JSON string raises an error when try_to_fix_with_gpt is False."""
|
||||
json_str = 'BEGIN: "name": "John" - "age": 30 - "city": "New York" :END'
|
||||
self.assertEqual(
|
||||
fix_and_parse_json(json_str, try_to_fix_with_gpt=True),
|
||||
@@ -27,14 +27,15 @@ class TestParseJson(unittest.TestCase):
|
||||
)
|
||||
|
||||
def test_invalid_json_major_without_gpt(self):
|
||||
# Test that a REALLY invalid JSON string raises an error when try_to_fix_with_gpt is False
|
||||
"""Test that a REALLY invalid JSON string raises an error when try_to_fix_with_gpt is False."""
|
||||
json_str = 'BEGIN: "name": "John" - "age": 30 - "city": "New York" :END'
|
||||
# Assert that this raises an exception:
|
||||
with self.assertRaises(Exception):
|
||||
fix_and_parse_json(json_str, try_to_fix_with_gpt=False)
|
||||
|
||||
def test_invalid_json_leading_sentence_with_gpt(self):
|
||||
# Test that a REALLY invalid JSON string raises an error when try_to_fix_with_gpt is False
|
||||
"""Test that a REALLY invalid JSON string raises an error when try_to_fix_with_gpt is False."""
|
||||
|
||||
json_str = """I suggest we start by browsing the repository to find any issues that we can fix.
|
||||
|
||||
{
|
||||
@@ -72,7 +73,7 @@ class TestParseJson(unittest.TestCase):
|
||||
)
|
||||
|
||||
def test_invalid_json_leading_sentence_with_gpt(self):
|
||||
# Test that a REALLY invalid JSON string raises an error when try_to_fix_with_gpt is False
|
||||
"""Test that a REALLY invalid JSON string raises an error when try_to_fix_with_gpt is False."""
|
||||
json_str = """I will first need to browse the repository (https://github.com/Torantulino/Auto-GPT) and identify any potential bugs that need fixing. I will use the "browse_website" command for this.
|
||||
|
||||
{
|
||||
|
||||
@@ -10,11 +10,14 @@ from autogpt.models.base_open_ai_plugin import (
|
||||
|
||||
|
||||
class DummyPlugin(BaseOpenAIPlugin):
|
||||
"""A dummy plugin for testing purposes."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_plugin():
|
||||
"""A dummy plugin for testing purposes."""
|
||||
manifests_specs_clients = {
|
||||
"manifest": {
|
||||
"name_for_model": "Dummy",
|
||||
@@ -28,22 +31,27 @@ def dummy_plugin():
|
||||
|
||||
|
||||
def test_dummy_plugin_inheritance(dummy_plugin):
|
||||
"""Test that the DummyPlugin class inherits from the BaseOpenAIPlugin class."""
|
||||
assert isinstance(dummy_plugin, BaseOpenAIPlugin)
|
||||
|
||||
|
||||
def test_dummy_plugin_name(dummy_plugin):
|
||||
"""Test that the DummyPlugin class has the correct name."""
|
||||
assert dummy_plugin._name == "Dummy"
|
||||
|
||||
|
||||
def test_dummy_plugin_version(dummy_plugin):
|
||||
"""Test that the DummyPlugin class has the correct version."""
|
||||
assert dummy_plugin._version == "1.0"
|
||||
|
||||
|
||||
def test_dummy_plugin_description(dummy_plugin):
|
||||
"""Test that the DummyPlugin class has the correct description."""
|
||||
assert dummy_plugin._description == "A dummy plugin for testing purposes"
|
||||
|
||||
|
||||
def test_dummy_plugin_default_methods(dummy_plugin):
|
||||
"""Test that the DummyPlugin class has the correct default methods."""
|
||||
assert not dummy_plugin.can_handle_on_response()
|
||||
assert not dummy_plugin.can_handle_post_prompt()
|
||||
assert not dummy_plugin.can_handle_on_planning()
|
||||
|
||||
@@ -38,8 +38,11 @@ requests and parse HTML content, respectively.
|
||||
|
||||
|
||||
class TestScrapeLinks:
|
||||
# Tests that the function returns a list of formatted hyperlinks when
|
||||
# provided with a valid url that returns a webpage with hyperlinks.
|
||||
"""
|
||||
Tests that the function returns a list of formatted hyperlinks when
|
||||
provided with a valid url that returns a webpage with hyperlinks.
|
||||
"""
|
||||
|
||||
def test_valid_url_with_hyperlinks(self):
|
||||
url = "https://www.google.com"
|
||||
result = scrape_links(url)
|
||||
@@ -47,8 +50,8 @@ class TestScrapeLinks:
|
||||
assert isinstance(result, list)
|
||||
assert isinstance(result[0], str)
|
||||
|
||||
# Tests that the function returns correctly formatted hyperlinks when given a valid url.
|
||||
def test_valid_url(self, mocker):
|
||||
"""Test that the function returns correctly formatted hyperlinks when given a valid url."""
|
||||
# Mock the requests.get() function to return a response with sample HTML containing hyperlinks
|
||||
mock_response = mocker.Mock()
|
||||
mock_response.status_code = 200
|
||||
@@ -63,8 +66,8 @@ class TestScrapeLinks:
|
||||
# Assert that the function returns correctly formatted hyperlinks
|
||||
assert result == ["Google (https://www.google.com)"]
|
||||
|
||||
# Tests that the function returns "error" when given an invalid url.
|
||||
def test_invalid_url(self, mocker):
|
||||
"""Test that the function returns "error" when given an invalid url."""
|
||||
# Mock the requests.get() function to return an HTTP error response
|
||||
mock_response = mocker.Mock()
|
||||
mock_response.status_code = 404
|
||||
@@ -76,8 +79,8 @@ class TestScrapeLinks:
|
||||
# Assert that the function returns "error"
|
||||
assert "Error:" in result
|
||||
|
||||
# Tests that the function returns an empty list when the html contains no hyperlinks.
|
||||
def test_no_hyperlinks(self, mocker):
|
||||
"""Test that the function returns an empty list when the html contains no hyperlinks."""
|
||||
# Mock the requests.get() function to return a response with sample HTML containing no hyperlinks
|
||||
mock_response = mocker.Mock()
|
||||
mock_response.status_code = 200
|
||||
@@ -90,10 +93,8 @@ class TestScrapeLinks:
|
||||
# Assert that the function returns an empty list
|
||||
assert result == []
|
||||
|
||||
# Tests that scrape_links() correctly extracts and formats hyperlinks from
|
||||
# a sample HTML containing a few hyperlinks.
|
||||
def test_scrape_links_with_few_hyperlinks(self, mocker):
|
||||
# Mock the requests.get() function to return a response with a sample HTML containing hyperlinks
|
||||
"""Test that scrape_links() correctly extracts and formats hyperlinks from a sample HTML containing a few hyperlinks."""
|
||||
mock_response = mocker.Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.text = """
|
||||
|
||||
@@ -42,8 +42,8 @@ Additional aspects:
|
||||
|
||||
|
||||
class TestScrapeText:
|
||||
# Tests that scrape_text() returns the expected text when given a valid URL.
|
||||
def test_scrape_text_with_valid_url(self, mocker):
|
||||
"""Tests that scrape_text() returns the expected text when given a valid URL."""
|
||||
# Mock the requests.get() method to return a response with expected text
|
||||
expected_text = "This is some sample text"
|
||||
mock_response = mocker.Mock()
|
||||
@@ -59,14 +59,13 @@ class TestScrapeText:
|
||||
url = "http://www.example.com"
|
||||
assert scrape_text(url) == expected_text
|
||||
|
||||
# Tests that an error is raised when an invalid url is provided.
|
||||
def test_invalid_url(self):
|
||||
"""Tests that an error is raised when an invalid url is provided."""
|
||||
url = "invalidurl.com"
|
||||
pytest.raises(ValueError, scrape_text, url)
|
||||
|
||||
# Tests that the function returns an error message when an unreachable
|
||||
# url is provided.
|
||||
def test_unreachable_url(self, mocker):
|
||||
"""Test that scrape_text returns an error message when an invalid or unreachable url is provided."""
|
||||
# Mock the requests.get() method to raise an exception
|
||||
mocker.patch(
|
||||
"requests.Session.get", side_effect=requests.exceptions.RequestException
|
||||
@@ -78,9 +77,8 @@ class TestScrapeText:
|
||||
error_message = scrape_text(url)
|
||||
assert "Error:" in error_message
|
||||
|
||||
# Tests that the function returns an empty string when the html page contains no
|
||||
# text to be scraped.
|
||||
def test_no_text(self, mocker):
|
||||
"""Test that scrape_text returns an empty string when the html page contains no text to be scraped."""
|
||||
# Mock the requests.get() method to return a response with no text
|
||||
mock_response = mocker.Mock()
|
||||
mock_response.status_code = 200
|
||||
@@ -91,9 +89,8 @@ class TestScrapeText:
|
||||
url = "http://www.example.com"
|
||||
assert scrape_text(url) == ""
|
||||
|
||||
# Tests that the function returns an error message when the response status code is
|
||||
# an http error (>=400).
|
||||
def test_http_error(self, mocker):
|
||||
"""Test that scrape_text returns an error message when the response status code is an http error (>=400)."""
|
||||
# Mock the requests.get() method to return a response with a 404 status code
|
||||
mocker.patch("requests.Session.get", return_value=mocker.Mock(status_code=404))
|
||||
|
||||
@@ -103,8 +100,8 @@ class TestScrapeText:
|
||||
# Check that the function returns an error message
|
||||
assert result == "Error: HTTP 404 error"
|
||||
|
||||
# Tests that scrape_text() properly handles HTML tags.
|
||||
def test_scrape_text_with_html_tags(self, mocker):
|
||||
"""Test that scrape_text() properly handles HTML tags."""
|
||||
# Create a mock response object with HTML containing tags
|
||||
html = "<html><body><p>This is <b>bold</b> text.</p></body></html>"
|
||||
mock_response = mocker.Mock()
|
||||
|
||||
@@ -7,19 +7,21 @@ from autogpt.chat import create_chat_message, generate_context
|
||||
|
||||
|
||||
class TestChat(unittest.TestCase):
|
||||
# Tests that the function returns a dictionary with the correct keys and values when valid strings are provided for role and content.
|
||||
"""Test the chat module functions."""
|
||||
|
||||
def test_happy_path_role_content(self):
|
||||
"""Test that the function returns a dictionary with the correct keys and values when valid strings are provided for role and content."""
|
||||
result = create_chat_message("system", "Hello, world!")
|
||||
self.assertEqual(result, {"role": "system", "content": "Hello, world!"})
|
||||
|
||||
# Tests that the function returns a dictionary with the correct keys and values when empty strings are provided for role and content.
|
||||
def test_empty_role_content(self):
|
||||
"""Test that the function returns a dictionary with the correct keys and values when empty strings are provided for role and content."""
|
||||
result = create_chat_message("", "")
|
||||
self.assertEqual(result, {"role": "", "content": ""})
|
||||
|
||||
# Tests the behavior of the generate_context function when all input parameters are empty.
|
||||
@patch("time.strftime")
|
||||
def test_generate_context_empty_inputs(self, mock_strftime):
|
||||
"""Test the behavior of the generate_context function when all input parameters are empty."""
|
||||
# Mock the time.strftime function to return a fixed value
|
||||
mock_strftime.return_value = "Sat Apr 15 00:00:00 2023"
|
||||
# Arrange
|
||||
@@ -50,8 +52,8 @@ class TestChat(unittest.TestCase):
|
||||
)
|
||||
self.assertEqual(result, expected_result)
|
||||
|
||||
# Tests that the function successfully generates a current_context given valid inputs.
|
||||
def test_generate_context_valid_inputs(self):
|
||||
"""Test that the function successfully generates a current_context given valid inputs."""
|
||||
# Given
|
||||
prompt = "What is your favorite color?"
|
||||
relevant_memory = "You once painted your room blue."
|
||||
|
||||
@@ -11,7 +11,8 @@ from tests.utils import requires_api_key
|
||||
@pytest.mark.integration_test
|
||||
@requires_api_key("OPENAI_API_KEY")
|
||||
def test_make_agent() -> None:
|
||||
"""Test the make_agent command"""
|
||||
"""Test that an agent can be created"""
|
||||
# Use the mock agent manager to avoid creating a real agent
|
||||
with patch("openai.ChatCompletion.create") as mock:
|
||||
obj = MagicMock()
|
||||
obj.response.choices[0].messages[0].content = "Test message"
|
||||
|
||||
@@ -21,6 +21,8 @@ def test_inspect_zip_for_modules():
|
||||
@pytest.fixture
|
||||
def mock_config_denylist_allowlist_check():
|
||||
class MockConfig:
|
||||
"""Mock config object for testing the denylist_allowlist_check function"""
|
||||
|
||||
plugins_denylist = ["BadPlugin"]
|
||||
plugins_allowlist = ["GoodPlugin"]
|
||||
|
||||
@@ -30,6 +32,7 @@ def mock_config_denylist_allowlist_check():
|
||||
def test_denylist_allowlist_check_denylist(
|
||||
mock_config_denylist_allowlist_check, monkeypatch
|
||||
):
|
||||
# Test that the function returns False when the plugin is in the denylist
|
||||
monkeypatch.setattr("builtins.input", lambda _: "y")
|
||||
assert not denylist_allowlist_check(
|
||||
"BadPlugin", mock_config_denylist_allowlist_check
|
||||
@@ -39,6 +42,7 @@ def test_denylist_allowlist_check_denylist(
|
||||
def test_denylist_allowlist_check_allowlist(
|
||||
mock_config_denylist_allowlist_check, monkeypatch
|
||||
):
|
||||
# Test that the function returns True when the plugin is in the allowlist
|
||||
monkeypatch.setattr("builtins.input", lambda _: "y")
|
||||
assert denylist_allowlist_check("GoodPlugin", mock_config_denylist_allowlist_check)
|
||||
|
||||
@@ -46,6 +50,7 @@ def test_denylist_allowlist_check_allowlist(
|
||||
def test_denylist_allowlist_check_user_input_yes(
|
||||
mock_config_denylist_allowlist_check, monkeypatch
|
||||
):
|
||||
# Test that the function returns True when the user inputs "y"
|
||||
monkeypatch.setattr("builtins.input", lambda _: "y")
|
||||
assert denylist_allowlist_check(
|
||||
"UnknownPlugin", mock_config_denylist_allowlist_check
|
||||
@@ -55,6 +60,7 @@ def test_denylist_allowlist_check_user_input_yes(
|
||||
def test_denylist_allowlist_check_user_input_no(
|
||||
mock_config_denylist_allowlist_check, monkeypatch
|
||||
):
|
||||
# Test that the function returns False when the user inputs "n"
|
||||
monkeypatch.setattr("builtins.input", lambda _: "n")
|
||||
assert not denylist_allowlist_check(
|
||||
"UnknownPlugin", mock_config_denylist_allowlist_check
|
||||
@@ -64,6 +70,7 @@ def test_denylist_allowlist_check_user_input_no(
|
||||
def test_denylist_allowlist_check_user_input_invalid(
|
||||
mock_config_denylist_allowlist_check, monkeypatch
|
||||
):
|
||||
# Test that the function returns False when the user inputs an invalid value
|
||||
monkeypatch.setattr("builtins.input", lambda _: "invalid")
|
||||
assert not denylist_allowlist_check(
|
||||
"UnknownPlugin", mock_config_denylist_allowlist_check
|
||||
@@ -72,6 +79,8 @@ def test_denylist_allowlist_check_user_input_invalid(
|
||||
|
||||
@pytest.fixture
|
||||
def config_with_plugins():
|
||||
"""Mock config object for testing the scan_plugins function"""
|
||||
# Test that the function returns the correct number of plugins
|
||||
cfg = Config()
|
||||
cfg.plugins_dir = PLUGINS_TEST_DIR
|
||||
cfg.plugins_openai = ["https://weathergpt.vercel.app/"]
|
||||
@@ -80,7 +89,11 @@ def config_with_plugins():
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config_openai_plugin():
|
||||
"""Mock config object for testing the scan_plugins function"""
|
||||
|
||||
class MockConfig:
|
||||
"""Mock config object for testing the scan_plugins function"""
|
||||
|
||||
plugins_dir = PLUGINS_TEST_DIR
|
||||
plugins_openai = [PLUGIN_TEST_OPENAI]
|
||||
plugins_denylist = ["AutoGPTPVicuna"]
|
||||
@@ -90,12 +103,16 @@ def mock_config_openai_plugin():
|
||||
|
||||
|
||||
def test_scan_plugins_openai(mock_config_openai_plugin):
|
||||
# Test that the function returns the correct number of plugins
|
||||
result = scan_plugins(mock_config_openai_plugin, debug=True)
|
||||
assert len(result) == 1
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config_generic_plugin():
|
||||
"""Mock config object for testing the scan_plugins function"""
|
||||
|
||||
# Test that the function returns the correct number of plugins
|
||||
class MockConfig:
|
||||
plugins_dir = PLUGINS_TEST_DIR
|
||||
plugins_openai = []
|
||||
@@ -106,5 +123,6 @@ def mock_config_generic_plugin():
|
||||
|
||||
|
||||
def test_scan_plugins_generic(mock_config_generic_plugin):
|
||||
# Test that the function returns the correct number of plugins
|
||||
result = scan_plugins(mock_config_generic_plugin, debug=True)
|
||||
assert len(result) == 1
|
||||
|
||||
@@ -30,28 +30,29 @@ PLEASE_WAIT = "Please wait..."
|
||||
|
||||
|
||||
class TestSpinner(unittest.TestCase):
|
||||
# Tests that the spinner initializes with default values.
|
||||
def test_spinner_initializes_with_default_values(self):
|
||||
"""Tests that the spinner initializes with default values."""
|
||||
with Spinner() as spinner:
|
||||
self.assertEqual(spinner.message, "Loading...")
|
||||
self.assertEqual(spinner.delay, 0.1)
|
||||
|
||||
# Tests that the spinner initializes with custom message and delay values.
|
||||
def test_spinner_initializes_with_custom_values(self):
|
||||
"""Tests that the spinner initializes with custom message and delay values."""
|
||||
with Spinner(message=PLEASE_WAIT, delay=0.2) as spinner:
|
||||
self.assertEqual(spinner.message, PLEASE_WAIT)
|
||||
self.assertEqual(spinner.delay, 0.2)
|
||||
|
||||
# Tests that the spinner starts spinning and stops spinning without errors.
|
||||
#
|
||||
def test_spinner_stops_spinning(self):
|
||||
"""Tests that the spinner starts spinning and stops spinning without errors."""
|
||||
with Spinner() as spinner:
|
||||
time.sleep(1)
|
||||
spinner.update_message(ALMOST_DONE_MESSAGE)
|
||||
time.sleep(1)
|
||||
self.assertFalse(spinner.running)
|
||||
|
||||
# Tests that the spinner message can be updated while the spinner is running and the spinner continues spinning.
|
||||
def test_spinner_updates_message_and_still_spins(self):
|
||||
"""Tests that the spinner message can be updated while the spinner is running and the spinner continues spinning."""
|
||||
with Spinner() as spinner:
|
||||
self.assertTrue(spinner.running)
|
||||
time.sleep(1)
|
||||
@@ -60,9 +61,8 @@ class TestSpinner(unittest.TestCase):
|
||||
self.assertEqual(spinner.message, ALMOST_DONE_MESSAGE)
|
||||
self.assertFalse(spinner.running)
|
||||
|
||||
# Tests that the spinner can be used as a context manager.
|
||||
|
||||
def test_spinner_can_be_used_as_context_manager(self):
|
||||
"""Tests that the spinner can be used as a context manager."""
|
||||
with Spinner() as spinner:
|
||||
self.assertTrue(spinner.running)
|
||||
self.assertFalse(spinner.running)
|
||||
|
||||
Reference in New Issue
Block a user