Add inference helpers & tests (#57)

* Add inference helpers & tests

* Support testing with hatch

* fixes to hatch script

* add inference test action

* change workflow trigger

* widen trigger to test

* revert changes to workflow triggers

* Install local python in action

* Trigger on push again

* fix python version

* add CODEOWNERS and change triggers

* Report tests results

* update action versions

* format

* Fix typo and add refiner helper

* use a shared path loaded from a secret for checkpoints source

* typo fix

* Use device from input and remove duplicated code

* PR feedback

* fix call to load_model_from_config

* Move model to gpu

* Refactor helpers

* cleanup

* test refiner, prep for 1.0, align with metadata

* fix paths on second load

* deduplicate streamlit code

* filenames

* fixes

* add pydantic to requirements

* fix usage of `msg` in demo script

* remove double text

* run black

* fix streamlit sampling when returning latents

* extract function for streamlit output

* another fix for streamlit outputs

* fix img2img in streamlit

* Make fp16 optional and fix device param

* PR feedback

* fix dict cast for dataclass

* run black, update ci script

* cache pip dependencies on hosted runners, remove extra runs

* install package in ci env

* fix cache path

* PR cleanup

* one more cleanup

* don't cache, it filled up
This commit is contained in:
Stephan Auerhahn
2023-07-26 04:37:24 -07:00
committed by GitHub
parent e596332148
commit 931d7a389a
11 changed files with 889 additions and 346 deletions

1
.github/workflows/CODEOWNERS vendored Normal file
View File

@@ -0,0 +1 @@
.github @Stability-AI/infrastructure

View File

@@ -1,5 +1,5 @@
name: Run black name: Run black
on: [push, pull_request] on: [pull_request]
jobs: jobs:
lint: lint:

View File

@@ -2,6 +2,7 @@ name: Build package
on: on:
push: push:
branches: [ main ]
pull_request: pull_request:
jobs: jobs:

34
.github/workflows/test-inference.yml vendored Normal file
View File

@@ -0,0 +1,34 @@
name: Test inference
on:
pull_request:
push:
branches:
- main
jobs:
test:
name: "Test inference"
# This action is designed only to run on the Stability research cluster at this time, so many assumptions are made about the environment
if: github.repository == 'stability-ai/generative-models'
runs-on: [self-hosted, slurm, g40]
steps:
- uses: actions/checkout@v3
- name: "Symlink checkpoints"
run: ln -s ${{secrets.SGM_CHECKPOINTS_PATH}} checkpoints
- name: "Setup python"
uses: actions/setup-python@v4
with:
python-version: "3.10"
- name: "Install Hatch"
run: pip install hatch
- name: "Run inference tests"
run: hatch run ci:test-inference --junit-xml test-results.xml
- name: Surface failing tests
if: always()
uses: pmeier/pytest-results-action@main
with:
path: test-results.xml
summary: true
display-options: fEX
fail-on-empty: true

View File

@@ -32,3 +32,17 @@ include = [
[tool.hatch.build.targets.wheel.force-include] [tool.hatch.build.targets.wheel.force-include]
"./configs" = "sgm/configs" "./configs" = "sgm/configs"
[tool.hatch.envs.ci]
skip-install = false
dependencies = [
"pytest"
]
[tool.hatch.envs.ci.scripts]
test-inference = [
"pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 torchaudio==2.0.2+cu118 --index-url https://download.pytorch.org/whl/cu118",
"pip install -r requirements/pt2.txt",
"pytest -v tests/inference/test_inference.py {args}",
]

3
pytest.ini Normal file
View File

@@ -0,0 +1,3 @@
[pytest]
markers =
inference: mark as inference test (deselect with '-m "not inference"')

View File

@@ -1,7 +1,14 @@
import numpy as np
from pytorch_lightning import seed_everything from pytorch_lightning import seed_everything
from scripts.demo.streamlit_helpers import * from scripts.demo.streamlit_helpers import *
from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering
from sgm.inference.helpers import (
do_img2img,
do_sample,
get_unique_embedder_keys_from_conditioner,
perform_save_locally,
)
SAVE_PATH = "outputs/demo/txt2img/" SAVE_PATH = "outputs/demo/txt2img/"
@@ -131,6 +138,8 @@ def run_txt2img(
if st.button("Sample"): if st.button("Sample"):
st.write(f"**Model I:** {version}") st.write(f"**Model I:** {version}")
outputs = st.empty()
st.text("Sampling")
out = do_sample( out = do_sample(
state["model"], state["model"],
sampler, sampler,
@@ -144,6 +153,8 @@ def run_txt2img(
return_latents=return_latents, return_latents=return_latents,
filter=filter, filter=filter,
) )
show_samples(out, outputs)
return out return out
@@ -175,6 +186,8 @@ def run_img2img(
num_samples = num_rows * num_cols num_samples = num_rows * num_cols
if st.button("Sample"): if st.button("Sample"):
outputs = st.empty()
st.text("Sampling")
out = do_img2img( out = do_img2img(
repeat(img, "1 ... -> n ...", n=num_samples), repeat(img, "1 ... -> n ...", n=num_samples),
state["model"], state["model"],
@@ -185,6 +198,7 @@ def run_img2img(
return_latents=return_latents, return_latents=return_latents,
filter=filter, filter=filter,
) )
show_samples(out, outputs)
return out return out
@@ -249,8 +263,6 @@ if __name__ == "__main__":
save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, version)) save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, version))
state = init_st(version_dict) state = init_st(version_dict)
if state["msg"]:
st.info(state["msg"])
model = state["model"] model = state["model"]
is_legacy = version_dict["is_legacy"] is_legacy = version_dict["is_legacy"]
@@ -275,7 +287,6 @@ if __name__ == "__main__":
version_dict2 = VERSION2SPECS[version2] version_dict2 = VERSION2SPECS[version2]
state2 = init_st(version_dict2) state2 = init_st(version_dict2)
st.info(state2["msg"])
stage2strength = st.number_input( stage2strength = st.number_input(
"**Refinement strength**", value=0.3, min_value=0.0, max_value=1.0 "**Refinement strength**", value=0.3, min_value=0.0, max_value=1.0
@@ -315,6 +326,7 @@ if __name__ == "__main__":
samples_z = None samples_z = None
if add_pipeline and samples_z is not None: if add_pipeline and samples_z is not None:
outputs = st.empty()
st.write("**Running Refinement Stage**") st.write("**Running Refinement Stage**")
samples = apply_refiner( samples = apply_refiner(
samples_z, samples_z,
@@ -325,6 +337,7 @@ if __name__ == "__main__":
negative_prompt=negative_prompt if is_legacy else "", negative_prompt=negative_prompt if is_legacy else "",
filter=filter, filter=filter,
) )
show_samples(samples, outputs)
if save_locally and samples is not None: if save_locally and samples is not None:
perform_save_locally(save_path, samples) perform_save_locally(save_path, samples)

View File

@@ -1,18 +1,11 @@
import math
import os import os
from typing import List, Union
import numpy as np
import streamlit as st import streamlit as st
import torch import torch
from einops import rearrange, repeat
from imwatermark import WatermarkEncoder
from omegaconf import ListConfig, OmegaConf
from PIL import Image from PIL import Image
from safetensors.torch import load_file as load_safetensors from einops import rearrange, repeat
from torch import autocast from omegaconf import OmegaConf
from torchvision import transforms from torchvision import transforms
from torchvision.utils import make_grid
from sgm.modules.diffusionmodules.sampling import ( from sgm.modules.diffusionmodules.sampling import (
DPMPP2MSampler, DPMPP2MSampler,
@@ -22,52 +15,8 @@ from sgm.modules.diffusionmodules.sampling import (
HeunEDMSampler, HeunEDMSampler,
LinearMultistepSampler, LinearMultistepSampler,
) )
from sgm.util import append_dims, instantiate_from_config from sgm.inference.helpers import Img2ImgDiscretizationWrapper, embed_watermark
from sgm.util import load_model_from_config
class WatermarkEmbedder:
def __init__(self, watermark):
self.watermark = watermark
self.num_bits = len(WATERMARK_BITS)
self.encoder = WatermarkEncoder()
self.encoder.set_watermark("bits", self.watermark)
def __call__(self, image: torch.Tensor):
"""
Adds a predefined watermark to the input image
Args:
image: ([N,] B, C, H, W) in range [0, 1]
Returns:
same as input but watermarked
"""
# watermarking libary expects input as cv2 BGR format
squeeze = len(image.shape) == 4
if squeeze:
image = image[None, ...]
n = image.shape[0]
image_np = rearrange(
(255 * image).detach().cpu(), "n b c h w -> (n b) h w c"
).numpy()[:, :, :, ::-1]
# torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
for k in range(image_np.shape[0]):
image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
image = torch.from_numpy(
rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)
).to(image.device)
image = torch.clamp(image / 255, min=0.0, max=1.0)
if squeeze:
image = image[0]
return image
# A fixed 48-bit message that was choosen at random
# WATERMARK_MESSAGE = 0xB3EC907BB19E
WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110
# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
embed_watemark = WatermarkEmbedder(WATERMARK_BITS)
@st.cache_resource() @st.cache_resource()
@@ -78,54 +27,17 @@ def init_st(version_dict, load_ckpt=True):
ckpt = version_dict["ckpt"] ckpt = version_dict["ckpt"]
config = OmegaConf.load(config) config = OmegaConf.load(config)
model, msg = load_model_from_config(config, ckpt if load_ckpt else None) model = load_model_from_config(config, ckpt if load_ckpt else None)
model = model.to("cuda")
model.conditioner.half()
model.model.half()
state["msg"] = msg
state["model"] = model state["model"] = model
state["ckpt"] = ckpt if load_ckpt else None state["ckpt"] = ckpt if load_ckpt else None
state["config"] = config state["config"] = config
return state return state
def load_model_from_config(config, ckpt=None, verbose=True):
model = instantiate_from_config(config.model)
if ckpt is not None:
print(f"Loading model from {ckpt}")
if ckpt.endswith("ckpt"):
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
global_step = pl_sd["global_step"]
st.info(f"loaded ckpt from global step {global_step}")
print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"]
elif ckpt.endswith("safetensors"):
sd = load_safetensors(ckpt)
else:
raise NotImplementedError
msg = None
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print("missing keys:")
print(m)
if len(u) > 0 and verbose:
print("unexpected keys:")
print(u)
else:
msg = None
model.cuda()
model.eval()
return model, msg
def get_unique_embedder_keys_from_conditioner(conditioner):
return list(set([x.input_key for x in conditioner.embedders]))
def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None): def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None):
# Hardcoded demo settings; might undergo some changes in the future # Hardcoded demo settings; might undergo some changes in the future
@@ -186,18 +98,6 @@ def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None):
return value_dict return value_dict
def perform_save_locally(save_path, samples):
os.makedirs(os.path.join(save_path), exist_ok=True)
base_count = len(os.listdir(os.path.join(save_path)))
samples = embed_watemark(samples)
for sample in samples:
sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c")
Image.fromarray(sample.astype(np.uint8)).save(
os.path.join(save_path, f"{base_count:09}.png")
)
base_count += 1
def init_save_locally(_dir, init_value: bool = False): def init_save_locally(_dir, init_value: bool = False):
save_locally = st.sidebar.checkbox("Save images locally", value=init_value) save_locally = st.sidebar.checkbox("Save images locally", value=init_value)
if save_locally: if save_locally:
@@ -208,28 +108,12 @@ def init_save_locally(_dir, init_value: bool = False):
return save_locally, save_path return save_locally, save_path
class Img2ImgDiscretizationWrapper: def show_samples(samples, outputs):
""" if isinstance(samples, tuple):
wraps a discretizer, and prunes the sigmas samples, _ = samples
params: grid = embed_watermark(torch.stack([samples]))
strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned) grid = rearrange(grid, "n b c h w -> (n h) (b w) c")
""" outputs.image(grid.cpu().numpy())
def __init__(self, discretization, strength: float = 1.0):
self.discretization = discretization
self.strength = strength
assert 0.0 <= self.strength <= 1.0
def __call__(self, *args, **kwargs):
# sigmas start large first, and decrease then
sigmas = self.discretization(*args, **kwargs)
print(f"sigmas after discretization, before pruning img2img: ", sigmas)
sigmas = torch.flip(sigmas, (0,))
sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)]
print("prune index:", max(int(self.strength * len(sigmas)), 1))
sigmas = torch.flip(sigmas, (0,))
print(f"sigmas after pruning: ", sigmas)
return sigmas
def get_guider(key): def get_guider(key):
@@ -452,214 +336,3 @@ def get_init_img(batch_size=1, key=None):
init_image = load_img(key=key).cuda() init_image = load_img(key=key).cuda()
init_image = repeat(init_image, "1 ... -> b ...", b=batch_size) init_image = repeat(init_image, "1 ... -> b ...", b=batch_size)
return init_image return init_image
def do_sample(
model,
sampler,
value_dict,
num_samples,
H,
W,
C,
F,
force_uc_zero_embeddings: List = None,
batch2model_input: List = None,
return_latents=False,
filter=None,
):
if force_uc_zero_embeddings is None:
force_uc_zero_embeddings = []
if batch2model_input is None:
batch2model_input = []
st.text("Sampling")
outputs = st.empty()
precision_scope = autocast
with torch.no_grad():
with precision_scope("cuda"):
with model.ema_scope():
num_samples = [num_samples]
batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner),
value_dict,
num_samples,
)
for key in batch:
if isinstance(batch[key], torch.Tensor):
print(key, batch[key].shape)
elif isinstance(batch[key], list):
print(key, [len(l) for l in batch[key]])
else:
print(key, batch[key])
c, uc = model.conditioner.get_unconditional_conditioning(
batch,
batch_uc=batch_uc,
force_uc_zero_embeddings=force_uc_zero_embeddings,
)
for k in c:
if not k == "crossattn":
c[k], uc[k] = map(
lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc)
)
additional_model_inputs = {}
for k in batch2model_input:
additional_model_inputs[k] = batch[k]
shape = (math.prod(num_samples), C, H // F, W // F)
randn = torch.randn(shape).to("cuda")
def denoiser(input, sigma, c):
return model.denoiser(
model.model, input, sigma, c, **additional_model_inputs
)
samples_z = sampler(denoiser, randn, cond=c, uc=uc)
samples_x = model.decode_first_stage(samples_z)
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
if filter is not None:
samples = filter(samples)
grid = torch.stack([samples])
grid = rearrange(grid, "n b c h w -> (n h) (b w) c")
outputs.image(grid.cpu().numpy())
if return_latents:
return samples, samples_z
return samples
def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
# Hardcoded demo setups; might undergo some changes in the future
batch = {}
batch_uc = {}
for key in keys:
if key == "txt":
batch["txt"] = (
np.repeat([value_dict["prompt"]], repeats=math.prod(N))
.reshape(N)
.tolist()
)
batch_uc["txt"] = (
np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N))
.reshape(N)
.tolist()
)
elif key == "original_size_as_tuple":
batch["original_size_as_tuple"] = (
torch.tensor([value_dict["orig_height"], value_dict["orig_width"]])
.to(device)
.repeat(*N, 1)
)
elif key == "crop_coords_top_left":
batch["crop_coords_top_left"] = (
torch.tensor(
[value_dict["crop_coords_top"], value_dict["crop_coords_left"]]
)
.to(device)
.repeat(*N, 1)
)
elif key == "aesthetic_score":
batch["aesthetic_score"] = (
torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
)
batch_uc["aesthetic_score"] = (
torch.tensor([value_dict["negative_aesthetic_score"]])
.to(device)
.repeat(*N, 1)
)
elif key == "target_size_as_tuple":
batch["target_size_as_tuple"] = (
torch.tensor([value_dict["target_height"], value_dict["target_width"]])
.to(device)
.repeat(*N, 1)
)
else:
batch[key] = value_dict[key]
for key in batch.keys():
if key not in batch_uc and isinstance(batch[key], torch.Tensor):
batch_uc[key] = torch.clone(batch[key])
return batch, batch_uc
@torch.no_grad()
def do_img2img(
img,
model,
sampler,
value_dict,
num_samples,
force_uc_zero_embeddings=[],
additional_kwargs={},
offset_noise_level: int = 0.0,
return_latents=False,
skip_encode=False,
filter=None,
):
st.text("Sampling")
outputs = st.empty()
precision_scope = autocast
with torch.no_grad():
with precision_scope("cuda"):
with model.ema_scope():
batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner),
value_dict,
[num_samples],
)
c, uc = model.conditioner.get_unconditional_conditioning(
batch,
batch_uc=batch_uc,
force_uc_zero_embeddings=force_uc_zero_embeddings,
)
for k in c:
c[k], uc[k] = map(lambda y: y[k][:num_samples].to("cuda"), (c, uc))
for k in additional_kwargs:
c[k] = uc[k] = additional_kwargs[k]
if skip_encode:
z = img
else:
z = model.encode_first_stage(img)
noise = torch.randn_like(z)
sigmas = sampler.discretization(sampler.num_steps)
sigma = sigmas[0]
st.info(f"all sigmas: {sigmas}")
st.info(f"noising sigma: {sigma}")
if offset_noise_level > 0.0:
noise = noise + offset_noise_level * append_dims(
torch.randn(z.shape[0], device=z.device), z.ndim
)
noised_z = z + noise * append_dims(sigma, z.ndim)
noised_z = noised_z / torch.sqrt(
1.0 + sigmas[0] ** 2.0
) # Note: hardcoded to DDPM-like scaling. need to generalize later.
def denoiser(x, sigma, c):
return model.denoiser(model.model, x, sigma, c)
samples_z = sampler(denoiser, noised_z, cond=c, uc=uc)
samples_x = model.decode_first_stage(samples_z)
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
if filter is not None:
samples = filter(samples)
grid = embed_watemark(torch.stack([samples]))
grid = rearrange(grid, "n b c h w -> (n h) (b w) c")
outputs.image(grid.cpu().numpy())
if return_latents:
return samples, samples_z
return samples

