mirror of
https://github.com/aljazceru/Auto-GPT.git
synced 2026-02-07 07:14:33 +01:00
Huggingface retry generate_image with delay (#2745)
Co-authored-by: Media <12145726+rihp@users.noreply.github.com> Co-authored-by: Nicholas Tindle <nick@ntindle.com> Co-authored-by: Nicholas Tindle <nicktindle@outlook.com> Co-authored-by: k-boikov <64261260+k-boikov@users.noreply.github.com> Co-authored-by: merwanehamadi <merwanehamadi@gmail.com> Co-authored-by; lc0rp
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
""" Image Generation Module for AutoGPT."""
|
||||
import io
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from base64 import b64decode
|
||||
|
||||
@@ -61,20 +63,42 @@ def generate_image_with_hf(prompt: str, filename: str) -> str:
|
||||
"X-Use-Cache": "false",
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
API_URL,
|
||||
headers=headers,
|
||||
json={
|
||||
"inputs": prompt,
|
||||
},
|
||||
)
|
||||
retry_count = 0
|
||||
while retry_count < 10:
|
||||
response = requests.post(
|
||||
API_URL,
|
||||
headers=headers,
|
||||
json={
|
||||
"inputs": prompt,
|
||||
},
|
||||
)
|
||||
|
||||
image = Image.open(io.BytesIO(response.content))
|
||||
logger.info(f"Image Generated for prompt:{prompt}")
|
||||
if response.ok:
|
||||
try:
|
||||
image = Image.open(io.BytesIO(response.content))
|
||||
logger.info(f"Image Generated for prompt:{prompt}")
|
||||
image.save(filename)
|
||||
return f"Saved to disk:{filename}"
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
break
|
||||
else:
|
||||
try:
|
||||
error = json.loads(response.text)
|
||||
if "estimated_time" in error:
|
||||
delay = error["estimated_time"]
|
||||
logger.debug(response.text)
|
||||
logger.info("Retrying in", delay)
|
||||
time.sleep(delay)
|
||||
else:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
break
|
||||
|
||||
image.save(filename)
|
||||
retry_count += 1
|
||||
|
||||
return f"Saved to disk:{filename}"
|
||||
return f"Error creating image."
|
||||
|
||||
|
||||
def generate_image_with_dalle(prompt: str, filename: str, size: int) -> str:
|
||||
|
||||
@@ -6,6 +6,7 @@ import pytest
|
||||
from PIL import Image
|
||||
|
||||
from autogpt.commands.image_gen import generate_image, generate_image_with_sd_webui
|
||||
from autogpt.config import Config
|
||||
from tests.utils import requires_api_key
|
||||
|
||||
|
||||
@@ -19,7 +20,7 @@ def image_size(request):
|
||||
reason="The image is too big to be put in a cassette for a CI pipeline. We're looking into a solution."
|
||||
)
|
||||
@requires_api_key("OPENAI_API_KEY")
|
||||
def test_dalle(config, workspace, image_size, patched_api_requestor):
|
||||
def test_dalle(config, workspace, image_size):
|
||||
"""Test DALL-E image generation."""
|
||||
generate_and_validate(
|
||||
config,
|
||||
@@ -48,18 +49,18 @@ def test_huggingface(config, workspace, image_size, image_model):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="External SD WebUI may not be available.")
|
||||
@pytest.mark.xfail(reason="SD WebUI call does not work.")
|
||||
def test_sd_webui(config, workspace, image_size):
|
||||
"""Test SD WebUI image generation."""
|
||||
generate_and_validate(
|
||||
config,
|
||||
workspace,
|
||||
image_provider="sdwebui",
|
||||
image_provider="sd_webui",
|
||||
image_size=image_size,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="External SD WebUI may not be available.")
|
||||
@pytest.mark.xfail(reason="SD WebUI call does not work.")
|
||||
def test_sd_webui_negative_prompt(config, workspace, image_size):
|
||||
gen_image = functools.partial(
|
||||
generate_image_with_sd_webui,
|
||||
@@ -103,3 +104,106 @@ def generate_and_validate(
|
||||
assert image_path.exists()
|
||||
with Image.open(image_path) as img:
|
||||
assert img.size == (image_size, image_size)
|
||||
|
||||
|
||||
def test_huggingface_fail_request_with_delay(mocker):
|
||||
config = Config()
|
||||
config.huggingface_api_token = "1"
|
||||
|
||||
# Mock requests.post
|
||||
mock_post = mocker.patch("requests.post")
|
||||
mock_post.return_value.status_code = 500
|
||||
mock_post.return_value.ok = False
|
||||
mock_post.return_value.text = '{"error":"Model CompVis/stable-diffusion-v1-4 is currently loading","estimated_time":0}'
|
||||
|
||||
# Mock time.sleep
|
||||
mock_sleep = mocker.patch("time.sleep")
|
||||
|
||||
config.image_provider = "huggingface"
|
||||
config.huggingface_image_model = "CompVis/stable-diffusion-v1-4"
|
||||
|
||||
result = generate_image("astronaut riding a horse", 512)
|
||||
|
||||
assert result == "Error creating image."
|
||||
|
||||
# Verify retry was called with delay.
|
||||
mock_sleep.assert_called_with(0)
|
||||
|
||||
|
||||
def test_huggingface_fail_request_no_delay(mocker):
|
||||
config = Config()
|
||||
config.huggingface_api_token = "1"
|
||||
|
||||
# Mock requests.post
|
||||
mock_post = mocker.patch("requests.post")
|
||||
mock_post.return_value.status_code = 500
|
||||
mock_post.return_value.ok = False
|
||||
mock_post.return_value.text = (
|
||||
'{"error":"Model CompVis/stable-diffusion-v1-4 is currently loading"}'
|
||||
)
|
||||
|
||||
# Mock time.sleep
|
||||
mock_sleep = mocker.patch("time.sleep")
|
||||
|
||||
config.image_provider = "huggingface"
|
||||
config.huggingface_image_model = "CompVis/stable-diffusion-v1-4"
|
||||
|
||||
result = generate_image("astronaut riding a horse", 512)
|
||||
|
||||
assert result == "Error creating image."
|
||||
|
||||
# Verify retry was not called.
|
||||
mock_sleep.assert_not_called()
|
||||
|
||||
|
||||
def test_huggingface_fail_request_bad_json(mocker):
|
||||
config = Config()
|
||||
config.huggingface_api_token = "1"
|
||||
|
||||
# Mock requests.post
|
||||
mock_post = mocker.patch("requests.post")
|
||||
mock_post.return_value.status_code = 500
|
||||
mock_post.return_value.ok = False
|
||||
mock_post.return_value.text = '{"error:}'
|
||||
|
||||
# Mock time.sleep
|
||||
mock_sleep = mocker.patch("time.sleep")
|
||||
|
||||
config.image_provider = "huggingface"
|
||||
config.huggingface_image_model = "CompVis/stable-diffusion-v1-4"
|
||||
|
||||
result = generate_image("astronaut riding a horse", 512)
|
||||
|
||||
assert result == "Error creating image."
|
||||
|
||||
# Verify retry was not called.
|
||||
mock_sleep.assert_not_called()
|
||||
|
||||
|
||||
def test_huggingface_fail_request_bad_image(mocker):
|
||||
config = Config()
|
||||
config.huggingface_api_token = "1"
|
||||
|
||||
# Mock requests.post
|
||||
mock_post = mocker.patch("requests.post")
|
||||
mock_post.return_value.status_code = 200
|
||||
|
||||
config.image_provider = "huggingface"
|
||||
config.huggingface_image_model = "CompVis/stable-diffusion-v1-4"
|
||||
|
||||
result = generate_image("astronaut riding a horse", 512)
|
||||
|
||||
assert result == "Error creating image."
|
||||
|
||||
|
||||
def test_huggingface_fail_missing_api_token(mocker):
|
||||
config = Config()
|
||||
config.image_provider = "huggingface"
|
||||
config.huggingface_image_model = "CompVis/stable-diffusion-v1-4"
|
||||
|
||||
# Mock requests.post to raise ValueError
|
||||
mock_post = mocker.patch("requests.post", side_effect=ValueError)
|
||||
|
||||
# Verify request raises an error.
|
||||
with pytest.raises(ValueError):
|
||||
generate_image("astronaut riding a horse", 512)
|
||||
|
||||
Reference in New Issue
Block a user