From ee986412108933b525b9b09089d92d7e8154fe8c Mon Sep 17 00:00:00 2001 From: "Luke K (pr-0f3t)" <2609441+lc0rp@users.noreply.github.com> Date: Fri, 19 May 2023 13:19:39 -0400 Subject: [PATCH] Imagegen delay retry huggingface (#4194) Co-authored-by: Kory Becker Co-authored-by: Nicholas Tindle Co-authored-by: Nicholas Tindle Co-authored-by: k-boikov <64261260+k-boikov@users.noreply.github.com> --- tests/test_image_gen.py | 50 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/tests/test_image_gen.py b/tests/test_image_gen.py index a48f26b9..a18855e7 100644 --- a/tests/test_image_gen.py +++ b/tests/test_image_gen.py @@ -1,6 +1,7 @@ import functools import hashlib from pathlib import Path +from unittest.mock import patch import pytest from PIL import Image @@ -106,6 +107,55 @@ def generate_and_validate( assert img.size == (image_size, image_size) +@pytest.mark.parametrize( + "return_text", + [ + '{"error":"Model [model] is currently loading","estimated_time": [delay]}', # Delay + '{"error":"Model [model] is currently loading"}', # No delay + '{"error:}', # Bad JSON + "", # Bad Image + ], +) +@pytest.mark.parametrize( + "image_model", + ["CompVis/stable-diffusion-v1-4", "stabilityai/stable-diffusion-2-1"], +) +@pytest.mark.parametrize("delay", [10, 0]) +def test_huggingface_fail_request_with_delay( + config, workspace, image_size, image_model, return_text, delay +): + return_text = return_text.replace("[model]", image_model).replace( + "[delay]", str(delay) + ) + + with patch("requests.post") as mock_post: + if return_text == "": + # Test bad image + mock_post.return_value.status_code = 200 + mock_post.return_value.ok = True + mock_post.return_value.content = b"bad image" + else: + # Test delay and bad json + mock_post.return_value.status_code = 500 + mock_post.return_value.ok = False + mock_post.return_value.text = return_text + + config.image_provider = "huggingface" + config.huggingface_image_model = image_model + prompt = "astronaut riding a horse" + + with patch("time.sleep") as mock_sleep: + # Verify request fails. + result = generate_image(prompt, image_size) + assert result == "Error creating image." + + # Verify retry was called with delay if delay is in return_text + if "estimated_time" in return_text: + mock_sleep.assert_called_with(delay) + else: + mock_sleep.assert_not_called() + + def test_huggingface_fail_request_with_delay(mocker): config = Config() config.huggingface_api_token = "1"