From 97ccaba45f3bcbbf4733226adf3b10f8420e5fc2 Mon Sep 17 00:00:00 2001 From: Reinier van der Leer Date: Wed, 23 Aug 2023 02:26:39 +0200 Subject: [PATCH] Fix broken tests (casualties from the past few days) --- autogpt/agents/agent.py | 2 +- autogpt/agents/features/context.py | 2 +- autogpt/commands/file_operations.py | 2 +- autogpt/llm/providers/openai.py | 6 ++-- autogpt/logs/config.py | 20 +++++++---- autogpt/logs/filters.py | 12 +++++++ autogpt/logs/helpers.py | 16 +++++++-- autogpt/models/command_registry.py | 4 +-- tests/conftest.py | 2 -- tests/unit/test_agent.py | 2 +- tests/unit/test_ai_config.py | 4 --- tests/unit/test_commands.py | 24 ++++++++----- tests/unit/test_git_commands.py | 4 +-- tests/unit/test_retry_provider_openai.py | 46 ++++++++++++------------ 14 files changed, 89 insertions(+), 57 deletions(-) create mode 100644 autogpt/logs/filters.py diff --git a/autogpt/agents/agent.py b/autogpt/agents/agent.py index e4893311..29686c09 100644 --- a/autogpt/agents/agent.py +++ b/autogpt/agents/agent.py @@ -347,7 +347,7 @@ def execute_command( raise CommandExecutionError(str(e)) # Handle non-native commands (e.g. from plugins) - for command in agent.ai_config.prompt_generator.commands: + for command in agent.prompt_generator.commands: if ( command_name == command.label.lower() or command_name == command.name.lower() diff --git a/autogpt/agents/features/context.py b/autogpt/agents/features/context.py index 847f318d..b957d963 100644 --- a/autogpt/agents/features/context.py +++ b/autogpt/agents/features/context.py @@ -53,7 +53,7 @@ class ContextMixin: 0, Message("system", "# Context\n" + self.context.format_numbered()) ) - return super(ContextMixin, self).construct_base_prompt(*args, **kwargs) + return super(ContextMixin, self).construct_base_prompt(*args, **kwargs) # type: ignore def get_agent_context(agent: BaseAgent) -> AgentContext | None: diff --git a/autogpt/commands/file_operations.py b/autogpt/commands/file_operations.py index 9e4faaab..1aa73014 100644 --- a/autogpt/commands/file_operations.py +++ b/autogpt/commands/file_operations.py @@ -19,7 +19,7 @@ from autogpt.command_decorator import command from autogpt.memory.vector import MemoryItem, VectorMemory from .decorators import sanitize_path_arg -from .file_context import open_file, open_folder # NOQA +from .file_context import open_file, open_folder # NOQA from .file_operations_utils import read_textual_file logger = logging.getLogger(__name__) diff --git a/autogpt/llm/providers/openai.py b/autogpt/llm/providers/openai.py index ec9dbb11..d4c2b15f 100644 --- a/autogpt/llm/providers/openai.py +++ b/autogpt/llm/providers/openai.py @@ -169,15 +169,15 @@ def retry_api( warn_user bool: Whether to warn the user. Defaults to True. """ error_messages = { - ServiceUnavailableError: f"{Fore.RED}Error: The OpenAI API engine is currently overloaded{Fore.RESET}", - RateLimitError: f"{Fore.RED}Error: Reached rate limit{Fore.RESET}", + ServiceUnavailableError: "The OpenAI API engine is currently overloaded", + RateLimitError: "Reached rate limit", } api_key_error_msg = ( f"Please double check that you have setup a " f"{Fore.CYAN + Style.BRIGHT}PAID{Style.RESET_ALL} OpenAI API Account. You can " f"read more here: {Fore.CYAN}https://docs.agpt.co/setup/#getting-an-api-key{Fore.RESET}" ) - backoff_msg = f"{Fore.RED}Waiting {{backoff}} seconds...{Fore.RESET}" + backoff_msg = "Waiting {backoff} seconds..." def _wrapper(func: Callable): @functools.wraps(func) diff --git a/autogpt/logs/config.py b/autogpt/logs/config.py index e7dbab4d..f38bde40 100644 --- a/autogpt/logs/config.py +++ b/autogpt/logs/config.py @@ -12,6 +12,7 @@ from openai.util import logger as openai_logger if TYPE_CHECKING: from autogpt.config import Config +from .filters import BelowLevelFilter from .formatters import AutoGptFormatter from .handlers import TTSHandler, TypingConsoleHandler @@ -42,10 +43,14 @@ def configure_logging(config: Config, log_dir: Path = LOG_DIR) -> None: log_format = DEBUG_LOG_FORMAT if config.debug_mode else SIMPLE_LOG_FORMAT console_formatter = AutoGptFormatter(log_format) - # Console output handler - console_handler = logging.StreamHandler(stream=sys.stdout) - console_handler.setLevel(log_level) - console_handler.setFormatter(console_formatter) + # Console output handlers + stdout = logging.StreamHandler(stream=sys.stdout) + stdout.setLevel(log_level) + stdout.addFilter(BelowLevelFilter(logging.WARNING)) + stdout.setFormatter(console_formatter) + stderr = logging.StreamHandler() + stderr.setLevel(logging.WARNING) + stderr.setFormatter(console_formatter) # INFO log file handler activity_log_handler = logging.FileHandler(log_dir / LOG_FILE, "a", "utf-8") @@ -68,7 +73,7 @@ def configure_logging(config: Config, log_dir: Path = LOG_DIR) -> None: format=log_format, level=log_level, handlers=( - [console_handler, activity_log_handler, error_log_handler] + [stdout, stderr, activity_log_handler, error_log_handler] + ([debug_log_handler] if config.debug_mode else []) ), ) @@ -81,13 +86,14 @@ def configure_logging(config: Config, log_dir: Path = LOG_DIR) -> None: typing_console_handler.setFormatter(console_formatter) user_friendly_output_logger = logging.getLogger(USER_FRIENDLY_OUTPUT_LOGGER) + user_friendly_output_logger.setLevel(logging.INFO) user_friendly_output_logger.addHandler( - typing_console_handler if not config.plain_output else console_handler + typing_console_handler if not config.plain_output else stdout ) user_friendly_output_logger.addHandler(TTSHandler(config)) user_friendly_output_logger.addHandler(activity_log_handler) user_friendly_output_logger.addHandler(error_log_handler) - user_friendly_output_logger.setLevel(logging.INFO) + user_friendly_output_logger.addHandler(stderr) user_friendly_output_logger.propagate = False # JSON logger with better formatting diff --git a/autogpt/logs/filters.py b/autogpt/logs/filters.py new file mode 100644 index 00000000..7a0ccd75 --- /dev/null +++ b/autogpt/logs/filters.py @@ -0,0 +1,12 @@ +import logging + + +class BelowLevelFilter(logging.Filter): + """Filter for logging levels below a certain threshold.""" + + def __init__(self, below_level: int): + super().__init__() + self.below_level = below_level + + def filter(self, record: logging.LogRecord): + return record.levelno < self.below_level diff --git a/autogpt/logs/helpers.py b/autogpt/logs/helpers.py index e0c90b52..482057ec 100644 --- a/autogpt/logs/helpers.py +++ b/autogpt/logs/helpers.py @@ -11,6 +11,7 @@ def user_friendly_output( level: int = logging.INFO, title: str = "", title_color: str = "", + preserve_message_color: bool = False, ) -> None: """Outputs a message to the user in a user-friendly way. @@ -24,7 +25,15 @@ def user_friendly_output( for plugin in _chat_plugins: plugin.report(f"{title}: {message}") - logger.log(level, message, extra={"title": title, "title_color": title_color}) + logger.log( + level, + message, + extra={ + "title": title, + "title_color": title_color, + "preserve_color": preserve_message_color, + }, + ) def print_attribute( @@ -51,5 +60,8 @@ def request_user_double_check(additionalText: Optional[str] = None) -> None: ) user_friendly_output( - additionalText, level=logging.WARN, title="DOUBLE CHECK CONFIGURATION" + additionalText, + level=logging.WARN, + title="DOUBLE CHECK CONFIGURATION", + preserve_message_color=True, ) diff --git a/autogpt/models/command_registry.py b/autogpt/models/command_registry.py index 0b94499f..eb61eb59 100644 --- a/autogpt/models/command_registry.py +++ b/autogpt/models/command_registry.py @@ -93,9 +93,9 @@ class CommandRegistry: if name in self.commands_aliases: return self.commands_aliases[name] - def call(self, command_name: str, **kwargs) -> Any: + def call(self, command_name: str, agent: BaseAgent, **kwargs) -> Any: if command := self.get_command(command_name): - return command(**kwargs) + return command(**kwargs, agent=agent) raise KeyError(f"Command '{command_name}' not found in registry") def list_available_commands(self, agent: BaseAgent) -> Iterator[Command]: diff --git a/tests/conftest.py b/tests/conftest.py index c4dfb347..21dc900c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -57,7 +57,6 @@ def config( config.plugins_dir = "tests/unit/data/test_plugins" config.plugins_config_file = temp_plugins_config_file - # HACK: this is necessary to ensure PLAIN_OUTPUT takes effect config.plain_output = True configure_logging(config, Path(__file__).parent / "logs") @@ -95,7 +94,6 @@ def agent(config: Config) -> Agent: ) command_registry = CommandRegistry() - ai_config.command_registry = command_registry config.memory_backend = "json_file" memory_json_file = get_memory(config) memory_json_file.clear() diff --git a/tests/unit/test_agent.py b/tests/unit/test_agent.py index 7e36d925..ef5ef28a 100644 --- a/tests/unit/test_agent.py +++ b/tests/unit/test_agent.py @@ -11,7 +11,7 @@ def test_agent_initialization(agent: Agent): def test_execute_command_plugin(agent: Agent): """Test that executing a command that came from a plugin works as expected""" command_name = "check_plan" - agent.ai_config.prompt_generator.add_command( + agent.prompt_generator.add_command( command_name, "Read the plan.md with the next goals to achieve", {}, diff --git a/tests/unit/test_ai_config.py b/tests/unit/test_ai_config.py index e3c31d5d..6c999c2d 100644 --- a/tests/unit/test_ai_config.py +++ b/tests/unit/test_ai_config.py @@ -55,8 +55,6 @@ def test_ai_config_file_not_exists(workspace): assert ai_config.ai_role == "" assert ai_config.ai_goals == [] assert ai_config.api_budget == 0.0 - assert ai_config.prompt_generator is None - assert ai_config.command_registry is None def test_ai_config_file_is_empty(workspace): @@ -70,5 +68,3 @@ def test_ai_config_file_is_empty(workspace): assert ai_config.ai_role == "" assert ai_config.ai_goals == [] assert ai_config.api_budget == 0.0 - assert ai_config.prompt_generator is None - assert ai_config.command_registry is None diff --git a/tests/unit/test_commands.py b/tests/unit/test_commands.py index 25867f21..0fb2869d 100644 --- a/tests/unit/test_commands.py +++ b/tests/unit/test_commands.py @@ -1,10 +1,16 @@ +from __future__ import annotations + import os import shutil import sys from pathlib import Path +from typing import TYPE_CHECKING import pytest +if TYPE_CHECKING: + from autogpt.agents import Agent, BaseAgent + from autogpt.models.command import Command, CommandParameter from autogpt.models.command_registry import CommandRegistry @@ -14,7 +20,7 @@ PARAMETERS = [ ] -def example_command_method(arg1: int, arg2: str) -> str: +def example_command_method(arg1: int, arg2: str, agent: BaseAgent) -> str: """Example function for testing the Command class.""" # This function is static because it is not used by any other test cases. return f"{arg1} - {arg2}" @@ -47,16 +53,16 @@ def example_command(): ) -def test_command_call(example_command: Command): +def test_command_call(example_command: Command, agent: Agent): """Test that Command(*args) calls and returns the result of method(*args).""" - result = example_command(arg1=1, arg2="test") + result = example_command(arg1=1, arg2="test", agent=agent) assert result == "1 - test" -def test_command_call_with_invalid_arguments(example_command: Command): +def test_command_call_with_invalid_arguments(example_command: Command, agent: Agent): """Test that calling a Command object with invalid arguments raises a TypeError.""" with pytest.raises(TypeError): - example_command(arg1="invalid", does_not_exist="test") + example_command(arg1="invalid", does_not_exist="test", agent=agent) def test_register_command(example_command: Command): @@ -148,7 +154,7 @@ def test_get_nonexistent_command(): assert "nonexistent_command" not in registry -def test_call_command(): +def test_call_command(agent: Agent): """Test that a command can be called through the registry.""" registry = CommandRegistry() cmd = Command( @@ -159,17 +165,17 @@ def test_call_command(): ) registry.register(cmd) - result = registry.call("example", arg1=1, arg2="test") + result = registry.call("example", arg1=1, arg2="test", agent=agent) assert result == "1 - test" -def test_call_nonexistent_command(): +def test_call_nonexistent_command(agent: Agent): """Test that attempting to call a nonexistent command raises a KeyError.""" registry = CommandRegistry() with pytest.raises(KeyError): - registry.call("nonexistent_command", arg1=1, arg2="test") + registry.call("nonexistent_command", arg1=1, arg2="test", agent=agent) def test_import_mock_commands_module(): diff --git a/tests/unit/test_git_commands.py b/tests/unit/test_git_commands.py index 072c56f3..81395868 100644 --- a/tests/unit/test_git_commands.py +++ b/tests/unit/test_git_commands.py @@ -18,7 +18,7 @@ def test_clone_auto_gpt_repository(workspace, mock_clone_from, agent: Agent): repo = "github.com/Significant-Gravitas/Auto-GPT.git" scheme = "https://" url = scheme + repo - clone_path = str(workspace.get_path("auto-gpt-repo")) + clone_path = workspace.get_path("auto-gpt-repo") expected_output = f"Cloned {url} to {clone_path}" @@ -33,7 +33,7 @@ def test_clone_auto_gpt_repository(workspace, mock_clone_from, agent: Agent): def test_clone_repository_error(workspace, mock_clone_from, agent: Agent): url = "https://github.com/this-repository/does-not-exist.git" - clone_path = str(workspace.get_path("does-not-exist")) + clone_path = workspace.get_path("does-not-exist") mock_clone_from.side_effect = GitCommandError( "clone", "fatal: repository not found", "" diff --git a/tests/unit/test_retry_provider_openai.py b/tests/unit/test_retry_provider_openai.py index 1b23f5d2..f626807c 100644 --- a/tests/unit/test_retry_provider_openai.py +++ b/tests/unit/test_retry_provider_openai.py @@ -31,7 +31,7 @@ def error_factory(error_instance, error_count, retry_count, warn_user=True): return RaisesError() -def test_retry_open_api_no_error(capsys): +def test_retry_open_api_no_error(caplog: pytest.LogCaptureFixture): """Tests the retry functionality with no errors expected""" @openai.retry_api() @@ -41,9 +41,9 @@ def test_retry_open_api_no_error(capsys): result = f() assert result == 1 - output = capsys.readouterr() - assert output.out == "" - assert output.err == "" + output = caplog.text + assert output == "" + assert output == "" @pytest.mark.parametrize( @@ -51,7 +51,9 @@ def test_retry_open_api_no_error(capsys): [(2, 10, False), (2, 2, False), (10, 2, True), (3, 2, True), (1, 0, True)], ids=["passing", "passing_edge", "failing", "failing_edge", "failing_no_retries"], ) -def test_retry_open_api_passing(capsys, error, error_count, retry_count, failure): +def test_retry_open_api_passing( + caplog: pytest.LogCaptureFixture, error, error_count, retry_count, failure +): """Tests the retry with simulated errors [RateLimitError, ServiceUnavailableError, APIError], but should ulimately pass""" call_count = min(error_count, retry_count) + 1 @@ -65,20 +67,20 @@ def test_retry_open_api_passing(capsys, error, error_count, retry_count, failure assert raises.count == call_count - output = capsys.readouterr() + output = caplog.text if error_count and retry_count: if type(error) == RateLimitError: - assert "Reached rate limit" in output.out - assert "Please double check" in output.out + assert "Reached rate limit" in output + assert "Please double check" in output if type(error) == ServiceUnavailableError: - assert "The OpenAI API engine is currently overloaded" in output.out - assert "Please double check" in output.out + assert "The OpenAI API engine is currently overloaded" in output + assert "Please double check" in output else: - assert output.out == "" + assert output == "" -def test_retry_open_api_rate_limit_no_warn(capsys): +def test_retry_open_api_rate_limit_no_warn(caplog: pytest.LogCaptureFixture): """Tests the retry logic with a rate limit error""" error_count = 2 retry_count = 10 @@ -89,13 +91,13 @@ def test_retry_open_api_rate_limit_no_warn(capsys): assert result == call_count assert raises.count == call_count - output = capsys.readouterr() + output = caplog.text - assert "Reached rate limit" in output.out - assert "Please double check" not in output.out + assert "Reached rate limit" in output + assert "Please double check" not in output -def test_retry_open_api_service_unavairable_no_warn(capsys): +def test_retry_open_api_service_unavairable_no_warn(caplog: pytest.LogCaptureFixture): """Tests the retry logic with a service unavairable error""" error_count = 2 retry_count = 10 @@ -108,13 +110,13 @@ def test_retry_open_api_service_unavairable_no_warn(capsys): assert result == call_count assert raises.count == call_count - output = capsys.readouterr() + output = caplog.text - assert "The OpenAI API engine is currently overloaded" in output.out - assert "Please double check" not in output.out + assert "The OpenAI API engine is currently overloaded" in output + assert "Please double check" not in output -def test_retry_openapi_other_api_error(capsys): +def test_retry_openapi_other_api_error(caplog: pytest.LogCaptureFixture): """Tests the Retry logic with a non rate limit error such as HTTP500""" error_count = 2 retry_count = 10 @@ -126,5 +128,5 @@ def test_retry_openapi_other_api_error(capsys): call_count = 1 assert raises.count == call_count - output = capsys.readouterr() - assert output.out == "" + output = caplog.text + assert output == ""