388
sgm/inference/api.py Normal file
View File

@@ -0,0 +1,388 @@
from dataclasses import dataclass, asdict
from enum import Enum
from omegaconf import OmegaConf
import pathlib
from sgm.inference.helpers import (
do_sample,
do_img2img,
Img2ImgDiscretizationWrapper,
)
from sgm.modules.diffusionmodules.sampling import (
EulerEDMSampler,
HeunEDMSampler,
EulerAncestralSampler,
DPMPP2SAncestralSampler,
DPMPP2MSampler,
LinearMultistepSampler,
)
from sgm.util import load_model_from_config
from typing import Optional
class ModelArchitecture(str, Enum):
SD_2_1 = "stable-diffusion-v2-1"
SD_2_1_768 = "stable-diffusion-v2-1-768"
SDXL_V0_9_BASE = "stable-diffusion-xl-v0-9-base"
SDXL_V0_9_REFINER = "stable-diffusion-xl-v0-9-refiner"
SDXL_V1_BASE = "stable-diffusion-xl-v1-base"
SDXL_V1_REFINER = "stable-diffusion-xl-v1-refiner"
class Sampler(str, Enum):
EULER_EDM = "EulerEDMSampler"
HEUN_EDM = "HeunEDMSampler"
EULER_ANCESTRAL = "EulerAncestralSampler"
DPMPP2S_ANCESTRAL = "DPMPP2SAncestralSampler"
DPMPP2M = "DPMPP2MSampler"
LINEAR_MULTISTEP = "LinearMultistepSampler"
class Discretization(str, Enum):
LEGACY_DDPM = "LegacyDDPMDiscretization"
EDM = "EDMDiscretization"
class Guider(str, Enum):
VANILLA = "VanillaCFG"
IDENTITY = "IdentityGuider"
class Thresholder(str, Enum):
NONE = "None"
@dataclass
class SamplingParams:
width: int = 1024
height: int = 1024
steps: int = 50
sampler: Sampler = Sampler.DPMPP2M
discretization: Discretization = Discretization.LEGACY_DDPM
guider: Guider = Guider.VANILLA
thresholder: Thresholder = Thresholder.NONE
scale: float = 6.0
aesthetic_score: float = 5.0
negative_aesthetic_score: float = 5.0
img2img_strength: float = 1.0
orig_width: int = 1024
orig_height: int = 1024
crop_coords_top: int = 0
crop_coords_left: int = 0
sigma_min: float = 0.0292
sigma_max: float = 14.6146
rho: float = 3.0
s_churn: float = 0.0
s_tmin: float = 0.0
s_tmax: float = 999.0
s_noise: float = 1.0
eta: float = 1.0
order: int = 4
@dataclass
class SamplingSpec:
width: int
height: int
channels: int
factor: int
is_legacy: bool
config: str
ckpt: str
is_guided: bool
model_specs = {
ModelArchitecture.SD_2_1: SamplingSpec(
height=512,
width=512,
channels=4,
factor=8,
is_legacy=True,
config="sd_2_1.yaml",
ckpt="v2-1_512-ema-pruned.safetensors",
is_guided=True,
),
ModelArchitecture.SD_2_1_768: SamplingSpec(
height=768,
width=768,
channels=4,
factor=8,
is_legacy=True,
config="sd_2_1_768.yaml",
ckpt="v2-1_768-ema-pruned.safetensors",
is_guided=True,
),
ModelArchitecture.SDXL_V0_9_BASE: SamplingSpec(
height=1024,
width=1024,
channels=4,
factor=8,
is_legacy=False,
config="sd_xl_base.yaml",
ckpt="sd_xl_base_0.9.safetensors",
is_guided=True,
),
ModelArchitecture.SDXL_V0_9_REFINER: SamplingSpec(
height=1024,
width=1024,
channels=4,
factor=8,
is_legacy=True,
config="sd_xl_refiner.yaml",
ckpt="sd_xl_refiner_0.9.safetensors",
is_guided=True,
),
ModelArchitecture.SDXL_V1_BASE: SamplingSpec(
height=1024,
width=1024,
channels=4,
factor=8,
is_legacy=False,
config="sd_xl_base.yaml",
ckpt="sd_xl_base_1.0-metadata.safetensors",
is_guided=True,
),
ModelArchitecture.SDXL_V1_REFINER: SamplingSpec(
height=1024,
width=1024,
channels=4,
factor=8,
is_legacy=True,
config="sd_xl_refiner.yaml",
ckpt="sd_xl_refiner_1.0-metadata.safetensors",
is_guided=True,
),
}
class SamplingPipeline:
def __init__(
self,
model_id: ModelArchitecture,
model_path="checkpoints",
config_path="configs/inference",
device="cuda",
use_fp16=True,
) -> None:
if model_id not in model_specs:
raise ValueError(f"Model {model_id} not supported")
self.model_id = model_id
self.specs = model_specs[self.model_id]
self.config = str(pathlib.Path(config_path, self.specs.config))
self.ckpt = str(pathlib.Path(model_path, self.specs.ckpt))
self.device = device
self.model = self._load_model(device=device, use_fp16=use_fp16)
def _load_model(self, device="cuda", use_fp16=True):
config = OmegaConf.load(self.config)
model = load_model_from_config(config, self.ckpt)
if model is None:
raise ValueError(f"Model {self.model_id} could not be loaded")
model.to(device)
if use_fp16:
model.conditioner.half()
model.model.half()
return model
def text_to_image(
self,
params: SamplingParams,
prompt: str,
negative_prompt: str = "",
samples: int = 1,
return_latents: bool = False,
):
sampler = get_sampler_config(params)
value_dict = asdict(params)
value_dict["prompt"] = prompt
value_dict["negative_prompt"] = negative_prompt
value_dict["target_width"] = params.width
value_dict["target_height"] = params.height
return do_sample(
self.model,
sampler,
value_dict,
samples,
params.height,
params.width,
self.specs.channels,
self.specs.factor,
force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
return_latents=return_latents,
filter=None,
)
def image_to_image(
self,
params: SamplingParams,
image,
prompt: str,
negative_prompt: str = "",
samples: int = 1,
return_latents: bool = False,
):
sampler = get_sampler_config(params)
if params.img2img_strength < 1.0:
sampler.discretization = Img2ImgDiscretizationWrapper(
sampler.discretization,
strength=params.img2img_strength,
)
height, width = image.shape[2], image.shape[3]
value_dict = asdict(params)
value_dict["prompt"] = prompt
value_dict["negative_prompt"] = negative_prompt
value_dict["target_width"] = width
value_dict["target_height"] = height
return do_img2img(
image,
self.model,
sampler,
value_dict,
samples,
force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
return_latents=return_latents,
filter=None,
)
def refiner(
self,
params: SamplingParams,
image,
prompt: str,
negative_prompt: Optional[str] = None,
samples: int = 1,
return_latents: bool = False,
):
sampler = get_sampler_config(params)
value_dict = {
"orig_width": image.shape[3] * 8,
"orig_height": image.shape[2] * 8,
"target_width": image.shape[3] * 8,
"target_height": image.shape[2] * 8,
"prompt": prompt,
"negative_prompt": negative_prompt,
"crop_coords_top": 0,
"crop_coords_left": 0,
"aesthetic_score": 6.0,
"negative_aesthetic_score": 2.5,
}
return do_img2img(
image,
self.model,
sampler,
value_dict,
samples,
skip_encode=True,
return_latents=return_latents,
filter=None,
)
def get_guider_config(params: SamplingParams):
if params.guider == Guider.IDENTITY:
guider_config = {
"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"
}
elif params.guider == Guider.VANILLA:
scale = params.scale
thresholder = params.thresholder
if thresholder == Thresholder.NONE:
dyn_thresh_config = {
"target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"
}
else:
raise NotImplementedError
guider_config = {
"target": "sgm.modules.diffusionmodules.guiders.VanillaCFG",
"params": {"scale": scale, "dyn_thresh_config": dyn_thresh_config},
}
else:
raise NotImplementedError
return guider_config
def get_discretization_config(params: SamplingParams):
if params.discretization == Discretization.LEGACY_DDPM:
discretization_config = {
"target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",
}
elif params.discretization == Discretization.EDM:
discretization_config = {
"target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization",
"params": {
"sigma_min": params.sigma_min,
"sigma_max": params.sigma_max,
"rho": params.rho,
},
}
else:
raise ValueError(f"unknown discretization {params.discretization}")
return discretization_config
def get_sampler_config(params: SamplingParams):
discretization_config = get_discretization_config(params)
guider_config = get_guider_config(params)
sampler = None
if params.sampler == Sampler.EULER_EDM:
return EulerEDMSampler(
num_steps=params.steps,
discretization_config=discretization_config,
guider_config=guider_config,
s_churn=params.s_churn,
s_tmin=params.s_tmin,
s_tmax=params.s_tmax,
s_noise=params.s_noise,
verbose=True,
)
if params.sampler == Sampler.HEUN_EDM:
return HeunEDMSampler(
num_steps=params.steps,
discretization_config=discretization_config,
guider_config=guider_config,
s_churn=params.s_churn,
s_tmin=params.s_tmin,
s_tmax=params.s_tmax,
s_noise=params.s_noise,
verbose=True,
)
if params.sampler == Sampler.EULER_ANCESTRAL:
return EulerAncestralSampler(
num_steps=params.steps,
discretization_config=discretization_config,
guider_config=guider_config,
eta=params.eta,
s_noise=params.s_noise,
verbose=True,
)
if params.sampler == Sampler.DPMPP2S_ANCESTRAL:
return DPMPP2SAncestralSampler(
num_steps=params.steps,
discretization_config=discretization_config,
guider_config=guider_config,
eta=params.eta,
s_noise=params.s_noise,
verbose=True,
)
if params.sampler == Sampler.DPMPP2M:
return DPMPP2MSampler(
num_steps=params.steps,
discretization_config=discretization_config,
guider_config=guider_config,
verbose=True,
)
if params.sampler == Sampler.LINEAR_MULTISTEP:
return LinearMultistepSampler(
num_steps=params.steps,
discretization_config=discretization_config,
guider_config=guider_config,
order=params.order,
verbose=True,
)
raise ValueError(f"unknown sampler {params.sampler}!")

