Files
Auto-GPT/tests/unit/test_retry_provider_openai.py
James Collins 6e6e7fcc9a Extract openai API calls and retry at lowest level (#3696)
* Extract open ai api calls and retry at lowest level

* Forgot a test

* Gotta fix my local docker config so I can let pre-commit hooks run, ugh

* fix: merge artiface

* Fix linting

* Update memory.vector.utils

* feat: make sure resp exists

* fix: raise error message if created

* feat: rename file

* fix: partial test fix

* fix: update comments

* fix: linting

* fix: remove broken test

* fix: require a model to exist

* fix: BaseError issue

* fix: runtime error

* Fix mock response in test_make_agent

* add 429 as errors to retry

---------

Co-authored-by: k-boikov <64261260+k-boikov@users.noreply.github.com>
Co-authored-by: Nicholas Tindle <nick@ntindle.com>
Co-authored-by: Reinier van der Leer <github@pwuts.nl>
Co-authored-by: Nicholas Tindle <nicktindle@outlook.com>
Co-authored-by: Luke K (pr-0f3t) <2609441+lc0rp@users.noreply.github.com>
Co-authored-by: Merwane Hamadi <merwanehamadi@gmail.com>
2023-06-14 07:59:26 -07:00

111 lines
3.1 KiB
Python

import pytest
from openai.error import APIError, RateLimitError
from autogpt.llm.providers import openai
@pytest.fixture(params=[RateLimitError, APIError])
def error(request):
if request.param == APIError:
return request.param("Error", http_status=502)
else:
return request.param("Error")
def error_factory(error_instance, error_count, retry_count, warn_user=True):
"""Creates errors"""
class RaisesError:
def __init__(self):
self.count = 0
@openai.retry_api(
num_retries=retry_count, backoff_base=0.001, warn_user=warn_user
)
def __call__(self):
self.count += 1
if self.count <= error_count:
raise error_instance
return self.count
return RaisesError()
def test_retry_open_api_no_error(capsys):
"""Tests the retry functionality with no errors expected"""
@openai.retry_api()
def f():
return 1
result = f()
assert result == 1
output = capsys.readouterr()
assert output.out == ""
assert output.err == ""
@pytest.mark.parametrize(
"error_count, retry_count, failure",
[(2, 10, False), (2, 2, False), (10, 2, True), (3, 2, True), (1, 0, True)],
ids=["passing", "passing_edge", "failing", "failing_edge", "failing_no_retries"],
)
def test_retry_open_api_passing(capsys, error, error_count, retry_count, failure):
"""Tests the retry with simulated errors [RateLimitError, APIError], but should ulimately pass"""
call_count = min(error_count, retry_count) + 1
raises = error_factory(error, error_count, retry_count)
if failure:
with pytest.raises(type(error)):
raises()
else:
result = raises()
assert result == call_count
assert raises.count == call_count
output = capsys.readouterr()
if error_count and retry_count:
if type(error) == RateLimitError:
assert "Reached rate limit, passing..." in output.out
assert "Please double check" in output.out
if type(error) == APIError:
assert "API Bad gateway" in output.out
else:
assert output.out == ""
def test_retry_open_api_rate_limit_no_warn(capsys):
"""Tests the retry logic with a rate limit error"""
error_count = 2
retry_count = 10
raises = error_factory(RateLimitError, error_count, retry_count, warn_user=False)
result = raises()
call_count = min(error_count, retry_count) + 1
assert result == call_count
assert raises.count == call_count
output = capsys.readouterr()
assert "Reached rate limit, passing..." in output.out
assert "Please double check" not in output.out
def test_retry_openapi_other_api_error(capsys):
"""Tests the Retry logic with a non rate limit error such as HTTP500"""
error_count = 2
retry_count = 10
raises = error_factory(APIError("Error", http_status=500), error_count, retry_count)
with pytest.raises(APIError):
raises()
call_count = 1
assert raises.count == call_count
output = capsys.readouterr()
assert output.out == ""