Huggingface retry generate_image with delay (#2745)

Co-authored-by: Media <12145726+rihp@users.noreply.github.com>
Co-authored-by: Nicholas Tindle <nick@ntindle.com>
Co-authored-by: Nicholas Tindle <nicktindle@outlook.com>
Co-authored-by: k-boikov <64261260+k-boikov@users.noreply.github.com>
Co-authored-by: merwanehamadi <merwanehamadi@gmail.com>
Co-authored-by; lc0rp
This commit is contained in:
Kory Becker
2023-05-16 11:02:55 -04:00
committed by GitHub
parent c1cd54d1ea
commit f424fac1d8
2 changed files with 143 additions and 15 deletions

View File

@@ -1,5 +1,7 @@
""" Image Generation Module for AutoGPT."""
import io
import json
import time
import uuid
from base64 import b64decode
@@ -61,20 +63,42 @@ def generate_image_with_hf(prompt: str, filename: str) -> str:
"X-Use-Cache": "false",
}
response = requests.post(
API_URL,
headers=headers,
json={
"inputs": prompt,
},
)
retry_count = 0
while retry_count < 10:
response = requests.post(
API_URL,
headers=headers,
json={
"inputs": prompt,
},
)
image = Image.open(io.BytesIO(response.content))
logger.info(f"Image Generated for prompt:{prompt}")
if response.ok:
try:
image = Image.open(io.BytesIO(response.content))
logger.info(f"Image Generated for prompt:{prompt}")
image.save(filename)
return f"Saved to disk:{filename}"
except Exception as e:
logger.error(e)
break
else:
try:
error = json.loads(response.text)
if "estimated_time" in error:
delay = error["estimated_time"]
logger.debug(response.text)
logger.info("Retrying in", delay)
time.sleep(delay)
else:
break
except Exception as e:
logger.error(e)
break
image.save(filename)
retry_count += 1
return f"Saved to disk:{filename}"
return f"Error creating image."
def generate_image_with_dalle(prompt: str, filename: str, size: int) -> str:

View File

@@ -6,6 +6,7 @@ import pytest
from PIL import Image
from autogpt.commands.image_gen import generate_image, generate_image_with_sd_webui
from autogpt.config import Config
from tests.utils import requires_api_key
@@ -19,7 +20,7 @@ def image_size(request):
reason="The image is too big to be put in a cassette for a CI pipeline. We're looking into a solution."
)
@requires_api_key("OPENAI_API_KEY")
def test_dalle(config, workspace, image_size, patched_api_requestor):
def test_dalle(config, workspace, image_size):
"""Test DALL-E image generation."""
generate_and_validate(
config,
@@ -48,18 +49,18 @@ def test_huggingface(config, workspace, image_size, image_model):
)
@pytest.mark.skip(reason="External SD WebUI may not be available.")
@pytest.mark.xfail(reason="SD WebUI call does not work.")
def test_sd_webui(config, workspace, image_size):
"""Test SD WebUI image generation."""
generate_and_validate(
config,
workspace,
image_provider="sdwebui",
image_provider="sd_webui",
image_size=image_size,
)
@pytest.mark.skip(reason="External SD WebUI may not be available.")
@pytest.mark.xfail(reason="SD WebUI call does not work.")
def test_sd_webui_negative_prompt(config, workspace, image_size):
gen_image = functools.partial(
generate_image_with_sd_webui,
@@ -103,3 +104,106 @@ def generate_and_validate(
assert image_path.exists()
with Image.open(image_path) as img:
assert img.size == (image_size, image_size)
def test_huggingface_fail_request_with_delay(mocker):
config = Config()
config.huggingface_api_token = "1"
# Mock requests.post
mock_post = mocker.patch("requests.post")
mock_post.return_value.status_code = 500
mock_post.return_value.ok = False
mock_post.return_value.text = '{"error":"Model CompVis/stable-diffusion-v1-4 is currently loading","estimated_time":0}'
# Mock time.sleep
mock_sleep = mocker.patch("time.sleep")
config.image_provider = "huggingface"
config.huggingface_image_model = "CompVis/stable-diffusion-v1-4"
result = generate_image("astronaut riding a horse", 512)
assert result == "Error creating image."
# Verify retry was called with delay.
mock_sleep.assert_called_with(0)
def test_huggingface_fail_request_no_delay(mocker):
config = Config()
config.huggingface_api_token = "1"
# Mock requests.post
mock_post = mocker.patch("requests.post")
mock_post.return_value.status_code = 500
mock_post.return_value.ok = False
mock_post.return_value.text = (
'{"error":"Model CompVis/stable-diffusion-v1-4 is currently loading"}'
)
# Mock time.sleep
mock_sleep = mocker.patch("time.sleep")
config.image_provider = "huggingface"
config.huggingface_image_model = "CompVis/stable-diffusion-v1-4"
result = generate_image("astronaut riding a horse", 512)
assert result == "Error creating image."
# Verify retry was not called.
mock_sleep.assert_not_called()
def test_huggingface_fail_request_bad_json(mocker):
config = Config()
config.huggingface_api_token = "1"
# Mock requests.post
mock_post = mocker.patch("requests.post")
mock_post.return_value.status_code = 500
mock_post.return_value.ok = False
mock_post.return_value.text = '{"error:}'
# Mock time.sleep
mock_sleep = mocker.patch("time.sleep")
config.image_provider = "huggingface"
config.huggingface_image_model = "CompVis/stable-diffusion-v1-4"
result = generate_image("astronaut riding a horse", 512)
assert result == "Error creating image."
# Verify retry was not called.
mock_sleep.assert_not_called()
def test_huggingface_fail_request_bad_image(mocker):
config = Config()
config.huggingface_api_token = "1"
# Mock requests.post
mock_post = mocker.patch("requests.post")
mock_post.return_value.status_code = 200
config.image_provider = "huggingface"
config.huggingface_image_model = "CompVis/stable-diffusion-v1-4"
result = generate_image("astronaut riding a horse", 512)
assert result == "Error creating image."
def test_huggingface_fail_missing_api_token(mocker):
config = Config()
config.image_provider = "huggingface"
config.huggingface_image_model = "CompVis/stable-diffusion-v1-4"
# Mock requests.post to raise ValueError
mock_post = mocker.patch("requests.post", side_effect=ValueError)
# Verify request raises an error.
with pytest.raises(ValueError):
generate_image("astronaut riding a horse", 512)