305
sgm/inference/helpers.py Normal file
View File

@@ -0,0 +1,305 @@
import os
from typing import Union, List, Optional
import math
import numpy as np
import torch
from PIL import Image
from einops import rearrange
from imwatermark import WatermarkEncoder
from omegaconf import ListConfig
from torch import autocast
from sgm.util import append_dims
class WatermarkEmbedder:
def __init__(self, watermark):
self.watermark = watermark
self.num_bits = len(WATERMARK_BITS)
self.encoder = WatermarkEncoder()
self.encoder.set_watermark("bits", self.watermark)
def __call__(self, image: torch.Tensor):
"""
Adds a predefined watermark to the input image
Args:
image: ([N,] B, C, H, W) in range [0, 1]
Returns:
same as input but watermarked
"""
# watermarking libary expects input as cv2 BGR format
squeeze = len(image.shape) == 4
if squeeze:
image = image[None, ...]
n = image.shape[0]
image_np = rearrange(
(255 * image).detach().cpu(), "n b c h w -> (n b) h w c"
).numpy()[:, :, :, ::-1]
# torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
for k in range(image_np.shape[0]):
image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
image = torch.from_numpy(
rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)
).to(image.device)
image = torch.clamp(image / 255, min=0.0, max=1.0)
if squeeze:
image = image[0]
return image
# A fixed 48-bit message that was choosen at random
# WATERMARK_MESSAGE = 0xB3EC907BB19E
WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110
# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
embed_watermark = WatermarkEmbedder(WATERMARK_BITS)
def get_unique_embedder_keys_from_conditioner(conditioner):
return list({x.input_key for x in conditioner.embedders})
def perform_save_locally(save_path, samples):
os.makedirs(os.path.join(save_path), exist_ok=True)
base_count = len(os.listdir(os.path.join(save_path)))
samples = embed_watermark(samples)
for sample in samples:
sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c")
Image.fromarray(sample.astype(np.uint8)).save(
os.path.join(save_path, f"{base_count:09}.png")
)
base_count += 1
class Img2ImgDiscretizationWrapper:
"""
wraps a discretizer, and prunes the sigmas
params:
strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned)
"""
def __init__(self, discretization, strength: float = 1.0):
self.discretization = discretization
self.strength = strength
assert 0.0 <= self.strength <= 1.0
def __call__(self, *args, **kwargs):
# sigmas start large first, and decrease then
sigmas = self.discretization(*args, **kwargs)
print(f"sigmas after discretization, before pruning img2img: ", sigmas)
sigmas = torch.flip(sigmas, (0,))
sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)]
print("prune index:", max(int(self.strength * len(sigmas)), 1))
sigmas = torch.flip(sigmas, (0,))
print(f"sigmas after pruning: ", sigmas)
return sigmas
def do_sample(
model,
sampler,
value_dict,
num_samples,
H,
W,
C,
F,
force_uc_zero_embeddings: Optional[List] = None,
batch2model_input: Optional[List] = None,
return_latents=False,
filter=None,
device="cuda",
):
if force_uc_zero_embeddings is None:
force_uc_zero_embeddings = []
if batch2model_input is None:
batch2model_input = []
with torch.no_grad():
with autocast(device) as precision_scope:
with model.ema_scope():
num_samples = [num_samples]
batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner),
value_dict,
num_samples,
)
for key in batch:
if isinstance(batch[key], torch.Tensor):
print(key, batch[key].shape)
elif isinstance(batch[key], list):
print(key, [len(l) for l in batch[key]])
else:
print(key, batch[key])
c, uc = model.conditioner.get_unconditional_conditioning(
batch,
batch_uc=batch_uc,
force_uc_zero_embeddings=force_uc_zero_embeddings,
)
for k in c:
if not k == "crossattn":
c[k], uc[k] = map(
lambda y: y[k][: math.prod(num_samples)].to(device), (c, uc)
)
additional_model_inputs = {}
for k in batch2model_input:
additional_model_inputs[k] = batch[k]
shape = (math.prod(num_samples), C, H // F, W // F)
randn = torch.randn(shape).to(device)
def denoiser(input, sigma, c):
return model.denoiser(
model.model, input, sigma, c, **additional_model_inputs
)
samples_z = sampler(denoiser, randn, cond=c, uc=uc)
samples_x = model.decode_first_stage(samples_z)
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
if filter is not None:
samples = filter(samples)
if return_latents:
return samples, samples_z
return samples
def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
# Hardcoded demo setups; might undergo some changes in the future
batch = {}
batch_uc = {}
for key in keys:
if key == "txt":
batch["txt"] = (
np.repeat([value_dict["prompt"]], repeats=math.prod(N))
.reshape(N)
.tolist()
)
batch_uc["txt"] = (
np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N))
.reshape(N)
.tolist()
)
elif key == "original_size_as_tuple":
batch["original_size_as_tuple"] = (
torch.tensor([value_dict["orig_height"], value_dict["orig_width"]])
.to(device)
.repeat(*N, 1)
)
elif key == "crop_coords_top_left":
batch["crop_coords_top_left"] = (
torch.tensor(
[value_dict["crop_coords_top"], value_dict["crop_coords_left"]]
)
.to(device)
.repeat(*N, 1)
)
elif key == "aesthetic_score":
batch["aesthetic_score"] = (
torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
)
batch_uc["aesthetic_score"] = (
torch.tensor([value_dict["negative_aesthetic_score"]])
.to(device)
.repeat(*N, 1)
)
elif key == "target_size_as_tuple":
batch["target_size_as_tuple"] = (
torch.tensor([value_dict["target_height"], value_dict["target_width"]])
.to(device)
.repeat(*N, 1)
)
else:
batch[key] = value_dict[key]
for key in batch.keys():
if key not in batch_uc and isinstance(batch[key], torch.Tensor):
batch_uc[key] = torch.clone(batch[key])
return batch, batch_uc
def get_input_image_tensor(image: Image.Image, device="cuda"):
w, h = image.size
print(f"loaded input image of size ({w}, {h})")
width, height = map(
lambda x: x - x % 64, (w, h)
) # resize to integer multiple of 64
image = image.resize((width, height))
image_array = np.array(image.convert("RGB"))
image_array = image_array[None].transpose(0, 3, 1, 2)
image_tensor = torch.from_numpy(image_array).to(dtype=torch.float32) / 127.5 - 1.0
return image_tensor.to(device)
def do_img2img(
img,
model,
sampler,
value_dict,
num_samples,
force_uc_zero_embeddings=[],
additional_kwargs={},
offset_noise_level: float = 0.0,
return_latents=False,
skip_encode=False,
filter=None,
device="cuda",
):
with torch.no_grad():
with autocast(device) as precision_scope:
with model.ema_scope():
batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner),
value_dict,
[num_samples],
)
c, uc = model.conditioner.get_unconditional_conditioning(
batch,
batch_uc=batch_uc,
force_uc_zero_embeddings=force_uc_zero_embeddings,
)
for k in c:
c[k], uc[k] = map(lambda y: y[k][:num_samples].to(device), (c, uc))
for k in additional_kwargs:
c[k] = uc[k] = additional_kwargs[k]
if skip_encode:
z = img
else:
z = model.encode_first_stage(img)
noise = torch.randn_like(z)
sigmas = sampler.discretization(sampler.num_steps)
sigma = sigmas[0].to(z.device)
if offset_noise_level > 0.0:
noise = noise + offset_noise_level * append_dims(
torch.randn(z.shape[0], device=z.device), z.ndim
)
noised_z = z + noise * append_dims(sigma, z.ndim)
noised_z = noised_z / torch.sqrt(
1.0 + sigmas[0] ** 2.0
) # Note: hardcoded to DDPM-like scaling. need to generalize later.
def denoiser(x, sigma, c):
return model.denoiser(model.model, x, sigma, c)
samples_z = sampler(denoiser, noised_z, cond=c, uc=uc)
samples_x = model.decode_first_stage(samples_z)
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
if filter is not None:
samples = filter(samples)
if return_latents:
return samples, samples_z
return samples

