mirror of
https://github.com/aljazceru/Auto-GPT.git
synced 2026-01-11 18:24:38 +01:00
Fix broken tests (casualties from the past few days)
This commit is contained in:
@@ -347,7 +347,7 @@ def execute_command(
|
||||
raise CommandExecutionError(str(e))
|
||||
|
||||
# Handle non-native commands (e.g. from plugins)
|
||||
for command in agent.ai_config.prompt_generator.commands:
|
||||
for command in agent.prompt_generator.commands:
|
||||
if (
|
||||
command_name == command.label.lower()
|
||||
or command_name == command.name.lower()
|
||||
|
||||
@@ -53,7 +53,7 @@ class ContextMixin:
|
||||
0, Message("system", "# Context\n" + self.context.format_numbered())
|
||||
)
|
||||
|
||||
return super(ContextMixin, self).construct_base_prompt(*args, **kwargs)
|
||||
return super(ContextMixin, self).construct_base_prompt(*args, **kwargs) # type: ignore
|
||||
|
||||
|
||||
def get_agent_context(agent: BaseAgent) -> AgentContext | None:
|
||||
|
||||
@@ -19,7 +19,7 @@ from autogpt.command_decorator import command
|
||||
from autogpt.memory.vector import MemoryItem, VectorMemory
|
||||
|
||||
from .decorators import sanitize_path_arg
|
||||
from .file_context import open_file, open_folder # NOQA
|
||||
from .file_context import open_file, open_folder # NOQA
|
||||
from .file_operations_utils import read_textual_file
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -169,15 +169,15 @@ def retry_api(
|
||||
warn_user bool: Whether to warn the user. Defaults to True.
|
||||
"""
|
||||
error_messages = {
|
||||
ServiceUnavailableError: f"{Fore.RED}Error: The OpenAI API engine is currently overloaded{Fore.RESET}",
|
||||
RateLimitError: f"{Fore.RED}Error: Reached rate limit{Fore.RESET}",
|
||||
ServiceUnavailableError: "The OpenAI API engine is currently overloaded",
|
||||
RateLimitError: "Reached rate limit",
|
||||
}
|
||||
api_key_error_msg = (
|
||||
f"Please double check that you have setup a "
|
||||
f"{Fore.CYAN + Style.BRIGHT}PAID{Style.RESET_ALL} OpenAI API Account. You can "
|
||||
f"read more here: {Fore.CYAN}https://docs.agpt.co/setup/#getting-an-api-key{Fore.RESET}"
|
||||
)
|
||||
backoff_msg = f"{Fore.RED}Waiting {{backoff}} seconds...{Fore.RESET}"
|
||||
backoff_msg = "Waiting {backoff} seconds..."
|
||||
|
||||
def _wrapper(func: Callable):
|
||||
@functools.wraps(func)
|
||||
|
||||
@@ -12,6 +12,7 @@ from openai.util import logger as openai_logger
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.config import Config
|
||||
|
||||
from .filters import BelowLevelFilter
|
||||
from .formatters import AutoGptFormatter
|
||||
from .handlers import TTSHandler, TypingConsoleHandler
|
||||
|
||||
@@ -42,10 +43,14 @@ def configure_logging(config: Config, log_dir: Path = LOG_DIR) -> None:
|
||||
log_format = DEBUG_LOG_FORMAT if config.debug_mode else SIMPLE_LOG_FORMAT
|
||||
console_formatter = AutoGptFormatter(log_format)
|
||||
|
||||
# Console output handler
|
||||
console_handler = logging.StreamHandler(stream=sys.stdout)
|
||||
console_handler.setLevel(log_level)
|
||||
console_handler.setFormatter(console_formatter)
|
||||
# Console output handlers
|
||||
stdout = logging.StreamHandler(stream=sys.stdout)
|
||||
stdout.setLevel(log_level)
|
||||
stdout.addFilter(BelowLevelFilter(logging.WARNING))
|
||||
stdout.setFormatter(console_formatter)
|
||||
stderr = logging.StreamHandler()
|
||||
stderr.setLevel(logging.WARNING)
|
||||
stderr.setFormatter(console_formatter)
|
||||
|
||||
# INFO log file handler
|
||||
activity_log_handler = logging.FileHandler(log_dir / LOG_FILE, "a", "utf-8")
|
||||
@@ -68,7 +73,7 @@ def configure_logging(config: Config, log_dir: Path = LOG_DIR) -> None:
|
||||
format=log_format,
|
||||
level=log_level,
|
||||
handlers=(
|
||||
[console_handler, activity_log_handler, error_log_handler]
|
||||
[stdout, stderr, activity_log_handler, error_log_handler]
|
||||
+ ([debug_log_handler] if config.debug_mode else [])
|
||||
),
|
||||
)
|
||||
@@ -81,13 +86,14 @@ def configure_logging(config: Config, log_dir: Path = LOG_DIR) -> None:
|
||||
typing_console_handler.setFormatter(console_formatter)
|
||||
|
||||
user_friendly_output_logger = logging.getLogger(USER_FRIENDLY_OUTPUT_LOGGER)
|
||||
user_friendly_output_logger.setLevel(logging.INFO)
|
||||
user_friendly_output_logger.addHandler(
|
||||
typing_console_handler if not config.plain_output else console_handler
|
||||
typing_console_handler if not config.plain_output else stdout
|
||||
)
|
||||
user_friendly_output_logger.addHandler(TTSHandler(config))
|
||||
user_friendly_output_logger.addHandler(activity_log_handler)
|
||||
user_friendly_output_logger.addHandler(error_log_handler)
|
||||
user_friendly_output_logger.setLevel(logging.INFO)
|
||||
user_friendly_output_logger.addHandler(stderr)
|
||||
user_friendly_output_logger.propagate = False
|
||||
|
||||
# JSON logger with better formatting
|
||||
|
||||
12
autogpt/logs/filters.py
Normal file
12
autogpt/logs/filters.py
Normal file
@@ -0,0 +1,12 @@
|
||||
import logging
|
||||
|
||||
|
||||
class BelowLevelFilter(logging.Filter):
|
||||
"""Filter for logging levels below a certain threshold."""
|
||||
|
||||
def __init__(self, below_level: int):
|
||||
super().__init__()
|
||||
self.below_level = below_level
|
||||
|
||||
def filter(self, record: logging.LogRecord):
|
||||
return record.levelno < self.below_level
|
||||
@@ -11,6 +11,7 @@ def user_friendly_output(
|
||||
level: int = logging.INFO,
|
||||
title: str = "",
|
||||
title_color: str = "",
|
||||
preserve_message_color: bool = False,
|
||||
) -> None:
|
||||
"""Outputs a message to the user in a user-friendly way.
|
||||
|
||||
@@ -24,7 +25,15 @@ def user_friendly_output(
|
||||
for plugin in _chat_plugins:
|
||||
plugin.report(f"{title}: {message}")
|
||||
|
||||
logger.log(level, message, extra={"title": title, "title_color": title_color})
|
||||
logger.log(
|
||||
level,
|
||||
message,
|
||||
extra={
|
||||
"title": title,
|
||||
"title_color": title_color,
|
||||
"preserve_color": preserve_message_color,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def print_attribute(
|
||||
@@ -51,5 +60,8 @@ def request_user_double_check(additionalText: Optional[str] = None) -> None:
|
||||
)
|
||||
|
||||
user_friendly_output(
|
||||
additionalText, level=logging.WARN, title="DOUBLE CHECK CONFIGURATION"
|
||||
additionalText,
|
||||
level=logging.WARN,
|
||||
title="DOUBLE CHECK CONFIGURATION",
|
||||
preserve_message_color=True,
|
||||
)
|
||||
|
||||
@@ -93,9 +93,9 @@ class CommandRegistry:
|
||||
if name in self.commands_aliases:
|
||||
return self.commands_aliases[name]
|
||||
|
||||
def call(self, command_name: str, **kwargs) -> Any:
|
||||
def call(self, command_name: str, agent: BaseAgent, **kwargs) -> Any:
|
||||
if command := self.get_command(command_name):
|
||||
return command(**kwargs)
|
||||
return command(**kwargs, agent=agent)
|
||||
raise KeyError(f"Command '{command_name}' not found in registry")
|
||||
|
||||
def list_available_commands(self, agent: BaseAgent) -> Iterator[Command]:
|
||||
|
||||
@@ -57,7 +57,6 @@ def config(
|
||||
config.plugins_dir = "tests/unit/data/test_plugins"
|
||||
config.plugins_config_file = temp_plugins_config_file
|
||||
|
||||
# HACK: this is necessary to ensure PLAIN_OUTPUT takes effect
|
||||
config.plain_output = True
|
||||
configure_logging(config, Path(__file__).parent / "logs")
|
||||
|
||||
@@ -95,7 +94,6 @@ def agent(config: Config) -> Agent:
|
||||
)
|
||||
|
||||
command_registry = CommandRegistry()
|
||||
ai_config.command_registry = command_registry
|
||||
config.memory_backend = "json_file"
|
||||
memory_json_file = get_memory(config)
|
||||
memory_json_file.clear()
|
||||
|
||||
@@ -11,7 +11,7 @@ def test_agent_initialization(agent: Agent):
|
||||
def test_execute_command_plugin(agent: Agent):
|
||||
"""Test that executing a command that came from a plugin works as expected"""
|
||||
command_name = "check_plan"
|
||||
agent.ai_config.prompt_generator.add_command(
|
||||
agent.prompt_generator.add_command(
|
||||
command_name,
|
||||
"Read the plan.md with the next goals to achieve",
|
||||
{},
|
||||
|
||||
@@ -55,8 +55,6 @@ def test_ai_config_file_not_exists(workspace):
|
||||
assert ai_config.ai_role == ""
|
||||
assert ai_config.ai_goals == []
|
||||
assert ai_config.api_budget == 0.0
|
||||
assert ai_config.prompt_generator is None
|
||||
assert ai_config.command_registry is None
|
||||
|
||||
|
||||
def test_ai_config_file_is_empty(workspace):
|
||||
@@ -70,5 +68,3 @@ def test_ai_config_file_is_empty(workspace):
|
||||
assert ai_config.ai_role == ""
|
||||
assert ai_config.ai_goals == []
|
||||
assert ai_config.api_budget == 0.0
|
||||
assert ai_config.prompt_generator is None
|
||||
assert ai_config.command_registry is None
|
||||
|
||||
@@ -1,10 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.agents import Agent, BaseAgent
|
||||
|
||||
from autogpt.models.command import Command, CommandParameter
|
||||
from autogpt.models.command_registry import CommandRegistry
|
||||
|
||||
@@ -14,7 +20,7 @@ PARAMETERS = [
|
||||
]
|
||||
|
||||
|
||||
def example_command_method(arg1: int, arg2: str) -> str:
|
||||
def example_command_method(arg1: int, arg2: str, agent: BaseAgent) -> str:
|
||||
"""Example function for testing the Command class."""
|
||||
# This function is static because it is not used by any other test cases.
|
||||
return f"{arg1} - {arg2}"
|
||||
@@ -47,16 +53,16 @@ def example_command():
|
||||
)
|
||||
|
||||
|
||||
def test_command_call(example_command: Command):
|
||||
def test_command_call(example_command: Command, agent: Agent):
|
||||
"""Test that Command(*args) calls and returns the result of method(*args)."""
|
||||
result = example_command(arg1=1, arg2="test")
|
||||
result = example_command(arg1=1, arg2="test", agent=agent)
|
||||
assert result == "1 - test"
|
||||
|
||||
|
||||
def test_command_call_with_invalid_arguments(example_command: Command):
|
||||
def test_command_call_with_invalid_arguments(example_command: Command, agent: Agent):
|
||||
"""Test that calling a Command object with invalid arguments raises a TypeError."""
|
||||
with pytest.raises(TypeError):
|
||||
example_command(arg1="invalid", does_not_exist="test")
|
||||
example_command(arg1="invalid", does_not_exist="test", agent=agent)
|
||||
|
||||
|
||||
def test_register_command(example_command: Command):
|
||||
@@ -148,7 +154,7 @@ def test_get_nonexistent_command():
|
||||
assert "nonexistent_command" not in registry
|
||||
|
||||
|
||||
def test_call_command():
|
||||
def test_call_command(agent: Agent):
|
||||
"""Test that a command can be called through the registry."""
|
||||
registry = CommandRegistry()
|
||||
cmd = Command(
|
||||
@@ -159,17 +165,17 @@ def test_call_command():
|
||||
)
|
||||
|
||||
registry.register(cmd)
|
||||
result = registry.call("example", arg1=1, arg2="test")
|
||||
result = registry.call("example", arg1=1, arg2="test", agent=agent)
|
||||
|
||||
assert result == "1 - test"
|
||||
|
||||
|
||||
def test_call_nonexistent_command():
|
||||
def test_call_nonexistent_command(agent: Agent):
|
||||
"""Test that attempting to call a nonexistent command raises a KeyError."""
|
||||
registry = CommandRegistry()
|
||||
|
||||
with pytest.raises(KeyError):
|
||||
registry.call("nonexistent_command", arg1=1, arg2="test")
|
||||
registry.call("nonexistent_command", arg1=1, arg2="test", agent=agent)
|
||||
|
||||
|
||||
def test_import_mock_commands_module():
|
||||
|
||||
@@ -18,7 +18,7 @@ def test_clone_auto_gpt_repository(workspace, mock_clone_from, agent: Agent):
|
||||
repo = "github.com/Significant-Gravitas/Auto-GPT.git"
|
||||
scheme = "https://"
|
||||
url = scheme + repo
|
||||
clone_path = str(workspace.get_path("auto-gpt-repo"))
|
||||
clone_path = workspace.get_path("auto-gpt-repo")
|
||||
|
||||
expected_output = f"Cloned {url} to {clone_path}"
|
||||
|
||||
@@ -33,7 +33,7 @@ def test_clone_auto_gpt_repository(workspace, mock_clone_from, agent: Agent):
|
||||
|
||||
def test_clone_repository_error(workspace, mock_clone_from, agent: Agent):
|
||||
url = "https://github.com/this-repository/does-not-exist.git"
|
||||
clone_path = str(workspace.get_path("does-not-exist"))
|
||||
clone_path = workspace.get_path("does-not-exist")
|
||||
|
||||
mock_clone_from.side_effect = GitCommandError(
|
||||
"clone", "fatal: repository not found", ""
|
||||
|
||||
@@ -31,7 +31,7 @@ def error_factory(error_instance, error_count, retry_count, warn_user=True):
|
||||
return RaisesError()
|
||||
|
||||
|
||||
def test_retry_open_api_no_error(capsys):
|
||||
def test_retry_open_api_no_error(caplog: pytest.LogCaptureFixture):
|
||||
"""Tests the retry functionality with no errors expected"""
|
||||
|
||||
@openai.retry_api()
|
||||
@@ -41,9 +41,9 @@ def test_retry_open_api_no_error(capsys):
|
||||
result = f()
|
||||
assert result == 1
|
||||
|
||||
output = capsys.readouterr()
|
||||
assert output.out == ""
|
||||
assert output.err == ""
|
||||
output = caplog.text
|
||||
assert output == ""
|
||||
assert output == ""
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -51,7 +51,9 @@ def test_retry_open_api_no_error(capsys):
|
||||
[(2, 10, False), (2, 2, False), (10, 2, True), (3, 2, True), (1, 0, True)],
|
||||
ids=["passing", "passing_edge", "failing", "failing_edge", "failing_no_retries"],
|
||||
)
|
||||
def test_retry_open_api_passing(capsys, error, error_count, retry_count, failure):
|
||||
def test_retry_open_api_passing(
|
||||
caplog: pytest.LogCaptureFixture, error, error_count, retry_count, failure
|
||||
):
|
||||
"""Tests the retry with simulated errors [RateLimitError, ServiceUnavailableError, APIError], but should ulimately pass"""
|
||||
call_count = min(error_count, retry_count) + 1
|
||||
|
||||
@@ -65,20 +67,20 @@ def test_retry_open_api_passing(capsys, error, error_count, retry_count, failure
|
||||
|
||||
assert raises.count == call_count
|
||||
|
||||
output = capsys.readouterr()
|
||||
output = caplog.text
|
||||
|
||||
if error_count and retry_count:
|
||||
if type(error) == RateLimitError:
|
||||
assert "Reached rate limit" in output.out
|
||||
assert "Please double check" in output.out
|
||||
assert "Reached rate limit" in output
|
||||
assert "Please double check" in output
|
||||
if type(error) == ServiceUnavailableError:
|
||||
assert "The OpenAI API engine is currently overloaded" in output.out
|
||||
assert "Please double check" in output.out
|
||||
assert "The OpenAI API engine is currently overloaded" in output
|
||||
assert "Please double check" in output
|
||||
else:
|
||||
assert output.out == ""
|
||||
assert output == ""
|
||||
|
||||
|
||||
def test_retry_open_api_rate_limit_no_warn(capsys):
|
||||
def test_retry_open_api_rate_limit_no_warn(caplog: pytest.LogCaptureFixture):
|
||||
"""Tests the retry logic with a rate limit error"""
|
||||
error_count = 2
|
||||
retry_count = 10
|
||||
@@ -89,13 +91,13 @@ def test_retry_open_api_rate_limit_no_warn(capsys):
|
||||
assert result == call_count
|
||||
assert raises.count == call_count
|
||||
|
||||
output = capsys.readouterr()
|
||||
output = caplog.text
|
||||
|
||||
assert "Reached rate limit" in output.out
|
||||
assert "Please double check" not in output.out
|
||||
assert "Reached rate limit" in output
|
||||
assert "Please double check" not in output
|
||||
|
||||
|
||||
def test_retry_open_api_service_unavairable_no_warn(capsys):
|
||||
def test_retry_open_api_service_unavairable_no_warn(caplog: pytest.LogCaptureFixture):
|
||||
"""Tests the retry logic with a service unavairable error"""
|
||||
error_count = 2
|
||||
retry_count = 10
|
||||
@@ -108,13 +110,13 @@ def test_retry_open_api_service_unavairable_no_warn(capsys):
|
||||
assert result == call_count
|
||||
assert raises.count == call_count
|
||||
|
||||
output = capsys.readouterr()
|
||||
output = caplog.text
|
||||
|
||||
assert "The OpenAI API engine is currently overloaded" in output.out
|
||||
assert "Please double check" not in output.out
|
||||
assert "The OpenAI API engine is currently overloaded" in output
|
||||
assert "Please double check" not in output
|
||||
|
||||
|
||||
def test_retry_openapi_other_api_error(capsys):
|
||||
def test_retry_openapi_other_api_error(caplog: pytest.LogCaptureFixture):
|
||||
"""Tests the Retry logic with a non rate limit error such as HTTP500"""
|
||||
error_count = 2
|
||||
retry_count = 10
|
||||
@@ -126,5 +128,5 @@ def test_retry_openapi_other_api_error(capsys):
|
||||
call_count = 1
|
||||
assert raises.count == call_count
|
||||
|
||||
output = capsys.readouterr()
|
||||
assert output.out == ""
|
||||
output = caplog.text
|
||||
assert output == ""
|
||||
|
||||
Reference in New Issue
Block a user