mirror of
https://github.com/Stability-AI/generative-models.git
synced 2025-12-19 14:24:21 +01:00
* Add inference helpers & tests * Support testing with hatch * fixes to hatch script * add inference test action * change workflow trigger * widen trigger to test * revert changes to workflow triggers * Install local python in action * Trigger on push again * fix python version * add CODEOWNERS and change triggers * Report tests results * update action versions * format * Fix typo and add refiner helper * use a shared path loaded from a secret for checkpoints source * typo fix * Use device from input and remove duplicated code * PR feedback * fix call to load_model_from_config * Move model to gpu * Refactor helpers * cleanup * test refiner, prep for 1.0, align with metadata * fix paths on second load * deduplicate streamlit code * filenames * fixes * add pydantic to requirements * fix usage of `msg` in demo script * remove double text * run black * fix streamlit sampling when returning latents * extract function for streamlit output * another fix for streamlit outputs * fix img2img in streamlit * Make fp16 optional and fix device param * PR feedback * fix dict cast for dataclass * run black, update ci script * cache pip dependencies on hosted runners, remove extra runs * install package in ci env * fix cache path * PR cleanup * one more cleanup * don't cache, it filled up
112 lines
3.9 KiB
Python
112 lines
3.9 KiB
Python
import numpy
|
|
from PIL import Image
|
|
import pytest
|
|
from pytest import fixture
|
|
import torch
|
|
from typing import Tuple
|
|
|
|
from sgm.inference.api import (
|
|
model_specs,
|
|
SamplingParams,
|
|
SamplingPipeline,
|
|
Sampler,
|
|
ModelArchitecture,
|
|
)
|
|
import sgm.inference.helpers as helpers
|
|
|
|
|
|
@pytest.mark.inference
|
|
class TestInference:
|
|
@fixture(scope="class", params=model_specs.keys())
|
|
def pipeline(self, request) -> SamplingPipeline:
|
|
pipeline = SamplingPipeline(request.param)
|
|
yield pipeline
|
|
del pipeline
|
|
torch.cuda.empty_cache()
|
|
|
|
@fixture(
|
|
scope="class",
|
|
params=[
|
|
[ModelArchitecture.SDXL_V1_BASE, ModelArchitecture.SDXL_V1_REFINER],
|
|
[ModelArchitecture.SDXL_V0_9_BASE, ModelArchitecture.SDXL_V0_9_REFINER],
|
|
],
|
|
ids=["SDXL_V1", "SDXL_V0_9"],
|
|
)
|
|
def sdxl_pipelines(self, request) -> Tuple[SamplingPipeline, SamplingPipeline]:
|
|
base_pipeline = SamplingPipeline(request.param[0])
|
|
refiner_pipeline = SamplingPipeline(request.param[1])
|
|
yield base_pipeline, refiner_pipeline
|
|
del base_pipeline
|
|
del refiner_pipeline
|
|
torch.cuda.empty_cache()
|
|
|
|
def create_init_image(self, h, w):
|
|
image_array = numpy.random.rand(h, w, 3) * 255
|
|
image = Image.fromarray(image_array.astype("uint8")).convert("RGB")
|
|
return helpers.get_input_image_tensor(image)
|
|
|
|
@pytest.mark.parametrize("sampler_enum", Sampler)
|
|
def test_txt2img(self, pipeline: SamplingPipeline, sampler_enum):
|
|
output = pipeline.text_to_image(
|
|
params=SamplingParams(sampler=sampler_enum.value, steps=10),
|
|
prompt="A professional photograph of an astronaut riding a pig",
|
|
negative_prompt="",
|
|
samples=1,
|
|
)
|
|
|
|
assert output is not None
|
|
|
|
@pytest.mark.parametrize("sampler_enum", Sampler)
|
|
def test_img2img(self, pipeline: SamplingPipeline, sampler_enum):
|
|
output = pipeline.image_to_image(
|
|
params=SamplingParams(sampler=sampler_enum.value, steps=10),
|
|
image=self.create_init_image(pipeline.specs.height, pipeline.specs.width),
|
|
prompt="A professional photograph of an astronaut riding a pig",
|
|
negative_prompt="",
|
|
samples=1,
|
|
)
|
|
assert output is not None
|
|
|
|
@pytest.mark.parametrize("sampler_enum", Sampler)
|
|
@pytest.mark.parametrize(
|
|
"use_init_image", [True, False], ids=["img2img", "txt2img"]
|
|
)
|
|
def test_sdxl_with_refiner(
|
|
self,
|
|
sdxl_pipelines: Tuple[SamplingPipeline, SamplingPipeline],
|
|
sampler_enum,
|
|
use_init_image,
|
|
):
|
|
base_pipeline, refiner_pipeline = sdxl_pipelines
|
|
if use_init_image:
|
|
output = base_pipeline.image_to_image(
|
|
params=SamplingParams(sampler=sampler_enum.value, steps=10),
|
|
image=self.create_init_image(
|
|
base_pipeline.specs.height, base_pipeline.specs.width
|
|
),
|
|
prompt="A professional photograph of an astronaut riding a pig",
|
|
negative_prompt="",
|
|
samples=1,
|
|
return_latents=True,
|
|
)
|
|
else:
|
|
output = base_pipeline.text_to_image(
|
|
params=SamplingParams(sampler=sampler_enum.value, steps=10),
|
|
prompt="A professional photograph of an astronaut riding a pig",
|
|
negative_prompt="",
|
|
samples=1,
|
|
return_latents=True,
|
|
)
|
|
|
|
assert isinstance(output, (tuple, list))
|
|
samples, samples_z = output
|
|
assert samples is not None
|
|
assert samples_z is not None
|
|
refiner_pipeline.refiner(
|
|
params=SamplingParams(sampler=sampler_enum.value, steps=10),
|
|
image=samples_z,
|
|
prompt="A professional photograph of an astronaut riding a pig",
|
|
negative_prompt="",
|
|
samples=1,
|
|
)
|