mirror of
https://github.com/Stability-AI/generative-models.git
synced 2025-12-22 15:54:21 +01:00
Add inference helpers & tests (#57)
* Add inference helpers & tests * Support testing with hatch * fixes to hatch script * add inference test action * change workflow trigger * widen trigger to test * revert changes to workflow triggers * Install local python in action * Trigger on push again * fix python version * add CODEOWNERS and change triggers * Report tests results * update action versions * format * Fix typo and add refiner helper * use a shared path loaded from a secret for checkpoints source * typo fix * Use device from input and remove duplicated code * PR feedback * fix call to load_model_from_config * Move model to gpu * Refactor helpers * cleanup * test refiner, prep for 1.0, align with metadata * fix paths on second load * deduplicate streamlit code * filenames * fixes * add pydantic to requirements * fix usage of `msg` in demo script * remove double text * run black * fix streamlit sampling when returning latents * extract function for streamlit output * another fix for streamlit outputs * fix img2img in streamlit * Make fp16 optional and fix device param * PR feedback * fix dict cast for dataclass * run black, update ci script * cache pip dependencies on hosted runners, remove extra runs * install package in ci env * fix cache path * PR cleanup * one more cleanup * don't cache, it filled up
This commit is contained in:
1
.github/workflows/CODEOWNERS
vendored
Normal file
1
.github/workflows/CODEOWNERS
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
.github @Stability-AI/infrastructure
|
||||||
2
.github/workflows/black.yml
vendored
2
.github/workflows/black.yml
vendored
@@ -1,5 +1,5 @@
|
|||||||
name: Run black
|
name: Run black
|
||||||
on: [push, pull_request]
|
on: [pull_request]
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
lint:
|
lint:
|
||||||
|
|||||||
1
.github/workflows/test-build.yaml
vendored
1
.github/workflows/test-build.yaml
vendored
@@ -2,6 +2,7 @@ name: Build package
|
|||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
|
branches: [ main ]
|
||||||
pull_request:
|
pull_request:
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
|
|||||||
34
.github/workflows/test-inference.yml
vendored
Normal file
34
.github/workflows/test-inference.yml
vendored
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
name: Test inference
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test:
|
||||||
|
name: "Test inference"
|
||||||
|
# This action is designed only to run on the Stability research cluster at this time, so many assumptions are made about the environment
|
||||||
|
if: github.repository == 'stability-ai/generative-models'
|
||||||
|
runs-on: [self-hosted, slurm, g40]
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v3
|
||||||
|
- name: "Symlink checkpoints"
|
||||||
|
run: ln -s ${{secrets.SGM_CHECKPOINTS_PATH}} checkpoints
|
||||||
|
- name: "Setup python"
|
||||||
|
uses: actions/setup-python@v4
|
||||||
|
with:
|
||||||
|
python-version: "3.10"
|
||||||
|
- name: "Install Hatch"
|
||||||
|
run: pip install hatch
|
||||||
|
- name: "Run inference tests"
|
||||||
|
run: hatch run ci:test-inference --junit-xml test-results.xml
|
||||||
|
- name: Surface failing tests
|
||||||
|
if: always()
|
||||||
|
uses: pmeier/pytest-results-action@main
|
||||||
|
with:
|
||||||
|
path: test-results.xml
|
||||||
|
summary: true
|
||||||
|
display-options: fEX
|
||||||
|
fail-on-empty: true
|
||||||
@@ -32,3 +32,17 @@ include = [
|
|||||||
|
|
||||||
[tool.hatch.build.targets.wheel.force-include]
|
[tool.hatch.build.targets.wheel.force-include]
|
||||||
"./configs" = "sgm/configs"
|
"./configs" = "sgm/configs"
|
||||||
|
|
||||||
|
[tool.hatch.envs.ci]
|
||||||
|
skip-install = false
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
"pytest"
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.hatch.envs.ci.scripts]
|
||||||
|
test-inference = [
|
||||||
|
"pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 torchaudio==2.0.2+cu118 --index-url https://download.pytorch.org/whl/cu118",
|
||||||
|
"pip install -r requirements/pt2.txt",
|
||||||
|
"pytest -v tests/inference/test_inference.py {args}",
|
||||||
|
]
|
||||||
3
pytest.ini
Normal file
3
pytest.ini
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
[pytest]
|
||||||
|
markers =
|
||||||
|
inference: mark as inference test (deselect with '-m "not inference"')
|
||||||
@@ -1,7 +1,14 @@
|
|||||||
|
import numpy as np
|
||||||
from pytorch_lightning import seed_everything
|
from pytorch_lightning import seed_everything
|
||||||
|
|
||||||
from scripts.demo.streamlit_helpers import *
|
from scripts.demo.streamlit_helpers import *
|
||||||
from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering
|
from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering
|
||||||
|
from sgm.inference.helpers import (
|
||||||
|
do_img2img,
|
||||||
|
do_sample,
|
||||||
|
get_unique_embedder_keys_from_conditioner,
|
||||||
|
perform_save_locally,
|
||||||
|
)
|
||||||
|
|
||||||
SAVE_PATH = "outputs/demo/txt2img/"
|
SAVE_PATH = "outputs/demo/txt2img/"
|
||||||
|
|
||||||
@@ -131,6 +138,8 @@ def run_txt2img(
|
|||||||
|
|
||||||
if st.button("Sample"):
|
if st.button("Sample"):
|
||||||
st.write(f"**Model I:** {version}")
|
st.write(f"**Model I:** {version}")
|
||||||
|
outputs = st.empty()
|
||||||
|
st.text("Sampling")
|
||||||
out = do_sample(
|
out = do_sample(
|
||||||
state["model"],
|
state["model"],
|
||||||
sampler,
|
sampler,
|
||||||
@@ -144,6 +153,8 @@ def run_txt2img(
|
|||||||
return_latents=return_latents,
|
return_latents=return_latents,
|
||||||
filter=filter,
|
filter=filter,
|
||||||
)
|
)
|
||||||
|
show_samples(out, outputs)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
@@ -175,6 +186,8 @@ def run_img2img(
|
|||||||
num_samples = num_rows * num_cols
|
num_samples = num_rows * num_cols
|
||||||
|
|
||||||
if st.button("Sample"):
|
if st.button("Sample"):
|
||||||
|
outputs = st.empty()
|
||||||
|
st.text("Sampling")
|
||||||
out = do_img2img(
|
out = do_img2img(
|
||||||
repeat(img, "1 ... -> n ...", n=num_samples),
|
repeat(img, "1 ... -> n ...", n=num_samples),
|
||||||
state["model"],
|
state["model"],
|
||||||
@@ -185,6 +198,7 @@ def run_img2img(
|
|||||||
return_latents=return_latents,
|
return_latents=return_latents,
|
||||||
filter=filter,
|
filter=filter,
|
||||||
)
|
)
|
||||||
|
show_samples(out, outputs)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
@@ -249,8 +263,6 @@ if __name__ == "__main__":
|
|||||||
save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, version))
|
save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, version))
|
||||||
|
|
||||||
state = init_st(version_dict)
|
state = init_st(version_dict)
|
||||||
if state["msg"]:
|
|
||||||
st.info(state["msg"])
|
|
||||||
model = state["model"]
|
model = state["model"]
|
||||||
|
|
||||||
is_legacy = version_dict["is_legacy"]
|
is_legacy = version_dict["is_legacy"]
|
||||||
@@ -275,7 +287,6 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
version_dict2 = VERSION2SPECS[version2]
|
version_dict2 = VERSION2SPECS[version2]
|
||||||
state2 = init_st(version_dict2)
|
state2 = init_st(version_dict2)
|
||||||
st.info(state2["msg"])
|
|
||||||
|
|
||||||
stage2strength = st.number_input(
|
stage2strength = st.number_input(
|
||||||
"**Refinement strength**", value=0.3, min_value=0.0, max_value=1.0
|
"**Refinement strength**", value=0.3, min_value=0.0, max_value=1.0
|
||||||
@@ -315,6 +326,7 @@ if __name__ == "__main__":
|
|||||||
samples_z = None
|
samples_z = None
|
||||||
|
|
||||||
if add_pipeline and samples_z is not None:
|
if add_pipeline and samples_z is not None:
|
||||||
|
outputs = st.empty()
|
||||||
st.write("**Running Refinement Stage**")
|
st.write("**Running Refinement Stage**")
|
||||||
samples = apply_refiner(
|
samples = apply_refiner(
|
||||||
samples_z,
|
samples_z,
|
||||||
@@ -325,6 +337,7 @@ if __name__ == "__main__":
|
|||||||
negative_prompt=negative_prompt if is_legacy else "",
|
negative_prompt=negative_prompt if is_legacy else "",
|
||||||
filter=filter,
|
filter=filter,
|
||||||
)
|
)
|
||||||
|
show_samples(samples, outputs)
|
||||||
|
|
||||||
if save_locally and samples is not None:
|
if save_locally and samples is not None:
|
||||||
perform_save_locally(save_path, samples)
|
perform_save_locally(save_path, samples)
|
||||||
|
|||||||
@@ -1,18 +1,11 @@
|
|||||||
import math
|
|
||||||
import os
|
import os
|
||||||
from typing import List, Union
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
import torch
|
import torch
|
||||||
from einops import rearrange, repeat
|
|
||||||
from imwatermark import WatermarkEncoder
|
|
||||||
from omegaconf import ListConfig, OmegaConf
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from safetensors.torch import load_file as load_safetensors
|
from einops import rearrange, repeat
|
||||||
from torch import autocast
|
from omegaconf import OmegaConf
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from torchvision.utils import make_grid
|
|
||||||
|
|
||||||
from sgm.modules.diffusionmodules.sampling import (
|
from sgm.modules.diffusionmodules.sampling import (
|
||||||
DPMPP2MSampler,
|
DPMPP2MSampler,
|
||||||
@@ -22,52 +15,8 @@ from sgm.modules.diffusionmodules.sampling import (
|
|||||||
HeunEDMSampler,
|
HeunEDMSampler,
|
||||||
LinearMultistepSampler,
|
LinearMultistepSampler,
|
||||||
)
|
)
|
||||||
from sgm.util import append_dims, instantiate_from_config
|
from sgm.inference.helpers import Img2ImgDiscretizationWrapper, embed_watermark
|
||||||
|
from sgm.util import load_model_from_config
|
||||||
|
|
||||||
class WatermarkEmbedder:
|
|
||||||
def __init__(self, watermark):
|
|
||||||
self.watermark = watermark
|
|
||||||
self.num_bits = len(WATERMARK_BITS)
|
|
||||||
self.encoder = WatermarkEncoder()
|
|
||||||
self.encoder.set_watermark("bits", self.watermark)
|
|
||||||
|
|
||||||
def __call__(self, image: torch.Tensor):
|
|
||||||
"""
|
|
||||||
Adds a predefined watermark to the input image
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image: ([N,] B, C, H, W) in range [0, 1]
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
same as input but watermarked
|
|
||||||
"""
|
|
||||||
# watermarking libary expects input as cv2 BGR format
|
|
||||||
squeeze = len(image.shape) == 4
|
|
||||||
if squeeze:
|
|
||||||
image = image[None, ...]
|
|
||||||
n = image.shape[0]
|
|
||||||
image_np = rearrange(
|
|
||||||
(255 * image).detach().cpu(), "n b c h w -> (n b) h w c"
|
|
||||||
).numpy()[:, :, :, ::-1]
|
|
||||||
# torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
|
|
||||||
for k in range(image_np.shape[0]):
|
|
||||||
image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
|
|
||||||
image = torch.from_numpy(
|
|
||||||
rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)
|
|
||||||
).to(image.device)
|
|
||||||
image = torch.clamp(image / 255, min=0.0, max=1.0)
|
|
||||||
if squeeze:
|
|
||||||
image = image[0]
|
|
||||||
return image
|
|
||||||
|
|
||||||
|
|
||||||
# A fixed 48-bit message that was choosen at random
|
|
||||||
# WATERMARK_MESSAGE = 0xB3EC907BB19E
|
|
||||||
WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110
|
|
||||||
# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
|
|
||||||
WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
|
|
||||||
embed_watemark = WatermarkEmbedder(WATERMARK_BITS)
|
|
||||||
|
|
||||||
|
|
||||||
@st.cache_resource()
|
@st.cache_resource()
|
||||||
@@ -78,54 +27,17 @@ def init_st(version_dict, load_ckpt=True):
|
|||||||
ckpt = version_dict["ckpt"]
|
ckpt = version_dict["ckpt"]
|
||||||
|
|
||||||
config = OmegaConf.load(config)
|
config = OmegaConf.load(config)
|
||||||
model, msg = load_model_from_config(config, ckpt if load_ckpt else None)
|
model = load_model_from_config(config, ckpt if load_ckpt else None)
|
||||||
|
model = model.to("cuda")
|
||||||
|
model.conditioner.half()
|
||||||
|
model.model.half()
|
||||||
|
|
||||||
state["msg"] = msg
|
|
||||||
state["model"] = model
|
state["model"] = model
|
||||||
state["ckpt"] = ckpt if load_ckpt else None
|
state["ckpt"] = ckpt if load_ckpt else None
|
||||||
state["config"] = config
|
state["config"] = config
|
||||||
return state
|
return state
|
||||||
|
|
||||||
|
|
||||||
def load_model_from_config(config, ckpt=None, verbose=True):
|
|
||||||
model = instantiate_from_config(config.model)
|
|
||||||
|
|
||||||
if ckpt is not None:
|
|
||||||
print(f"Loading model from {ckpt}")
|
|
||||||
if ckpt.endswith("ckpt"):
|
|
||||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
|
||||||
if "global_step" in pl_sd:
|
|
||||||
global_step = pl_sd["global_step"]
|
|
||||||
st.info(f"loaded ckpt from global step {global_step}")
|
|
||||||
print(f"Global Step: {pl_sd['global_step']}")
|
|
||||||
sd = pl_sd["state_dict"]
|
|
||||||
elif ckpt.endswith("safetensors"):
|
|
||||||
sd = load_safetensors(ckpt)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
msg = None
|
|
||||||
|
|
||||||
m, u = model.load_state_dict(sd, strict=False)
|
|
||||||
|
|
||||||
if len(m) > 0 and verbose:
|
|
||||||
print("missing keys:")
|
|
||||||
print(m)
|
|
||||||
if len(u) > 0 and verbose:
|
|
||||||
print("unexpected keys:")
|
|
||||||
print(u)
|
|
||||||
else:
|
|
||||||
msg = None
|
|
||||||
|
|
||||||
model.cuda()
|
|
||||||
model.eval()
|
|
||||||
return model, msg
|
|
||||||
|
|
||||||
|
|
||||||
def get_unique_embedder_keys_from_conditioner(conditioner):
|
|
||||||
return list(set([x.input_key for x in conditioner.embedders]))
|
|
||||||
|
|
||||||
|
|
||||||
def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None):
|
def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None):
|
||||||
# Hardcoded demo settings; might undergo some changes in the future
|
# Hardcoded demo settings; might undergo some changes in the future
|
||||||
|
|
||||||
@@ -186,18 +98,6 @@ def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None):
|
|||||||
return value_dict
|
return value_dict
|
||||||
|
|
||||||
|
|
||||||
def perform_save_locally(save_path, samples):
|
|
||||||
os.makedirs(os.path.join(save_path), exist_ok=True)
|
|
||||||
base_count = len(os.listdir(os.path.join(save_path)))
|
|
||||||
samples = embed_watemark(samples)
|
|
||||||
for sample in samples:
|
|
||||||
sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c")
|
|
||||||
Image.fromarray(sample.astype(np.uint8)).save(
|
|
||||||
os.path.join(save_path, f"{base_count:09}.png")
|
|
||||||
)
|
|
||||||
base_count += 1
|
|
||||||
|
|
||||||
|
|
||||||
def init_save_locally(_dir, init_value: bool = False):
|
def init_save_locally(_dir, init_value: bool = False):
|
||||||
save_locally = st.sidebar.checkbox("Save images locally", value=init_value)
|
save_locally = st.sidebar.checkbox("Save images locally", value=init_value)
|
||||||
if save_locally:
|
if save_locally:
|
||||||
@@ -208,28 +108,12 @@ def init_save_locally(_dir, init_value: bool = False):
|
|||||||
return save_locally, save_path
|
return save_locally, save_path
|
||||||
|
|
||||||
|
|
||||||
class Img2ImgDiscretizationWrapper:
|
def show_samples(samples, outputs):
|
||||||
"""
|
if isinstance(samples, tuple):
|
||||||
wraps a discretizer, and prunes the sigmas
|
samples, _ = samples
|
||||||
params:
|
grid = embed_watermark(torch.stack([samples]))
|
||||||
strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned)
|
grid = rearrange(grid, "n b c h w -> (n h) (b w) c")
|
||||||
"""
|
outputs.image(grid.cpu().numpy())
|
||||||
|
|
||||||
def __init__(self, discretization, strength: float = 1.0):
|
|
||||||
self.discretization = discretization
|
|
||||||
self.strength = strength
|
|
||||||
assert 0.0 <= self.strength <= 1.0
|
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
|
||||||
# sigmas start large first, and decrease then
|
|
||||||
sigmas = self.discretization(*args, **kwargs)
|
|
||||||
print(f"sigmas after discretization, before pruning img2img: ", sigmas)
|
|
||||||
sigmas = torch.flip(sigmas, (0,))
|
|
||||||
sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)]
|
|
||||||
print("prune index:", max(int(self.strength * len(sigmas)), 1))
|
|
||||||
sigmas = torch.flip(sigmas, (0,))
|
|
||||||
print(f"sigmas after pruning: ", sigmas)
|
|
||||||
return sigmas
|
|
||||||
|
|
||||||
|
|
||||||
def get_guider(key):
|
def get_guider(key):
|
||||||
@@ -452,214 +336,3 @@ def get_init_img(batch_size=1, key=None):
|
|||||||
init_image = load_img(key=key).cuda()
|
init_image = load_img(key=key).cuda()
|
||||||
init_image = repeat(init_image, "1 ... -> b ...", b=batch_size)
|
init_image = repeat(init_image, "1 ... -> b ...", b=batch_size)
|
||||||
return init_image
|
return init_image
|
||||||
|
|
||||||
|
|
||||||
def do_sample(
|
|
||||||
model,
|
|
||||||
sampler,
|
|
||||||
value_dict,
|
|
||||||
num_samples,
|
|
||||||
H,
|
|
||||||
W,
|
|
||||||
C,
|
|
||||||
F,
|
|
||||||
force_uc_zero_embeddings: List = None,
|
|
||||||
batch2model_input: List = None,
|
|
||||||
return_latents=False,
|
|
||||||
filter=None,
|
|
||||||
):
|
|
||||||
if force_uc_zero_embeddings is None:
|
|
||||||
force_uc_zero_embeddings = []
|
|
||||||
if batch2model_input is None:
|
|
||||||
batch2model_input = []
|
|
||||||
|
|
||||||
st.text("Sampling")
|
|
||||||
|
|
||||||
outputs = st.empty()
|
|
||||||
precision_scope = autocast
|
|
||||||
with torch.no_grad():
|
|
||||||
with precision_scope("cuda"):
|
|
||||||
with model.ema_scope():
|
|
||||||
num_samples = [num_samples]
|
|
||||||
batch, batch_uc = get_batch(
|
|
||||||
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
|
||||||
value_dict,
|
|
||||||
num_samples,
|
|
||||||
)
|
|
||||||
for key in batch:
|
|
||||||
if isinstance(batch[key], torch.Tensor):
|
|
||||||
print(key, batch[key].shape)
|
|
||||||
elif isinstance(batch[key], list):
|
|
||||||
print(key, [len(l) for l in batch[key]])
|
|
||||||
else:
|
|
||||||
print(key, batch[key])
|
|
||||||
c, uc = model.conditioner.get_unconditional_conditioning(
|
|
||||||
batch,
|
|
||||||
batch_uc=batch_uc,
|
|
||||||
force_uc_zero_embeddings=force_uc_zero_embeddings,
|
|
||||||
)
|
|
||||||
|
|
||||||
for k in c:
|
|
||||||
if not k == "crossattn":
|
|
||||||
c[k], uc[k] = map(
|
|
||||||
lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc)
|
|
||||||
)
|
|
||||||
|
|
||||||
additional_model_inputs = {}
|
|
||||||
for k in batch2model_input:
|
|
||||||
additional_model_inputs[k] = batch[k]
|
|
||||||
|
|
||||||
shape = (math.prod(num_samples), C, H // F, W // F)
|
|
||||||
randn = torch.randn(shape).to("cuda")
|
|
||||||
|
|
||||||
def denoiser(input, sigma, c):
|
|
||||||
return model.denoiser(
|
|
||||||
model.model, input, sigma, c, **additional_model_inputs
|
|
||||||
)
|
|
||||||
|
|
||||||
samples_z = sampler(denoiser, randn, cond=c, uc=uc)
|
|
||||||
samples_x = model.decode_first_stage(samples_z)
|
|
||||||
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
|
|
||||||
|
|
||||||
if filter is not None:
|
|
||||||
samples = filter(samples)
|
|
||||||
|
|
||||||
grid = torch.stack([samples])
|
|
||||||
grid = rearrange(grid, "n b c h w -> (n h) (b w) c")
|
|
||||||
outputs.image(grid.cpu().numpy())
|
|
||||||
|
|
||||||
if return_latents:
|
|
||||||
return samples, samples_z
|
|
||||||
return samples
|
|
||||||
|
|
||||||
|
|
||||||
def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
|
|
||||||
# Hardcoded demo setups; might undergo some changes in the future
|
|
||||||
|
|
||||||
batch = {}
|
|
||||||
batch_uc = {}
|
|
||||||
|
|
||||||
for key in keys:
|
|
||||||
if key == "txt":
|
|
||||||
batch["txt"] = (
|
|
||||||
np.repeat([value_dict["prompt"]], repeats=math.prod(N))
|
|
||||||
.reshape(N)
|
|
||||||
.tolist()
|
|
||||||
)
|
|
||||||
batch_uc["txt"] = (
|
|
||||||
np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N))
|
|
||||||
.reshape(N)
|
|
||||||
.tolist()
|
|
||||||
)
|
|
||||||
elif key == "original_size_as_tuple":
|
|
||||||
batch["original_size_as_tuple"] = (
|
|
||||||
torch.tensor([value_dict["orig_height"], value_dict["orig_width"]])
|
|
||||||
.to(device)
|
|
||||||
.repeat(*N, 1)
|
|
||||||
)
|
|
||||||
elif key == "crop_coords_top_left":
|
|
||||||
batch["crop_coords_top_left"] = (
|
|
||||||
torch.tensor(
|
|
||||||
[value_dict["crop_coords_top"], value_dict["crop_coords_left"]]
|
|
||||||
)
|
|
||||||
.to(device)
|
|
||||||
.repeat(*N, 1)
|
|
||||||
)
|
|
||||||
elif key == "aesthetic_score":
|
|
||||||
batch["aesthetic_score"] = (
|
|
||||||
torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
|
|
||||||
)
|
|
||||||
batch_uc["aesthetic_score"] = (
|
|
||||||
torch.tensor([value_dict["negative_aesthetic_score"]])
|
|
||||||
.to(device)
|
|
||||||
.repeat(*N, 1)
|
|
||||||
)
|
|
||||||
|
|
||||||
elif key == "target_size_as_tuple":
|
|
||||||
batch["target_size_as_tuple"] = (
|
|
||||||
torch.tensor([value_dict["target_height"], value_dict["target_width"]])
|
|
||||||
.to(device)
|
|
||||||
.repeat(*N, 1)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
batch[key] = value_dict[key]
|
|
||||||
|
|
||||||
for key in batch.keys():
|
|
||||||
if key not in batch_uc and isinstance(batch[key], torch.Tensor):
|
|
||||||
batch_uc[key] = torch.clone(batch[key])
|
|
||||||
return batch, batch_uc
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def do_img2img(
|
|
||||||
img,
|
|
||||||
model,
|
|
||||||
sampler,
|
|
||||||
value_dict,
|
|
||||||
num_samples,
|
|
||||||
force_uc_zero_embeddings=[],
|
|
||||||
additional_kwargs={},
|
|
||||||
offset_noise_level: int = 0.0,
|
|
||||||
return_latents=False,
|
|
||||||
skip_encode=False,
|
|
||||||
filter=None,
|
|
||||||
):
|
|
||||||
st.text("Sampling")
|
|
||||||
|
|
||||||
outputs = st.empty()
|
|
||||||
precision_scope = autocast
|
|
||||||
with torch.no_grad():
|
|
||||||
with precision_scope("cuda"):
|
|
||||||
with model.ema_scope():
|
|
||||||
batch, batch_uc = get_batch(
|
|
||||||
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
|
||||||
value_dict,
|
|
||||||
[num_samples],
|
|
||||||
)
|
|
||||||
c, uc = model.conditioner.get_unconditional_conditioning(
|
|
||||||
batch,
|
|
||||||
batch_uc=batch_uc,
|
|
||||||
force_uc_zero_embeddings=force_uc_zero_embeddings,
|
|
||||||
)
|
|
||||||
|
|
||||||
for k in c:
|
|
||||||
c[k], uc[k] = map(lambda y: y[k][:num_samples].to("cuda"), (c, uc))
|
|
||||||
|
|
||||||
for k in additional_kwargs:
|
|
||||||
c[k] = uc[k] = additional_kwargs[k]
|
|
||||||
if skip_encode:
|
|
||||||
z = img
|
|
||||||
else:
|
|
||||||
z = model.encode_first_stage(img)
|
|
||||||
noise = torch.randn_like(z)
|
|
||||||
sigmas = sampler.discretization(sampler.num_steps)
|
|
||||||
sigma = sigmas[0]
|
|
||||||
|
|
||||||
st.info(f"all sigmas: {sigmas}")
|
|
||||||
st.info(f"noising sigma: {sigma}")
|
|
||||||
|
|
||||||
if offset_noise_level > 0.0:
|
|
||||||
noise = noise + offset_noise_level * append_dims(
|
|
||||||
torch.randn(z.shape[0], device=z.device), z.ndim
|
|
||||||
)
|
|
||||||
noised_z = z + noise * append_dims(sigma, z.ndim)
|
|
||||||
noised_z = noised_z / torch.sqrt(
|
|
||||||
1.0 + sigmas[0] ** 2.0
|
|
||||||
) # Note: hardcoded to DDPM-like scaling. need to generalize later.
|
|
||||||
|
|
||||||
def denoiser(x, sigma, c):
|
|
||||||
return model.denoiser(model.model, x, sigma, c)
|
|
||||||
|
|
||||||
samples_z = sampler(denoiser, noised_z, cond=c, uc=uc)
|
|
||||||
samples_x = model.decode_first_stage(samples_z)
|
|
||||||
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
|
|
||||||
|
|
||||||
if filter is not None:
|
|
||||||
samples = filter(samples)
|
|
||||||
|
|
||||||
grid = embed_watemark(torch.stack([samples]))
|
|
||||||
grid = rearrange(grid, "n b c h w -> (n h) (b w) c")
|
|
||||||
outputs.image(grid.cpu().numpy())
|
|
||||||
if return_latents:
|
|
||||||
return samples, samples_z
|
|
||||||
return samples
|
|
||||||
|
|||||||
388
sgm/inference/api.py
Normal file
388
sgm/inference/api.py
Normal file
@@ -0,0 +1,388 @@
|
|||||||
|
from dataclasses import dataclass, asdict
|
||||||
|
from enum import Enum
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
import pathlib
|
||||||
|
from sgm.inference.helpers import (
|
||||||
|
do_sample,
|
||||||
|
do_img2img,
|
||||||
|
Img2ImgDiscretizationWrapper,
|
||||||
|
)
|
||||||
|
from sgm.modules.diffusionmodules.sampling import (
|
||||||
|
EulerEDMSampler,
|
||||||
|
HeunEDMSampler,
|
||||||
|
EulerAncestralSampler,
|
||||||
|
DPMPP2SAncestralSampler,
|
||||||
|
DPMPP2MSampler,
|
||||||
|
LinearMultistepSampler,
|
||||||
|
)
|
||||||
|
from sgm.util import load_model_from_config
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
class ModelArchitecture(str, Enum):
|
||||||
|
SD_2_1 = "stable-diffusion-v2-1"
|
||||||
|
SD_2_1_768 = "stable-diffusion-v2-1-768"
|
||||||
|
SDXL_V0_9_BASE = "stable-diffusion-xl-v0-9-base"
|
||||||
|
SDXL_V0_9_REFINER = "stable-diffusion-xl-v0-9-refiner"
|
||||||
|
SDXL_V1_BASE = "stable-diffusion-xl-v1-base"
|
||||||
|
SDXL_V1_REFINER = "stable-diffusion-xl-v1-refiner"
|
||||||
|
|
||||||
|
|
||||||
|
class Sampler(str, Enum):
|
||||||
|
EULER_EDM = "EulerEDMSampler"
|
||||||
|
HEUN_EDM = "HeunEDMSampler"
|
||||||
|
EULER_ANCESTRAL = "EulerAncestralSampler"
|
||||||
|
DPMPP2S_ANCESTRAL = "DPMPP2SAncestralSampler"
|
||||||
|
DPMPP2M = "DPMPP2MSampler"
|
||||||
|
LINEAR_MULTISTEP = "LinearMultistepSampler"
|
||||||
|
|
||||||
|
|
||||||
|
class Discretization(str, Enum):
|
||||||
|
LEGACY_DDPM = "LegacyDDPMDiscretization"
|
||||||
|
EDM = "EDMDiscretization"
|
||||||
|
|
||||||
|
|
||||||
|
class Guider(str, Enum):
|
||||||
|
VANILLA = "VanillaCFG"
|
||||||
|
IDENTITY = "IdentityGuider"
|
||||||
|
|
||||||
|
|
||||||
|
class Thresholder(str, Enum):
|
||||||
|
NONE = "None"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SamplingParams:
|
||||||
|
width: int = 1024
|
||||||
|
height: int = 1024
|
||||||
|
steps: int = 50
|
||||||
|
sampler: Sampler = Sampler.DPMPP2M
|
||||||
|
discretization: Discretization = Discretization.LEGACY_DDPM
|
||||||
|
guider: Guider = Guider.VANILLA
|
||||||
|
thresholder: Thresholder = Thresholder.NONE
|
||||||
|
scale: float = 6.0
|
||||||
|
aesthetic_score: float = 5.0
|
||||||
|
negative_aesthetic_score: float = 5.0
|
||||||
|
img2img_strength: float = 1.0
|
||||||
|
orig_width: int = 1024
|
||||||
|
orig_height: int = 1024
|
||||||
|
crop_coords_top: int = 0
|
||||||
|
crop_coords_left: int = 0
|
||||||
|
sigma_min: float = 0.0292
|
||||||
|
sigma_max: float = 14.6146
|
||||||
|
rho: float = 3.0
|
||||||
|
s_churn: float = 0.0
|
||||||
|
s_tmin: float = 0.0
|
||||||
|
s_tmax: float = 999.0
|
||||||
|
s_noise: float = 1.0
|
||||||
|
eta: float = 1.0
|
||||||
|
order: int = 4
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SamplingSpec:
|
||||||
|
width: int
|
||||||
|
height: int
|
||||||
|
channels: int
|
||||||
|
factor: int
|
||||||
|
is_legacy: bool
|
||||||
|
config: str
|
||||||
|
ckpt: str
|
||||||
|
is_guided: bool
|
||||||
|
|
||||||
|
|
||||||
|
model_specs = {
|
||||||
|
ModelArchitecture.SD_2_1: SamplingSpec(
|
||||||
|
height=512,
|
||||||
|
width=512,
|
||||||
|
channels=4,
|
||||||
|
factor=8,
|
||||||
|
is_legacy=True,
|
||||||
|
config="sd_2_1.yaml",
|
||||||
|
ckpt="v2-1_512-ema-pruned.safetensors",
|
||||||
|
is_guided=True,
|
||||||
|
),
|
||||||
|
ModelArchitecture.SD_2_1_768: SamplingSpec(
|
||||||
|
height=768,
|
||||||
|
width=768,
|
||||||
|
channels=4,
|
||||||
|
factor=8,
|
||||||
|
is_legacy=True,
|
||||||
|
config="sd_2_1_768.yaml",
|
||||||
|
ckpt="v2-1_768-ema-pruned.safetensors",
|
||||||
|
is_guided=True,
|
||||||
|
),
|
||||||
|
ModelArchitecture.SDXL_V0_9_BASE: SamplingSpec(
|
||||||
|
height=1024,
|
||||||
|
width=1024,
|
||||||
|
channels=4,
|
||||||
|
factor=8,
|
||||||
|
is_legacy=False,
|
||||||
|
config="sd_xl_base.yaml",
|
||||||
|
ckpt="sd_xl_base_0.9.safetensors",
|
||||||
|
is_guided=True,
|
||||||
|
),
|
||||||
|
ModelArchitecture.SDXL_V0_9_REFINER: SamplingSpec(
|
||||||
|
height=1024,
|
||||||
|
width=1024,
|
||||||
|
channels=4,
|
||||||
|
factor=8,
|
||||||
|
is_legacy=True,
|
||||||
|
config="sd_xl_refiner.yaml",
|
||||||
|
ckpt="sd_xl_refiner_0.9.safetensors",
|
||||||
|
is_guided=True,
|
||||||
|
),
|
||||||
|
ModelArchitecture.SDXL_V1_BASE: SamplingSpec(
|
||||||
|
height=1024,
|
||||||
|
width=1024,
|
||||||
|
channels=4,
|
||||||
|
factor=8,
|
||||||
|
is_legacy=False,
|
||||||
|
config="sd_xl_base.yaml",
|
||||||
|
ckpt="sd_xl_base_1.0-metadata.safetensors",
|
||||||
|
is_guided=True,
|
||||||
|
),
|
||||||
|
ModelArchitecture.SDXL_V1_REFINER: SamplingSpec(
|
||||||
|
height=1024,
|
||||||
|
width=1024,
|
||||||
|
channels=4,
|
||||||
|
factor=8,
|
||||||
|
is_legacy=True,
|
||||||
|
config="sd_xl_refiner.yaml",
|
||||||
|
ckpt="sd_xl_refiner_1.0-metadata.safetensors",
|
||||||
|
is_guided=True,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class SamplingPipeline:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_id: ModelArchitecture,
|
||||||
|
model_path="checkpoints",
|
||||||
|
config_path="configs/inference",
|
||||||
|
device="cuda",
|
||||||
|
use_fp16=True,
|
||||||
|
) -> None:
|
||||||
|
if model_id not in model_specs:
|
||||||
|
raise ValueError(f"Model {model_id} not supported")
|
||||||
|
self.model_id = model_id
|
||||||
|
self.specs = model_specs[self.model_id]
|
||||||
|
self.config = str(pathlib.Path(config_path, self.specs.config))
|
||||||
|
self.ckpt = str(pathlib.Path(model_path, self.specs.ckpt))
|
||||||
|
self.device = device
|
||||||
|
self.model = self._load_model(device=device, use_fp16=use_fp16)
|
||||||
|
|
||||||
|
def _load_model(self, device="cuda", use_fp16=True):
|
||||||
|
config = OmegaConf.load(self.config)
|
||||||
|
model = load_model_from_config(config, self.ckpt)
|
||||||
|
if model is None:
|
||||||
|
raise ValueError(f"Model {self.model_id} could not be loaded")
|
||||||
|
model.to(device)
|
||||||
|
if use_fp16:
|
||||||
|
model.conditioner.half()
|
||||||
|
model.model.half()
|
||||||
|
return model
|
||||||
|
|
||||||
|
def text_to_image(
|
||||||
|
self,
|
||||||
|
params: SamplingParams,
|
||||||
|
prompt: str,
|
||||||
|
negative_prompt: str = "",
|
||||||
|
samples: int = 1,
|
||||||
|
return_latents: bool = False,
|
||||||
|
):
|
||||||
|
sampler = get_sampler_config(params)
|
||||||
|
value_dict = asdict(params)
|
||||||
|
value_dict["prompt"] = prompt
|
||||||
|
value_dict["negative_prompt"] = negative_prompt
|
||||||
|
value_dict["target_width"] = params.width
|
||||||
|
value_dict["target_height"] = params.height
|
||||||
|
return do_sample(
|
||||||
|
self.model,
|
||||||
|
sampler,
|
||||||
|
value_dict,
|
||||||
|
samples,
|
||||||
|
params.height,
|
||||||
|
params.width,
|
||||||
|
self.specs.channels,
|
||||||
|
self.specs.factor,
|
||||||
|
force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
|
||||||
|
return_latents=return_latents,
|
||||||
|
filter=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
def image_to_image(
|
||||||
|
self,
|
||||||
|
params: SamplingParams,
|
||||||
|
image,
|
||||||
|
prompt: str,
|
||||||
|
negative_prompt: str = "",
|
||||||
|
samples: int = 1,
|
||||||
|
return_latents: bool = False,
|
||||||
|
):
|
||||||
|
sampler = get_sampler_config(params)
|
||||||
|
|
||||||
|
if params.img2img_strength < 1.0:
|
||||||
|
sampler.discretization = Img2ImgDiscretizationWrapper(
|
||||||
|
sampler.discretization,
|
||||||
|
strength=params.img2img_strength,
|
||||||
|
)
|
||||||
|
height, width = image.shape[2], image.shape[3]
|
||||||
|
value_dict = asdict(params)
|
||||||
|
value_dict["prompt"] = prompt
|
||||||
|
value_dict["negative_prompt"] = negative_prompt
|
||||||
|
value_dict["target_width"] = width
|
||||||
|
value_dict["target_height"] = height
|
||||||
|
return do_img2img(
|
||||||
|
image,
|
||||||
|
self.model,
|
||||||
|
sampler,
|
||||||
|
value_dict,
|
||||||
|
samples,
|
||||||
|
force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
|
||||||
|
return_latents=return_latents,
|
||||||
|
filter=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
def refiner(
|
||||||
|
self,
|
||||||
|
params: SamplingParams,
|
||||||
|
image,
|
||||||
|
prompt: str,
|
||||||
|
negative_prompt: Optional[str] = None,
|
||||||
|
samples: int = 1,
|
||||||
|
return_latents: bool = False,
|
||||||
|
):
|
||||||
|
sampler = get_sampler_config(params)
|
||||||
|
value_dict = {
|
||||||
|
"orig_width": image.shape[3] * 8,
|
||||||
|
"orig_height": image.shape[2] * 8,
|
||||||
|
"target_width": image.shape[3] * 8,
|
||||||
|
"target_height": image.shape[2] * 8,
|
||||||
|
"prompt": prompt,
|
||||||
|
"negative_prompt": negative_prompt,
|
||||||
|
"crop_coords_top": 0,
|
||||||
|
"crop_coords_left": 0,
|
||||||
|
"aesthetic_score": 6.0,
|
||||||
|
"negative_aesthetic_score": 2.5,
|
||||||
|
}
|
||||||
|
|
||||||
|
return do_img2img(
|
||||||
|
image,
|
||||||
|
self.model,
|
||||||
|
sampler,
|
||||||
|
value_dict,
|
||||||
|
samples,
|
||||||
|
skip_encode=True,
|
||||||
|
return_latents=return_latents,
|
||||||
|
filter=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_guider_config(params: SamplingParams):
|
||||||
|
if params.guider == Guider.IDENTITY:
|
||||||
|
guider_config = {
|
||||||
|
"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"
|
||||||
|
}
|
||||||
|
elif params.guider == Guider.VANILLA:
|
||||||
|
scale = params.scale
|
||||||
|
|
||||||
|
thresholder = params.thresholder
|
||||||
|
|
||||||
|
if thresholder == Thresholder.NONE:
|
||||||
|
dyn_thresh_config = {
|
||||||
|
"target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
guider_config = {
|
||||||
|
"target": "sgm.modules.diffusionmodules.guiders.VanillaCFG",
|
||||||
|
"params": {"scale": scale, "dyn_thresh_config": dyn_thresh_config},
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
return guider_config
|
||||||
|
|
||||||
|
|
||||||
|
def get_discretization_config(params: SamplingParams):
|
||||||
|
if params.discretization == Discretization.LEGACY_DDPM:
|
||||||
|
discretization_config = {
|
||||||
|
"target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",
|
||||||
|
}
|
||||||
|
elif params.discretization == Discretization.EDM:
|
||||||
|
discretization_config = {
|
||||||
|
"target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization",
|
||||||
|
"params": {
|
||||||
|
"sigma_min": params.sigma_min,
|
||||||
|
"sigma_max": params.sigma_max,
|
||||||
|
"rho": params.rho,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise ValueError(f"unknown discretization {params.discretization}")
|
||||||
|
return discretization_config
|
||||||
|
|
||||||
|
|
||||||
|
def get_sampler_config(params: SamplingParams):
|
||||||
|
discretization_config = get_discretization_config(params)
|
||||||
|
guider_config = get_guider_config(params)
|
||||||
|
sampler = None
|
||||||
|
if params.sampler == Sampler.EULER_EDM:
|
||||||
|
return EulerEDMSampler(
|
||||||
|
num_steps=params.steps,
|
||||||
|
discretization_config=discretization_config,
|
||||||
|
guider_config=guider_config,
|
||||||
|
s_churn=params.s_churn,
|
||||||
|
s_tmin=params.s_tmin,
|
||||||
|
s_tmax=params.s_tmax,
|
||||||
|
s_noise=params.s_noise,
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
if params.sampler == Sampler.HEUN_EDM:
|
||||||
|
return HeunEDMSampler(
|
||||||
|
num_steps=params.steps,
|
||||||
|
discretization_config=discretization_config,
|
||||||
|
guider_config=guider_config,
|
||||||
|
s_churn=params.s_churn,
|
||||||
|
s_tmin=params.s_tmin,
|
||||||
|
s_tmax=params.s_tmax,
|
||||||
|
s_noise=params.s_noise,
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
if params.sampler == Sampler.EULER_ANCESTRAL:
|
||||||
|
return EulerAncestralSampler(
|
||||||
|
num_steps=params.steps,
|
||||||
|
discretization_config=discretization_config,
|
||||||
|
guider_config=guider_config,
|
||||||
|
eta=params.eta,
|
||||||
|
s_noise=params.s_noise,
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
if params.sampler == Sampler.DPMPP2S_ANCESTRAL:
|
||||||
|
return DPMPP2SAncestralSampler(
|
||||||
|
num_steps=params.steps,
|
||||||
|
discretization_config=discretization_config,
|
||||||
|
guider_config=guider_config,
|
||||||
|
eta=params.eta,
|
||||||
|
s_noise=params.s_noise,
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
if params.sampler == Sampler.DPMPP2M:
|
||||||
|
return DPMPP2MSampler(
|
||||||
|
num_steps=params.steps,
|
||||||
|
discretization_config=discretization_config,
|
||||||
|
guider_config=guider_config,
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
if params.sampler == Sampler.LINEAR_MULTISTEP:
|
||||||
|
return LinearMultistepSampler(
|
||||||
|
num_steps=params.steps,
|
||||||
|
discretization_config=discretization_config,
|
||||||
|
guider_config=guider_config,
|
||||||
|
order=params.order,
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
raise ValueError(f"unknown sampler {params.sampler}!")
|
||||||
305
sgm/inference/helpers.py
Normal file
305
sgm/inference/helpers.py
Normal file
@@ -0,0 +1,305 @@
|
|||||||
|
import os
|
||||||
|
from typing import Union, List, Optional
|
||||||
|
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from einops import rearrange
|
||||||
|
from imwatermark import WatermarkEncoder
|
||||||
|
from omegaconf import ListConfig
|
||||||
|
from torch import autocast
|
||||||
|
|
||||||
|
from sgm.util import append_dims
|
||||||
|
|
||||||
|
|
||||||
|
class WatermarkEmbedder:
|
||||||
|
def __init__(self, watermark):
|
||||||
|
self.watermark = watermark
|
||||||
|
self.num_bits = len(WATERMARK_BITS)
|
||||||
|
self.encoder = WatermarkEncoder()
|
||||||
|
self.encoder.set_watermark("bits", self.watermark)
|
||||||
|
|
||||||
|
def __call__(self, image: torch.Tensor):
|
||||||
|
"""
|
||||||
|
Adds a predefined watermark to the input image
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: ([N,] B, C, H, W) in range [0, 1]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
same as input but watermarked
|
||||||
|
"""
|
||||||
|
# watermarking libary expects input as cv2 BGR format
|
||||||
|
squeeze = len(image.shape) == 4
|
||||||
|
if squeeze:
|
||||||
|
image = image[None, ...]
|
||||||
|
n = image.shape[0]
|
||||||
|
image_np = rearrange(
|
||||||
|
(255 * image).detach().cpu(), "n b c h w -> (n b) h w c"
|
||||||
|
).numpy()[:, :, :, ::-1]
|
||||||
|
# torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
|
||||||
|
for k in range(image_np.shape[0]):
|
||||||
|
image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
|
||||||
|
image = torch.from_numpy(
|
||||||
|
rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)
|
||||||
|
).to(image.device)
|
||||||
|
image = torch.clamp(image / 255, min=0.0, max=1.0)
|
||||||
|
if squeeze:
|
||||||
|
image = image[0]
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
# A fixed 48-bit message that was choosen at random
|
||||||
|
# WATERMARK_MESSAGE = 0xB3EC907BB19E
|
||||||
|
WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110
|
||||||
|
# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
|
||||||
|
WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
|
||||||
|
embed_watermark = WatermarkEmbedder(WATERMARK_BITS)
|
||||||
|
|
||||||
|
|
||||||
|
def get_unique_embedder_keys_from_conditioner(conditioner):
|
||||||
|
return list({x.input_key for x in conditioner.embedders})
|
||||||
|
|
||||||
|
|
||||||
|
def perform_save_locally(save_path, samples):
|
||||||
|
os.makedirs(os.path.join(save_path), exist_ok=True)
|
||||||
|
base_count = len(os.listdir(os.path.join(save_path)))
|
||||||
|
samples = embed_watermark(samples)
|
||||||
|
for sample in samples:
|
||||||
|
sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c")
|
||||||
|
Image.fromarray(sample.astype(np.uint8)).save(
|
||||||
|
os.path.join(save_path, f"{base_count:09}.png")
|
||||||
|
)
|
||||||
|
base_count += 1
|
||||||
|
|
||||||
|
|
||||||
|
class Img2ImgDiscretizationWrapper:
|
||||||
|
"""
|
||||||
|
wraps a discretizer, and prunes the sigmas
|
||||||
|
params:
|
||||||
|
strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, discretization, strength: float = 1.0):
|
||||||
|
self.discretization = discretization
|
||||||
|
self.strength = strength
|
||||||
|
assert 0.0 <= self.strength <= 1.0
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
# sigmas start large first, and decrease then
|
||||||
|
sigmas = self.discretization(*args, **kwargs)
|
||||||
|
print(f"sigmas after discretization, before pruning img2img: ", sigmas)
|
||||||
|
sigmas = torch.flip(sigmas, (0,))
|
||||||
|
sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)]
|
||||||
|
print("prune index:", max(int(self.strength * len(sigmas)), 1))
|
||||||
|
sigmas = torch.flip(sigmas, (0,))
|
||||||
|
print(f"sigmas after pruning: ", sigmas)
|
||||||
|
return sigmas
|
||||||
|
|
||||||
|
|
||||||
|
def do_sample(
|
||||||
|
model,
|
||||||
|
sampler,
|
||||||
|
value_dict,
|
||||||
|
num_samples,
|
||||||
|
H,
|
||||||
|
W,
|
||||||
|
C,
|
||||||
|
F,
|
||||||
|
force_uc_zero_embeddings: Optional[List] = None,
|
||||||
|
batch2model_input: Optional[List] = None,
|
||||||
|
return_latents=False,
|
||||||
|
filter=None,
|
||||||
|
device="cuda",
|
||||||
|
):
|
||||||
|
if force_uc_zero_embeddings is None:
|
||||||
|
force_uc_zero_embeddings = []
|
||||||
|
if batch2model_input is None:
|
||||||
|
batch2model_input = []
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
with autocast(device) as precision_scope:
|
||||||
|
with model.ema_scope():
|
||||||
|
num_samples = [num_samples]
|
||||||
|
batch, batch_uc = get_batch(
|
||||||
|
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
||||||
|
value_dict,
|
||||||
|
num_samples,
|
||||||
|
)
|
||||||
|
for key in batch:
|
||||||
|
if isinstance(batch[key], torch.Tensor):
|
||||||
|
print(key, batch[key].shape)
|
||||||
|
elif isinstance(batch[key], list):
|
||||||
|
print(key, [len(l) for l in batch[key]])
|
||||||
|
else:
|
||||||
|
print(key, batch[key])
|
||||||
|
c, uc = model.conditioner.get_unconditional_conditioning(
|
||||||
|
batch,
|
||||||
|
batch_uc=batch_uc,
|
||||||
|
force_uc_zero_embeddings=force_uc_zero_embeddings,
|
||||||
|
)
|
||||||
|
|
||||||
|
for k in c:
|
||||||
|
if not k == "crossattn":
|
||||||
|
c[k], uc[k] = map(
|
||||||
|
lambda y: y[k][: math.prod(num_samples)].to(device), (c, uc)
|
||||||
|
)
|
||||||
|
|
||||||
|
additional_model_inputs = {}
|
||||||
|
for k in batch2model_input:
|
||||||
|
additional_model_inputs[k] = batch[k]
|
||||||
|
|
||||||
|
shape = (math.prod(num_samples), C, H // F, W // F)
|
||||||
|
randn = torch.randn(shape).to(device)
|
||||||
|
|
||||||
|
def denoiser(input, sigma, c):
|
||||||
|
return model.denoiser(
|
||||||
|
model.model, input, sigma, c, **additional_model_inputs
|
||||||
|
)
|
||||||
|
|
||||||
|
samples_z = sampler(denoiser, randn, cond=c, uc=uc)
|
||||||
|
samples_x = model.decode_first_stage(samples_z)
|
||||||
|
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
|
||||||
|
if filter is not None:
|
||||||
|
samples = filter(samples)
|
||||||
|
|
||||||
|
if return_latents:
|
||||||
|
return samples, samples_z
|
||||||
|
return samples
|
||||||
|
|
||||||
|
|
||||||
|
def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
|
||||||
|
# Hardcoded demo setups; might undergo some changes in the future
|
||||||
|
|
||||||
|
batch = {}
|
||||||
|
batch_uc = {}
|
||||||
|
|
||||||
|
for key in keys:
|
||||||
|
if key == "txt":
|
||||||
|
batch["txt"] = (
|
||||||
|
np.repeat([value_dict["prompt"]], repeats=math.prod(N))
|
||||||
|
.reshape(N)
|
||||||
|
.tolist()
|
||||||
|
)
|
||||||
|
batch_uc["txt"] = (
|
||||||
|
np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N))
|
||||||
|
.reshape(N)
|
||||||
|
.tolist()
|
||||||
|
)
|
||||||
|
elif key == "original_size_as_tuple":
|
||||||
|
batch["original_size_as_tuple"] = (
|
||||||
|
torch.tensor([value_dict["orig_height"], value_dict["orig_width"]])
|
||||||
|
.to(device)
|
||||||
|
.repeat(*N, 1)
|
||||||
|
)
|
||||||
|
elif key == "crop_coords_top_left":
|
||||||
|
batch["crop_coords_top_left"] = (
|
||||||
|
torch.tensor(
|
||||||
|
[value_dict["crop_coords_top"], value_dict["crop_coords_left"]]
|
||||||
|
)
|
||||||
|
.to(device)
|
||||||
|
.repeat(*N, 1)
|
||||||
|
)
|
||||||
|
elif key == "aesthetic_score":
|
||||||
|
batch["aesthetic_score"] = (
|
||||||
|
torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
|
||||||
|
)
|
||||||
|
batch_uc["aesthetic_score"] = (
|
||||||
|
torch.tensor([value_dict["negative_aesthetic_score"]])
|
||||||
|
.to(device)
|
||||||
|
.repeat(*N, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
elif key == "target_size_as_tuple":
|
||||||
|
batch["target_size_as_tuple"] = (
|
||||||
|
torch.tensor([value_dict["target_height"], value_dict["target_width"]])
|
||||||
|
.to(device)
|
||||||
|
.repeat(*N, 1)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
batch[key] = value_dict[key]
|
||||||
|
|
||||||
|
for key in batch.keys():
|
||||||
|
if key not in batch_uc and isinstance(batch[key], torch.Tensor):
|
||||||
|
batch_uc[key] = torch.clone(batch[key])
|
||||||
|
return batch, batch_uc
|
||||||
|
|
||||||
|
|
||||||
|
def get_input_image_tensor(image: Image.Image, device="cuda"):
|
||||||
|
w, h = image.size
|
||||||
|
print(f"loaded input image of size ({w}, {h})")
|
||||||
|
width, height = map(
|
||||||
|
lambda x: x - x % 64, (w, h)
|
||||||
|
) # resize to integer multiple of 64
|
||||||
|
image = image.resize((width, height))
|
||||||
|
image_array = np.array(image.convert("RGB"))
|
||||||
|
image_array = image_array[None].transpose(0, 3, 1, 2)
|
||||||
|
image_tensor = torch.from_numpy(image_array).to(dtype=torch.float32) / 127.5 - 1.0
|
||||||
|
return image_tensor.to(device)
|
||||||
|
|
||||||
|
|
||||||
|
def do_img2img(
|
||||||
|
img,
|
||||||
|
model,
|
||||||
|
sampler,
|
||||||
|
value_dict,
|
||||||
|
num_samples,
|
||||||
|
force_uc_zero_embeddings=[],
|
||||||
|
additional_kwargs={},
|
||||||
|
offset_noise_level: float = 0.0,
|
||||||
|
return_latents=False,
|
||||||
|
skip_encode=False,
|
||||||
|
filter=None,
|
||||||
|
device="cuda",
|
||||||
|
):
|
||||||
|
with torch.no_grad():
|
||||||
|
with autocast(device) as precision_scope:
|
||||||
|
with model.ema_scope():
|
||||||
|
batch, batch_uc = get_batch(
|
||||||
|
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
||||||
|
value_dict,
|
||||||
|
[num_samples],
|
||||||
|
)
|
||||||
|
c, uc = model.conditioner.get_unconditional_conditioning(
|
||||||
|
batch,
|
||||||
|
batch_uc=batch_uc,
|
||||||
|
force_uc_zero_embeddings=force_uc_zero_embeddings,
|
||||||
|
)
|
||||||
|
|
||||||
|
for k in c:
|
||||||
|
c[k], uc[k] = map(lambda y: y[k][:num_samples].to(device), (c, uc))
|
||||||
|
|
||||||
|
for k in additional_kwargs:
|
||||||
|
c[k] = uc[k] = additional_kwargs[k]
|
||||||
|
if skip_encode:
|
||||||
|
z = img
|
||||||
|
else:
|
||||||
|
z = model.encode_first_stage(img)
|
||||||
|
noise = torch.randn_like(z)
|
||||||
|
sigmas = sampler.discretization(sampler.num_steps)
|
||||||
|
sigma = sigmas[0].to(z.device)
|
||||||
|
|
||||||
|
if offset_noise_level > 0.0:
|
||||||
|
noise = noise + offset_noise_level * append_dims(
|
||||||
|
torch.randn(z.shape[0], device=z.device), z.ndim
|
||||||
|
)
|
||||||
|
noised_z = z + noise * append_dims(sigma, z.ndim)
|
||||||
|
noised_z = noised_z / torch.sqrt(
|
||||||
|
1.0 + sigmas[0] ** 2.0
|
||||||
|
) # Note: hardcoded to DDPM-like scaling. need to generalize later.
|
||||||
|
|
||||||
|
def denoiser(x, sigma, c):
|
||||||
|
return model.denoiser(model.model, x, sigma, c)
|
||||||
|
|
||||||
|
samples_z = sampler(denoiser, noised_z, cond=c, uc=uc)
|
||||||
|
samples_x = model.decode_first_stage(samples_z)
|
||||||
|
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
|
||||||
|
if filter is not None:
|
||||||
|
samples = filter(samples)
|
||||||
|
|
||||||
|
if return_latents:
|
||||||
|
return samples, samples_z
|
||||||
|
return samples
|
||||||
111
tests/inference/test_inference.py
Normal file
111
tests/inference/test_inference.py
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
import numpy
|
||||||
|
from PIL import Image
|
||||||
|
import pytest
|
||||||
|
from pytest import fixture
|
||||||
|
import torch
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
from sgm.inference.api import (
|
||||||
|
model_specs,
|
||||||
|
SamplingParams,
|
||||||
|
SamplingPipeline,
|
||||||
|
Sampler,
|
||||||
|
ModelArchitecture,
|
||||||
|
)
|
||||||
|
import sgm.inference.helpers as helpers
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.inference
|
||||||
|
class TestInference:
|
||||||
|
@fixture(scope="class", params=model_specs.keys())
|
||||||
|
def pipeline(self, request) -> SamplingPipeline:
|
||||||
|
pipeline = SamplingPipeline(request.param)
|
||||||
|
yield pipeline
|
||||||
|
del pipeline
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
@fixture(
|
||||||
|
scope="class",
|
||||||
|
params=[
|
||||||
|
[ModelArchitecture.SDXL_V1_BASE, ModelArchitecture.SDXL_V1_REFINER],
|
||||||
|
[ModelArchitecture.SDXL_V0_9_BASE, ModelArchitecture.SDXL_V0_9_REFINER],
|
||||||
|
],
|
||||||
|
ids=["SDXL_V1", "SDXL_V0_9"],
|
||||||
|
)
|
||||||
|
def sdxl_pipelines(self, request) -> Tuple[SamplingPipeline, SamplingPipeline]:
|
||||||
|
base_pipeline = SamplingPipeline(request.param[0])
|
||||||
|
refiner_pipeline = SamplingPipeline(request.param[1])
|
||||||
|
yield base_pipeline, refiner_pipeline
|
||||||
|
del base_pipeline
|
||||||
|
del refiner_pipeline
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
def create_init_image(self, h, w):
|
||||||
|
image_array = numpy.random.rand(h, w, 3) * 255
|
||||||
|
image = Image.fromarray(image_array.astype("uint8")).convert("RGB")
|
||||||
|
return helpers.get_input_image_tensor(image)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("sampler_enum", Sampler)
|
||||||
|
def test_txt2img(self, pipeline: SamplingPipeline, sampler_enum):
|
||||||
|
output = pipeline.text_to_image(
|
||||||
|
params=SamplingParams(sampler=sampler_enum.value, steps=10),
|
||||||
|
prompt="A professional photograph of an astronaut riding a pig",
|
||||||
|
negative_prompt="",
|
||||||
|
samples=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert output is not None
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("sampler_enum", Sampler)
|
||||||
|
def test_img2img(self, pipeline: SamplingPipeline, sampler_enum):
|
||||||
|
output = pipeline.image_to_image(
|
||||||
|
params=SamplingParams(sampler=sampler_enum.value, steps=10),
|
||||||
|
image=self.create_init_image(pipeline.specs.height, pipeline.specs.width),
|
||||||
|
prompt="A professional photograph of an astronaut riding a pig",
|
||||||
|
negative_prompt="",
|
||||||
|
samples=1,
|
||||||
|
)
|
||||||
|
assert output is not None
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("sampler_enum", Sampler)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"use_init_image", [True, False], ids=["img2img", "txt2img"]
|
||||||
|
)
|
||||||
|
def test_sdxl_with_refiner(
|
||||||
|
self,
|
||||||
|
sdxl_pipelines: Tuple[SamplingPipeline, SamplingPipeline],
|
||||||
|
sampler_enum,
|
||||||
|
use_init_image,
|
||||||
|
):
|
||||||
|
base_pipeline, refiner_pipeline = sdxl_pipelines
|
||||||
|
if use_init_image:
|
||||||
|
output = base_pipeline.image_to_image(
|
||||||
|
params=SamplingParams(sampler=sampler_enum.value, steps=10),
|
||||||
|
image=self.create_init_image(
|
||||||
|
base_pipeline.specs.height, base_pipeline.specs.width
|
||||||
|
),
|
||||||
|
prompt="A professional photograph of an astronaut riding a pig",
|
||||||
|
negative_prompt="",
|
||||||
|
samples=1,
|
||||||
|
return_latents=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
output = base_pipeline.text_to_image(
|
||||||
|
params=SamplingParams(sampler=sampler_enum.value, steps=10),
|
||||||
|
prompt="A professional photograph of an astronaut riding a pig",
|
||||||
|
negative_prompt="",
|
||||||
|
samples=1,
|
||||||
|
return_latents=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(output, (tuple, list))
|
||||||
|
samples, samples_z = output
|
||||||
|
assert samples is not None
|
||||||
|
assert samples_z is not None
|
||||||
|
refiner_pipeline.refiner(
|
||||||
|
params=SamplingParams(sampler=sampler_enum.value, steps=10),
|
||||||
|
image=samples_z,
|
||||||
|
prompt="A professional photograph of an astronaut riding a pig",
|
||||||
|
negative_prompt="",
|
||||||
|
samples=1,
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user