mirror of
https://github.com/aljazceru/Auto-GPT.git
synced 2026-02-21 14:14:40 +01:00
Retry ServiceUnavailableError (#4789)
Co-authored-by: merwanehamadi <merwanehamadi@gmail.com>
This commit is contained in:
@@ -9,7 +9,7 @@ from unittest.mock import patch
|
||||
import openai
|
||||
import openai.api_resources.abstract.engine_api_resource as engine_api_resource
|
||||
from colorama import Fore, Style
|
||||
from openai.error import APIError, RateLimitError, Timeout
|
||||
from openai.error import APIError, RateLimitError, ServiceUnavailableError, Timeout
|
||||
from openai.openai_object import OpenAIObject
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -163,7 +163,10 @@ def retry_api(
|
||||
backoff_base float: Base for exponential backoff. Defaults to 2.
|
||||
warn_user bool: Whether to warn the user. Defaults to True.
|
||||
"""
|
||||
retry_limit_msg = f"{Fore.RED}Error: " f"Reached rate limit, passing...{Fore.RESET}"
|
||||
error_messages = {
|
||||
ServiceUnavailableError: f"{Fore.RED}Error: The OpenAI API engine is currently overloaded, passing...{Fore.RESET}",
|
||||
RateLimitError: f"{Fore.RED}Error: Reached rate limit, passing...{Fore.RESET}",
|
||||
}
|
||||
api_key_error_msg = (
|
||||
f"Please double check that you have setup a "
|
||||
f"{Fore.CYAN + Style.BRIGHT}PAID{Style.RESET_ALL} OpenAI API Account. You can "
|
||||
@@ -182,19 +185,18 @@ def retry_api(
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
except RateLimitError:
|
||||
except (RateLimitError, ServiceUnavailableError) as e:
|
||||
if attempt == num_attempts:
|
||||
raise
|
||||
|
||||
logger.debug(retry_limit_msg)
|
||||
error_msg = error_messages[type(e)]
|
||||
logger.debug(error_msg)
|
||||
if not user_warned:
|
||||
logger.double_check(api_key_error_msg)
|
||||
user_warned = True
|
||||
|
||||
except (APIError, Timeout) as e:
|
||||
if (e.http_status not in [429, 502, 503]) or (
|
||||
attempt == num_attempts
|
||||
):
|
||||
if (e.http_status not in [429, 502]) or (attempt == num_attempts):
|
||||
raise
|
||||
|
||||
backoff = backoff_base ** (attempt + 2)
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import pytest
|
||||
from openai.error import APIError, RateLimitError
|
||||
from openai.error import APIError, RateLimitError, ServiceUnavailableError
|
||||
|
||||
from autogpt.llm.providers import openai
|
||||
|
||||
|
||||
@pytest.fixture(params=[RateLimitError, APIError])
|
||||
@pytest.fixture(params=[RateLimitError, ServiceUnavailableError, APIError])
|
||||
def error(request):
|
||||
if request.param == APIError:
|
||||
return request.param("Error", http_status=502)
|
||||
@@ -52,7 +52,7 @@ def test_retry_open_api_no_error(capsys):
|
||||
ids=["passing", "passing_edge", "failing", "failing_edge", "failing_no_retries"],
|
||||
)
|
||||
def test_retry_open_api_passing(capsys, error, error_count, retry_count, failure):
|
||||
"""Tests the retry with simulated errors [RateLimitError, APIError], but should ulimately pass"""
|
||||
"""Tests the retry with simulated errors [RateLimitError, ServiceUnavailableError, APIError], but should ulimately pass"""
|
||||
call_count = min(error_count, retry_count) + 1
|
||||
|
||||
raises = error_factory(error, error_count, retry_count)
|
||||
@@ -71,6 +71,12 @@ def test_retry_open_api_passing(capsys, error, error_count, retry_count, failure
|
||||
if type(error) == RateLimitError:
|
||||
assert "Reached rate limit, passing..." in output.out
|
||||
assert "Please double check" in output.out
|
||||
if type(error) == ServiceUnavailableError:
|
||||
assert (
|
||||
"The OpenAI API engine is currently overloaded, passing..."
|
||||
in output.out
|
||||
)
|
||||
assert "Please double check" in output.out
|
||||
if type(error) == APIError:
|
||||
assert "API Bad gateway" in output.out
|
||||
else:
|
||||
@@ -94,6 +100,25 @@ def test_retry_open_api_rate_limit_no_warn(capsys):
|
||||
assert "Please double check" not in output.out
|
||||
|
||||
|
||||
def test_retry_open_api_service_unavairable_no_warn(capsys):
|
||||
"""Tests the retry logic with a service unavairable error"""
|
||||
error_count = 2
|
||||
retry_count = 10
|
||||
|
||||
raises = error_factory(
|
||||
ServiceUnavailableError, error_count, retry_count, warn_user=False
|
||||
)
|
||||
result = raises()
|
||||
call_count = min(error_count, retry_count) + 1
|
||||
assert result == call_count
|
||||
assert raises.count == call_count
|
||||
|
||||
output = capsys.readouterr()
|
||||
|
||||
assert "The OpenAI API engine is currently overloaded, passing..." in output.out
|
||||
assert "Please double check" not in output.out
|
||||
|
||||
|
||||
def test_retry_openapi_other_api_error(capsys):
|
||||
"""Tests the Retry logic with a non rate limit error such as HTTP500"""
|
||||
error_count = 2
|
||||
|
||||
Reference in New Issue
Block a user