View File

@@ -0,0 +1,111 @@
import numpy
from PIL import Image
import pytest
from pytest import fixture
import torch
from typing import Tuple
from sgm.inference.api import (
model_specs,
SamplingParams,
SamplingPipeline,
Sampler,
ModelArchitecture,
)
import sgm.inference.helpers as helpers
@pytest.mark.inference
class TestInference:
@fixture(scope="class", params=model_specs.keys())
def pipeline(self, request) -> SamplingPipeline:
pipeline = SamplingPipeline(request.param)
yield pipeline
del pipeline
torch.cuda.empty_cache()
@fixture(
scope="class",
params=[
[ModelArchitecture.SDXL_V1_BASE, ModelArchitecture.SDXL_V1_REFINER],
[ModelArchitecture.SDXL_V0_9_BASE, ModelArchitecture.SDXL_V0_9_REFINER],
],
ids=["SDXL_V1", "SDXL_V0_9"],
)
def sdxl_pipelines(self, request) -> Tuple[SamplingPipeline, SamplingPipeline]:
base_pipeline = SamplingPipeline(request.param[0])
refiner_pipeline = SamplingPipeline(request.param[1])
yield base_pipeline, refiner_pipeline
del base_pipeline
del refiner_pipeline
torch.cuda.empty_cache()
def create_init_image(self, h, w):
image_array = numpy.random.rand(h, w, 3) * 255
image = Image.fromarray(image_array.astype("uint8")).convert("RGB")
return helpers.get_input_image_tensor(image)
@pytest.mark.parametrize("sampler_enum", Sampler)
def test_txt2img(self, pipeline: SamplingPipeline, sampler_enum):
output = pipeline.text_to_image(
params=SamplingParams(sampler=sampler_enum.value, steps=10),
prompt="A professional photograph of an astronaut riding a pig",
negative_prompt="",
samples=1,
)
assert output is not None
@pytest.mark.parametrize("sampler_enum", Sampler)
def test_img2img(self, pipeline: SamplingPipeline, sampler_enum):
output = pipeline.image_to_image(
params=SamplingParams(sampler=sampler_enum.value, steps=10),
image=self.create_init_image(pipeline.specs.height, pipeline.specs.width),
prompt="A professional photograph of an astronaut riding a pig",
negative_prompt="",
samples=1,
)
assert output is not None
@pytest.mark.parametrize("sampler_enum", Sampler)
@pytest.mark.parametrize(
"use_init_image", [True, False], ids=["img2img", "txt2img"]
)
def test_sdxl_with_refiner(
self,
sdxl_pipelines: Tuple[SamplingPipeline, SamplingPipeline],
sampler_enum,
use_init_image,
):
base_pipeline, refiner_pipeline = sdxl_pipelines
if use_init_image:
output = base_pipeline.image_to_image(
params=SamplingParams(sampler=sampler_enum.value, steps=10),
image=self.create_init_image(
base_pipeline.specs.height, base_pipeline.specs.width
),
prompt="A professional photograph of an astronaut riding a pig",
negative_prompt="",
samples=1,
return_latents=True,
)
else:
output = base_pipeline.text_to_image(
params=SamplingParams(sampler=sampler_enum.value, steps=10),
prompt="A professional photograph of an astronaut riding a pig",
negative_prompt="",
samples=1,
return_latents=True,
)
assert isinstance(output, (tuple, list))
samples, samples_z = output
assert samples is not None
assert samples_z is not None
refiner_pipeline.refiner(
params=SamplingParams(sampler=sampler_enum.value, steps=10),
image=samples_z,
prompt="A professional photograph of an astronaut riding a pig",
negative_prompt="",
samples=1,
)