diff --git a/autogpt/commands/image_gen.py b/autogpt/commands/image_gen.py index fbed067f..5326cd52 100644 --- a/autogpt/commands/image_gen.py +++ b/autogpt/commands/image_gen.py @@ -1,5 +1,7 @@ """ Image Generation Module for AutoGPT.""" import io +import json +import time import uuid from base64 import b64decode @@ -61,20 +63,42 @@ def generate_image_with_hf(prompt: str, filename: str) -> str: "X-Use-Cache": "false", } - response = requests.post( - API_URL, - headers=headers, - json={ - "inputs": prompt, - }, - ) + retry_count = 0 + while retry_count < 10: + response = requests.post( + API_URL, + headers=headers, + json={ + "inputs": prompt, + }, + ) - image = Image.open(io.BytesIO(response.content)) - logger.info(f"Image Generated for prompt:{prompt}") + if response.ok: + try: + image = Image.open(io.BytesIO(response.content)) + logger.info(f"Image Generated for prompt:{prompt}") + image.save(filename) + return f"Saved to disk:{filename}" + except Exception as e: + logger.error(e) + break + else: + try: + error = json.loads(response.text) + if "estimated_time" in error: + delay = error["estimated_time"] + logger.debug(response.text) + logger.info("Retrying in", delay) + time.sleep(delay) + else: + break + except Exception as e: + logger.error(e) + break - image.save(filename) + retry_count += 1 - return f"Saved to disk:{filename}" + return f"Error creating image." def generate_image_with_dalle(prompt: str, filename: str, size: int) -> str: diff --git a/tests/test_image_gen.py b/tests/test_image_gen.py index 136fb510..a48f26b9 100644 --- a/tests/test_image_gen.py +++ b/tests/test_image_gen.py @@ -6,6 +6,7 @@ import pytest from PIL import Image from autogpt.commands.image_gen import generate_image, generate_image_with_sd_webui +from autogpt.config import Config from tests.utils import requires_api_key @@ -19,7 +20,7 @@ def image_size(request): reason="The image is too big to be put in a cassette for a CI pipeline. We're looking into a solution." ) @requires_api_key("OPENAI_API_KEY") -def test_dalle(config, workspace, image_size, patched_api_requestor): +def test_dalle(config, workspace, image_size): """Test DALL-E image generation.""" generate_and_validate( config, @@ -48,18 +49,18 @@ def test_huggingface(config, workspace, image_size, image_model): ) -@pytest.mark.skip(reason="External SD WebUI may not be available.") +@pytest.mark.xfail(reason="SD WebUI call does not work.") def test_sd_webui(config, workspace, image_size): """Test SD WebUI image generation.""" generate_and_validate( config, workspace, - image_provider="sdwebui", + image_provider="sd_webui", image_size=image_size, ) -@pytest.mark.skip(reason="External SD WebUI may not be available.") +@pytest.mark.xfail(reason="SD WebUI call does not work.") def test_sd_webui_negative_prompt(config, workspace, image_size): gen_image = functools.partial( generate_image_with_sd_webui, @@ -103,3 +104,106 @@ def generate_and_validate( assert image_path.exists() with Image.open(image_path) as img: assert img.size == (image_size, image_size) + + +def test_huggingface_fail_request_with_delay(mocker): + config = Config() + config.huggingface_api_token = "1" + + # Mock requests.post + mock_post = mocker.patch("requests.post") + mock_post.return_value.status_code = 500 + mock_post.return_value.ok = False + mock_post.return_value.text = '{"error":"Model CompVis/stable-diffusion-v1-4 is currently loading","estimated_time":0}' + + # Mock time.sleep + mock_sleep = mocker.patch("time.sleep") + + config.image_provider = "huggingface" + config.huggingface_image_model = "CompVis/stable-diffusion-v1-4" + + result = generate_image("astronaut riding a horse", 512) + + assert result == "Error creating image." + + # Verify retry was called with delay. + mock_sleep.assert_called_with(0) + + +def test_huggingface_fail_request_no_delay(mocker): + config = Config() + config.huggingface_api_token = "1" + + # Mock requests.post + mock_post = mocker.patch("requests.post") + mock_post.return_value.status_code = 500 + mock_post.return_value.ok = False + mock_post.return_value.text = ( + '{"error":"Model CompVis/stable-diffusion-v1-4 is currently loading"}' + ) + + # Mock time.sleep + mock_sleep = mocker.patch("time.sleep") + + config.image_provider = "huggingface" + config.huggingface_image_model = "CompVis/stable-diffusion-v1-4" + + result = generate_image("astronaut riding a horse", 512) + + assert result == "Error creating image." + + # Verify retry was not called. + mock_sleep.assert_not_called() + + +def test_huggingface_fail_request_bad_json(mocker): + config = Config() + config.huggingface_api_token = "1" + + # Mock requests.post + mock_post = mocker.patch("requests.post") + mock_post.return_value.status_code = 500 + mock_post.return_value.ok = False + mock_post.return_value.text = '{"error:}' + + # Mock time.sleep + mock_sleep = mocker.patch("time.sleep") + + config.image_provider = "huggingface" + config.huggingface_image_model = "CompVis/stable-diffusion-v1-4" + + result = generate_image("astronaut riding a horse", 512) + + assert result == "Error creating image." + + # Verify retry was not called. + mock_sleep.assert_not_called() + + +def test_huggingface_fail_request_bad_image(mocker): + config = Config() + config.huggingface_api_token = "1" + + # Mock requests.post + mock_post = mocker.patch("requests.post") + mock_post.return_value.status_code = 200 + + config.image_provider = "huggingface" + config.huggingface_image_model = "CompVis/stable-diffusion-v1-4" + + result = generate_image("astronaut riding a horse", 512) + + assert result == "Error creating image." + + +def test_huggingface_fail_missing_api_token(mocker): + config = Config() + config.image_provider = "huggingface" + config.huggingface_image_model = "CompVis/stable-diffusion-v1-4" + + # Mock requests.post to raise ValueError + mock_post = mocker.patch("requests.post", side_effect=ValueError) + + # Verify request raises an error. + with pytest.raises(ValueError): + generate_image("astronaut riding a horse", 512)