From b1570543c872c342f23d15f574c75b3fbcdaed59 Mon Sep 17 00:00:00 2001 From: uta <122957026+uta0x89@users.noreply.github.com> Date: Tue, 27 Jun 2023 06:38:16 +0900 Subject: [PATCH] Retry ServiceUnavailableError (#4789) Co-authored-by: merwanehamadi --- autogpt/llm/providers/openai.py | 16 ++++++------ tests/unit/test_retry_provider_openai.py | 31 +++++++++++++++++++++--- 2 files changed, 37 insertions(+), 10 deletions(-) diff --git a/autogpt/llm/providers/openai.py b/autogpt/llm/providers/openai.py index 3c16f5cf..397b4791 100644 --- a/autogpt/llm/providers/openai.py +++ b/autogpt/llm/providers/openai.py @@ -9,7 +9,7 @@ from unittest.mock import patch import openai import openai.api_resources.abstract.engine_api_resource as engine_api_resource from colorama import Fore, Style -from openai.error import APIError, RateLimitError, Timeout +from openai.error import APIError, RateLimitError, ServiceUnavailableError, Timeout from openai.openai_object import OpenAIObject if TYPE_CHECKING: @@ -163,7 +163,10 @@ def retry_api( backoff_base float: Base for exponential backoff. Defaults to 2. warn_user bool: Whether to warn the user. Defaults to True. """ - retry_limit_msg = f"{Fore.RED}Error: " f"Reached rate limit, passing...{Fore.RESET}" + error_messages = { + ServiceUnavailableError: f"{Fore.RED}Error: The OpenAI API engine is currently overloaded, passing...{Fore.RESET}", + RateLimitError: f"{Fore.RED}Error: Reached rate limit, passing...{Fore.RESET}", + } api_key_error_msg = ( f"Please double check that you have setup a " f"{Fore.CYAN + Style.BRIGHT}PAID{Style.RESET_ALL} OpenAI API Account. You can " @@ -182,19 +185,18 @@ def retry_api( try: return func(*args, **kwargs) - except RateLimitError: + except (RateLimitError, ServiceUnavailableError) as e: if attempt == num_attempts: raise - logger.debug(retry_limit_msg) + error_msg = error_messages[type(e)] + logger.debug(error_msg) if not user_warned: logger.double_check(api_key_error_msg) user_warned = True except (APIError, Timeout) as e: - if (e.http_status not in [429, 502, 503]) or ( - attempt == num_attempts - ): + if (e.http_status not in [429, 502]) or (attempt == num_attempts): raise backoff = backoff_base ** (attempt + 2) diff --git a/tests/unit/test_retry_provider_openai.py b/tests/unit/test_retry_provider_openai.py index f8162eb8..b2c2d04a 100644 --- a/tests/unit/test_retry_provider_openai.py +++ b/tests/unit/test_retry_provider_openai.py @@ -1,10 +1,10 @@ import pytest -from openai.error import APIError, RateLimitError +from openai.error import APIError, RateLimitError, ServiceUnavailableError from autogpt.llm.providers import openai -@pytest.fixture(params=[RateLimitError, APIError]) +@pytest.fixture(params=[RateLimitError, ServiceUnavailableError, APIError]) def error(request): if request.param == APIError: return request.param("Error", http_status=502) @@ -52,7 +52,7 @@ def test_retry_open_api_no_error(capsys): ids=["passing", "passing_edge", "failing", "failing_edge", "failing_no_retries"], ) def test_retry_open_api_passing(capsys, error, error_count, retry_count, failure): - """Tests the retry with simulated errors [RateLimitError, APIError], but should ulimately pass""" + """Tests the retry with simulated errors [RateLimitError, ServiceUnavailableError, APIError], but should ulimately pass""" call_count = min(error_count, retry_count) + 1 raises = error_factory(error, error_count, retry_count) @@ -71,6 +71,12 @@ def test_retry_open_api_passing(capsys, error, error_count, retry_count, failure if type(error) == RateLimitError: assert "Reached rate limit, passing..." in output.out assert "Please double check" in output.out + if type(error) == ServiceUnavailableError: + assert ( + "The OpenAI API engine is currently overloaded, passing..." + in output.out + ) + assert "Please double check" in output.out if type(error) == APIError: assert "API Bad gateway" in output.out else: @@ -94,6 +100,25 @@ def test_retry_open_api_rate_limit_no_warn(capsys): assert "Please double check" not in output.out +def test_retry_open_api_service_unavairable_no_warn(capsys): + """Tests the retry logic with a service unavairable error""" + error_count = 2 + retry_count = 10 + + raises = error_factory( + ServiceUnavailableError, error_count, retry_count, warn_user=False + ) + result = raises() + call_count = min(error_count, retry_count) + 1 + assert result == call_count + assert raises.count == call_count + + output = capsys.readouterr() + + assert "The OpenAI API engine is currently overloaded, passing..." in output.out + assert "Please double check" not in output.out + + def test_retry_openapi_other_api_error(capsys): """Tests the Retry logic with a non rate limit error such as HTTP500""" error_count = 2