Improved sampling (#69)

* New research features

* Add new model specs
---------

Co-authored-by: Dominik Lorenz <53151171+qp-qp@users.noreply.github.com>

* remove sd1.5 and change default refiner to 1.0

* remove asking second time for output

* adapt model names

* adjusted strength

* Correctly pass prompt

---------

Co-authored-by: Dominik Lorenz <53151171+qp-qp@users.noreply.github.com>
This commit is contained in:
Jonas Müller
2023-07-26 19:49:23 +02:00
committed by GitHub
parent f2fa96b7e5
commit e5d714d304
2 changed files with 514 additions and 93 deletions

View File

@@ -1,14 +1,6 @@
import numpy as np
from pytorch_lightning import seed_everything from pytorch_lightning import seed_everything
from scripts.demo.streamlit_helpers import * from scripts.demo.streamlit_helpers import *
from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering
from sgm.inference.helpers import (
do_img2img,
do_sample,
get_unique_embedder_keys_from_conditioner,
perform_save_locally,
)
SAVE_PATH = "outputs/demo/txt2img/" SAVE_PATH = "outputs/demo/txt2img/"
@@ -42,7 +34,16 @@ SD_XL_BASE_RATIOS = {
} }
VERSION2SPECS = { VERSION2SPECS = {
"SD-XL base": { "SDXL-base-1.0": {
"H": 1024,
"W": 1024,
"C": 4,
"f": 8,
"is_legacy": False,
"config": "configs/inference/sd_xl_base.yaml",
"ckpt": "checkpoints/sd_xl_base_1.0.safetensors",
},
"SDXL-base-0.9": {
"H": 1024, "H": 1024,
"W": 1024, "W": 1024,
"C": 4, "C": 4,
@@ -50,9 +51,8 @@ VERSION2SPECS = {
"is_legacy": False, "is_legacy": False,
"config": "configs/inference/sd_xl_base.yaml", "config": "configs/inference/sd_xl_base.yaml",
"ckpt": "checkpoints/sd_xl_base_0.9.safetensors", "ckpt": "checkpoints/sd_xl_base_0.9.safetensors",
"is_guided": True,
}, },
"sd-2.1": { "SD-2.1": {
"H": 512, "H": 512,
"W": 512, "W": 512,
"C": 4, "C": 4,
@@ -60,9 +60,8 @@ VERSION2SPECS = {
"is_legacy": True, "is_legacy": True,
"config": "configs/inference/sd_2_1.yaml", "config": "configs/inference/sd_2_1.yaml",
"ckpt": "checkpoints/v2-1_512-ema-pruned.safetensors", "ckpt": "checkpoints/v2-1_512-ema-pruned.safetensors",
"is_guided": True,
}, },
"sd-2.1-768": { "SD-2.1-768": {
"H": 768, "H": 768,
"W": 768, "W": 768,
"C": 4, "C": 4,
@@ -71,7 +70,7 @@ VERSION2SPECS = {
"config": "configs/inference/sd_2_1_768.yaml", "config": "configs/inference/sd_2_1_768.yaml",
"ckpt": "checkpoints/v2-1_768-ema-pruned.safetensors", "ckpt": "checkpoints/v2-1_768-ema-pruned.safetensors",
}, },
"SDXL-Refiner": { "SDXL-refiner-0.9": {
"H": 1024, "H": 1024,
"W": 1024, "W": 1024,
"C": 4, "C": 4,
@@ -79,7 +78,15 @@ VERSION2SPECS = {
"is_legacy": True, "is_legacy": True,
"config": "configs/inference/sd_xl_refiner.yaml", "config": "configs/inference/sd_xl_refiner.yaml",
"ckpt": "checkpoints/sd_xl_refiner_0.9.safetensors", "ckpt": "checkpoints/sd_xl_refiner_0.9.safetensors",
"is_guided": True, },
"SDXL-refiner-1.0": {
"H": 1024,
"W": 1024,
"C": 4,
"f": 8,
"is_legacy": True,
"config": "configs/inference/sd_xl_refiner.yaml",
"ckpt": "checkpoints/sd_xl_refiner_1.0.safetensors",
}, },
} }
@@ -103,18 +110,19 @@ def load_img(display=True, key=None, device="cuda"):
def run_txt2img( def run_txt2img(
state, version, version_dict, is_legacy=False, return_latents=False, filter=None state,
version,
version_dict,
is_legacy=False,
return_latents=False,
filter=None,
stage2strength=None,
): ):
if version == "SD-XL base": if version.startswith("SDXL-base"):
ratio = st.sidebar.selectbox("Ratio:", list(SD_XL_BASE_RATIOS.keys()), 10) W, H = st.selectbox("Resolution:", list(SD_XL_BASE_RATIOS.values()), 10)
W, H = SD_XL_BASE_RATIOS[ratio]
else: else:
H = st.sidebar.number_input( H = st.number_input("H", value=version_dict["H"], min_value=64, max_value=2048)
"H", value=version_dict["H"], min_value=64, max_value=2048 W = st.number_input("W", value=version_dict["W"], min_value=64, max_value=2048)
)
W = st.sidebar.number_input(
"W", value=version_dict["W"], min_value=64, max_value=2048
)
C = version_dict["C"] C = version_dict["C"]
F = version_dict["f"] F = version_dict["f"]
@@ -130,16 +138,11 @@ def run_txt2img(
prompt=prompt, prompt=prompt,
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
) )
num_rows, num_cols, sampler = init_sampling( sampler, num_rows, num_cols = init_sampling(stage2strength=stage2strength)
use_identity_guider=not version_dict["is_guided"]
)
num_samples = num_rows * num_cols num_samples = num_rows * num_cols
if st.button("Sample"): if st.button("Sample"):
st.write(f"**Model I:** {version}") st.write(f"**Model I:** {version}")
outputs = st.empty()
st.text("Sampling")
out = do_sample( out = do_sample(
state["model"], state["model"],
sampler, sampler,
@@ -153,13 +156,16 @@ def run_txt2img(
return_latents=return_latents, return_latents=return_latents,
filter=filter, filter=filter,
) )
show_samples(out, outputs)
return out return out
def run_img2img( def run_img2img(
state, version_dict, is_legacy=False, return_latents=False, filter=None state,
version_dict,
is_legacy=False,
return_latents=False,
filter=None,
stage2strength=None,
): ):
img = load_img() img = load_img()
if img is None: if img is None:
@@ -175,19 +181,19 @@ def run_img2img(
value_dict = init_embedder_options( value_dict = init_embedder_options(
get_unique_embedder_keys_from_conditioner(state["model"].conditioner), get_unique_embedder_keys_from_conditioner(state["model"].conditioner),
init_dict, init_dict,
prompt=prompt,
negative_prompt=negative_prompt,
) )
strength = st.number_input( strength = st.number_input(
"**Img2Img Strength**", value=0.5, min_value=0.0, max_value=1.0 "**Img2Img Strength**", value=0.75, min_value=0.0, max_value=1.0
) )
num_rows, num_cols, sampler = init_sampling( sampler, num_rows, num_cols = init_sampling(
img2img_strength=strength, img2img_strength=strength,
use_identity_guider=not version_dict["is_guided"], stage2strength=stage2strength,
) )
num_samples = num_rows * num_cols num_samples = num_rows * num_cols
if st.button("Sample"): if st.button("Sample"):
outputs = st.empty()
st.text("Sampling")
out = do_img2img( out = do_img2img(
repeat(img, "1 ... -> n ...", n=num_samples), repeat(img, "1 ... -> n ...", n=num_samples),
state["model"], state["model"],
@@ -198,7 +204,6 @@ def run_img2img(
return_latents=return_latents, return_latents=return_latents,
filter=filter, filter=filter,
) )
show_samples(out, outputs)
return out return out
@@ -210,6 +215,7 @@ def apply_refiner(
prompt, prompt,
negative_prompt, negative_prompt,
filter=None, filter=None,
finish_denoising=False,
): ):
init_dict = { init_dict = {
"orig_width": input.shape[3] * 8, "orig_width": input.shape[3] * 8,
@@ -237,6 +243,7 @@ def apply_refiner(
num_samples, num_samples,
skip_encode=True, skip_encode=True,
filter=filter, filter=filter,
add_noise=not finish_denoising,
) )
return samples return samples
@@ -249,20 +256,22 @@ if __name__ == "__main__":
mode = st.radio("Mode", ("txt2img", "img2img"), 0) mode = st.radio("Mode", ("txt2img", "img2img"), 0)
st.write("__________________________") st.write("__________________________")
if version == "SD-XL base": set_lowvram_mode(st.checkbox("Low vram mode", True))
add_pipeline = st.checkbox("Load SDXL-Refiner?", False)
if version.startswith("SDXL-base"):
add_pipeline = st.checkbox("Load SDXL-refiner?", False)
st.write("__________________________") st.write("__________________________")
else: else:
add_pipeline = False add_pipeline = False
filter = DeepFloydDataFiltering(verbose=False)
seed = st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9)) seed = st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9))
seed_everything(seed) seed_everything(seed)
save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, version)) save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, version))
state = init_st(version_dict) state = init_st(version_dict, load_filter=True)
if state["msg"]:
st.info(state["msg"])
model = state["model"] model = state["model"]
is_legacy = version_dict["is_legacy"] is_legacy = version_dict["is_legacy"]
@@ -276,29 +285,34 @@ if __name__ == "__main__":
else: else:
negative_prompt = "" # which is unused negative_prompt = "" # which is unused
stage2strength = None
finish_denoising = False
if add_pipeline: if add_pipeline:
st.write("__________________________") st.write("__________________________")
version2 = st.selectbox("Refiner:", ["SDXL-refiner-1.0", "SDXL-refiner-0.9"])
version2 = "SDXL-Refiner"
st.warning( st.warning(
f"Running with {version2} as the second stage model. Make sure to provide (V)RAM :) " f"Running with {version2} as the second stage model. Make sure to provide (V)RAM :) "
) )
st.write("**Refiner Options:**") st.write("**Refiner Options:**")
version_dict2 = VERSION2SPECS[version2] version_dict2 = VERSION2SPECS[version2]
state2 = init_st(version_dict2) state2 = init_st(version_dict2, load_filter=False)
st.info(state2["msg"])
stage2strength = st.number_input( stage2strength = st.number_input(
"**Refinement strength**", value=0.3, min_value=0.0, max_value=1.0 "**Refinement strength**", value=0.15, min_value=0.0, max_value=1.0
) )
sampler2 = init_sampling( sampler2, *_ = init_sampling(
key=2, key=2,
img2img_strength=stage2strength, img2img_strength=stage2strength,
use_identity_guider=not version_dict2["is_guided"], specify_num_samples=False,
get_num_samples=False,
) )
st.write("__________________________") st.write("__________________________")
finish_denoising = st.checkbox("Finish denoising with refiner.", True)
if not finish_denoising:
stage2strength = None
if mode == "txt2img": if mode == "txt2img":
out = run_txt2img( out = run_txt2img(
@@ -307,7 +321,8 @@ if __name__ == "__main__":
version_dict, version_dict,
is_legacy=is_legacy, is_legacy=is_legacy,
return_latents=add_pipeline, return_latents=add_pipeline,
filter=filter, filter=state.get("filter"),
stage2strength=stage2strength,
) )
elif mode == "img2img": elif mode == "img2img":
out = run_img2img( out = run_img2img(
@@ -315,7 +330,8 @@ if __name__ == "__main__":
version_dict, version_dict,
is_legacy=is_legacy, is_legacy=is_legacy,
return_latents=add_pipeline, return_latents=add_pipeline,
filter=filter, filter=state.get("filter"),
stage2strength=stage2strength,
) )
else: else:
raise ValueError(f"unknown mode {mode}") raise ValueError(f"unknown mode {mode}")
@@ -326,7 +342,6 @@ if __name__ == "__main__":
samples_z = None samples_z = None
if add_pipeline and samples_z is not None: if add_pipeline and samples_z is not None:
outputs = st.empty()
st.write("**Running Refinement Stage**") st.write("**Running Refinement Stage**")
samples = apply_refiner( samples = apply_refiner(
samples_z, samples_z,
@@ -335,9 +350,9 @@ if __name__ == "__main__":
samples_z.shape[0], samples_z.shape[0],
prompt=prompt, prompt=prompt,
negative_prompt=negative_prompt if is_legacy else "", negative_prompt=negative_prompt if is_legacy else "",
filter=filter, filter=state.get("filter"),
finish_denoising=finish_denoising,
) )
show_samples(samples, outputs)
if save_locally and samples is not None: if save_locally and samples is not None:
perform_save_locally(save_path, samples) perform_save_locally(save_path, samples)

View File

@@ -1,12 +1,20 @@
import math
import os import os
from typing import List, Union
import numpy as np
import streamlit as st import streamlit as st
import torch import torch
from PIL import Image
from einops import rearrange, repeat from einops import rearrange, repeat
from omegaconf import OmegaConf from imwatermark import WatermarkEncoder
from omegaconf import ListConfig, OmegaConf
from PIL import Image
from safetensors.torch import load_file as load_safetensors
from torch import autocast
from torchvision import transforms from torchvision import transforms
from torchvision.utils import make_grid
from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering
from sgm.modules.diffusionmodules.sampling import ( from sgm.modules.diffusionmodules.sampling import (
DPMPP2MSampler, DPMPP2MSampler,
DPMPP2SAncestralSampler, DPMPP2SAncestralSampler,
@@ -15,29 +23,140 @@ from sgm.modules.diffusionmodules.sampling import (
HeunEDMSampler, HeunEDMSampler,
LinearMultistepSampler, LinearMultistepSampler,
) )
from sgm.inference.helpers import Img2ImgDiscretizationWrapper, embed_watermark from sgm.util import append_dims, instantiate_from_config
from sgm.util import load_model_from_config
class WatermarkEmbedder:
def __init__(self, watermark):
self.watermark = watermark
self.num_bits = len(WATERMARK_BITS)
self.encoder = WatermarkEncoder()
self.encoder.set_watermark("bits", self.watermark)
def __call__(self, image: torch.Tensor):
"""
Adds a predefined watermark to the input image
Args:
image: ([N,] B, C, H, W) in range [0, 1]
Returns:
same as input but watermarked
"""
# watermarking libary expects input as cv2 BGR format
squeeze = len(image.shape) == 4
if squeeze:
image = image[None, ...]
n = image.shape[0]
image_np = rearrange(
(255 * image).detach().cpu(), "n b c h w -> (n b) h w c"
).numpy()[:, :, :, ::-1]
# torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
for k in range(image_np.shape[0]):
image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
image = torch.from_numpy(
rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)
).to(image.device)
image = torch.clamp(image / 255, min=0.0, max=1.0)
if squeeze:
image = image[0]
return image
# A fixed 48-bit message that was choosen at random
# WATERMARK_MESSAGE = 0xB3EC907BB19E
WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110
# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
embed_watemark = WatermarkEmbedder(WATERMARK_BITS)
@st.cache_resource() @st.cache_resource()
def init_st(version_dict, load_ckpt=True): def init_st(version_dict, load_ckpt=True, load_filter=True):
state = dict() state = dict()
if not "model" in state: if not "model" in state:
config = version_dict["config"] config = version_dict["config"]
ckpt = version_dict["ckpt"] ckpt = version_dict["ckpt"]
config = OmegaConf.load(config) config = OmegaConf.load(config)
model = load_model_from_config(config, ckpt if load_ckpt else None) model, msg = load_model_from_config(config, ckpt if load_ckpt else None)
model = model.to("cuda")
model.conditioner.half()
model.model.half()
state["msg"] = msg
state["model"] = model state["model"] = model
state["ckpt"] = ckpt if load_ckpt else None state["ckpt"] = ckpt if load_ckpt else None
state["config"] = config state["config"] = config
if load_filter:
state["filter"] = DeepFloydDataFiltering(verbose=False)
return state return state
def load_model(model):
model.cuda()
lowvram_mode = False
def set_lowvram_mode(mode):
global lowvram_mode
lowvram_mode = mode
def initial_model_load(model):
global lowvram_mode
if lowvram_mode:
model.model.half()
else:
model.cuda()
return model
def unload_model(model):
global lowvram_mode
if lowvram_mode:
model.cpu()
torch.cuda.empty_cache()
def load_model_from_config(config, ckpt=None, verbose=True):
model = instantiate_from_config(config.model)
if ckpt is not None:
print(f"Loading model from {ckpt}")
if ckpt.endswith("ckpt"):
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
global_step = pl_sd["global_step"]
st.info(f"loaded ckpt from global step {global_step}")
print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"]
elif ckpt.endswith("safetensors"):
sd = load_safetensors(ckpt)
else:
raise NotImplementedError
msg = None
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print("missing keys:")
print(m)
if len(u) > 0 and verbose:
print("unexpected keys:")
print(u)
else:
msg = None
model = initial_model_load(model)
model.eval()
return model, msg
def get_unique_embedder_keys_from_conditioner(conditioner):
return list(set([x.input_key for x in conditioner.embedders]))
def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None): def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None):
# Hardcoded demo settings; might undergo some changes in the future # Hardcoded demo settings; might undergo some changes in the future
@@ -81,23 +200,24 @@ def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None):
value_dict["negative_aesthetic_score"] = 2.5 value_dict["negative_aesthetic_score"] = 2.5
if key == "target_size_as_tuple": if key == "target_size_as_tuple":
target_width = st.number_input( value_dict["target_width"] = init_dict["target_width"]
"target_width", value_dict["target_height"] = init_dict["target_height"]
value=init_dict["target_width"],
min_value=16,
)
target_height = st.number_input(
"target_height",
value=init_dict["target_height"],
min_value=16,
)
value_dict["target_width"] = target_width
value_dict["target_height"] = target_height
return value_dict return value_dict
def perform_save_locally(save_path, samples):
os.makedirs(os.path.join(save_path), exist_ok=True)
base_count = len(os.listdir(os.path.join(save_path)))
samples = embed_watemark(samples)
for sample in samples:
sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c")
Image.fromarray(sample.astype(np.uint8)).save(
os.path.join(save_path, f"{base_count:09}.png")
)
base_count += 1
def init_save_locally(_dir, init_value: bool = False): def init_save_locally(_dir, init_value: bool = False):
save_locally = st.sidebar.checkbox("Save images locally", value=init_value) save_locally = st.sidebar.checkbox("Save images locally", value=init_value)
if save_locally: if save_locally:
@@ -108,12 +228,58 @@ def init_save_locally(_dir, init_value: bool = False):
return save_locally, save_path return save_locally, save_path
def show_samples(samples, outputs): class Img2ImgDiscretizationWrapper:
if isinstance(samples, tuple): """
samples, _ = samples wraps a discretizer, and prunes the sigmas
grid = embed_watermark(torch.stack([samples])) params:
grid = rearrange(grid, "n b c h w -> (n h) (b w) c") strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned)
outputs.image(grid.cpu().numpy()) """
def __init__(self, discretization, strength: float = 1.0):
self.discretization = discretization
self.strength = strength
assert 0.0 <= self.strength <= 1.0
def __call__(self, *args, **kwargs):
# sigmas start large first, and decrease then
sigmas = self.discretization(*args, **kwargs)
print(f"sigmas after discretization, before pruning img2img: ", sigmas)
sigmas = torch.flip(sigmas, (0,))
sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)]
print("prune index:", max(int(self.strength * len(sigmas)), 1))
sigmas = torch.flip(sigmas, (0,))
print(f"sigmas after pruning: ", sigmas)
return sigmas
class Txt2NoisyDiscretizationWrapper:
"""
wraps a discretizer, and prunes the sigmas
params:
strength: float between 0.0 and 1.0. 0.0 means full sampling (all sigmas are returned)
"""
def __init__(self, discretization, strength: float = 0.0, original_steps=None):
self.discretization = discretization
self.strength = strength
self.original_steps = original_steps
assert 0.0 <= self.strength <= 1.0
def __call__(self, *args, **kwargs):
# sigmas start large first, and decrease then
sigmas = self.discretization(*args, **kwargs)
print(f"sigmas after discretization, before pruning img2img: ", sigmas)
sigmas = torch.flip(sigmas, (0,))
if self.original_steps is None:
steps = len(sigmas)
else:
steps = self.original_steps + 1
prune_index = max(min(int(self.strength * steps) - 1, steps - 1), 0)
sigmas = sigmas[prune_index:]
print("prune index:", prune_index)
sigmas = torch.flip(sigmas, (0,))
print(f"sigmas after pruning: ", sigmas)
return sigmas
def get_guider(key): def get_guider(key):
@@ -158,16 +324,19 @@ def get_guider(key):
def init_sampling( def init_sampling(
key=1, img2img_strength=1.0, use_identity_guider=False, get_num_samples=True key=1,
img2img_strength=1.0,
specify_num_samples=True,
stage2strength=None,
): ):
if get_num_samples: num_rows, num_cols = 1, 1
num_rows = 1 if specify_num_samples:
num_cols = st.number_input( num_cols = st.number_input(
f"num cols #{key}", value=2, min_value=1, max_value=10 f"num cols #{key}", value=2, min_value=1, max_value=10
) )
steps = st.sidebar.number_input( steps = st.sidebar.number_input(
f"steps #{key}", value=50, min_value=1, max_value=1000 f"steps #{key}", value=40, min_value=1, max_value=1000
) )
sampler = st.sidebar.selectbox( sampler = st.sidebar.selectbox(
f"Sampler #{key}", f"Sampler #{key}",
@@ -201,9 +370,11 @@ def init_sampling(
sampler.discretization = Img2ImgDiscretizationWrapper( sampler.discretization = Img2ImgDiscretizationWrapper(
sampler.discretization, strength=img2img_strength sampler.discretization, strength=img2img_strength
) )
if get_num_samples: if stage2strength is not None:
return num_rows, num_cols, sampler sampler.discretization = Txt2NoisyDiscretizationWrapper(
return sampler sampler.discretization, strength=stage2strength, original_steps=steps
)
return sampler, num_rows, num_cols
def get_discretization(discretization, key=1): def get_discretization(discretization, key=1):
@@ -336,3 +507,238 @@ def get_init_img(batch_size=1, key=None):
init_image = load_img(key=key).cuda() init_image = load_img(key=key).cuda()
init_image = repeat(init_image, "1 ... -> b ...", b=batch_size) init_image = repeat(init_image, "1 ... -> b ...", b=batch_size)
return init_image return init_image
def do_sample(
model,
sampler,
value_dict,
num_samples,
H,
W,
C,
F,
force_uc_zero_embeddings: List = None,
batch2model_input: List = None,
return_latents=False,
filter=None,
):
if force_uc_zero_embeddings is None:
force_uc_zero_embeddings = []
if batch2model_input is None:
batch2model_input = []
st.text("Sampling")
outputs = st.empty()
precision_scope = autocast
with torch.no_grad():
with precision_scope("cuda"):
with model.ema_scope():
num_samples = [num_samples]
load_model(model.conditioner)
batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner),
value_dict,
num_samples,
)
for key in batch:
if isinstance(batch[key], torch.Tensor):
print(key, batch[key].shape)
elif isinstance(batch[key], list):
print(key, [len(l) for l in batch[key]])
else:
print(key, batch[key])
c, uc = model.conditioner.get_unconditional_conditioning(
batch,
batch_uc=batch_uc,
force_uc_zero_embeddings=force_uc_zero_embeddings,
)
unload_model(model.conditioner)
for k in c:
if not k == "crossattn":
c[k], uc[k] = map(
lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc)
)
additional_model_inputs = {}
for k in batch2model_input:
additional_model_inputs[k] = batch[k]
shape = (math.prod(num_samples), C, H // F, W // F)
randn = torch.randn(shape).to("cuda")
def denoiser(input, sigma, c):
return model.denoiser(
model.model, input, sigma, c, **additional_model_inputs
)
load_model(model.denoiser)
load_model(model.model)
samples_z = sampler(denoiser, randn, cond=c, uc=uc)
unload_model(model.model)
unload_model(model.denoiser)
load_model(model.first_stage_model)
samples_x = model.decode_first_stage(samples_z)
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
unload_model(model.first_stage_model)
if filter is not None:
samples = filter(samples)
grid = torch.stack([samples])
grid = rearrange(grid, "n b c h w -> (n h) (b w) c")
outputs.image(grid.cpu().numpy())
if return_latents:
return samples, samples_z
return samples
def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
# Hardcoded demo setups; might undergo some changes in the future
batch = {}
batch_uc = {}
for key in keys:
if key == "txt":
batch["txt"] = (
np.repeat([value_dict["prompt"]], repeats=math.prod(N))
.reshape(N)
.tolist()
)
batch_uc["txt"] = (
np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N))
.reshape(N)
.tolist()
)
elif key == "original_size_as_tuple":
batch["original_size_as_tuple"] = (
torch.tensor([value_dict["orig_height"], value_dict["orig_width"]])
.to(device)
.repeat(*N, 1)
)
elif key == "crop_coords_top_left":
batch["crop_coords_top_left"] = (
torch.tensor(
[value_dict["crop_coords_top"], value_dict["crop_coords_left"]]
)
.to(device)
.repeat(*N, 1)
)
elif key == "aesthetic_score":
batch["aesthetic_score"] = (
torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
)
batch_uc["aesthetic_score"] = (
torch.tensor([value_dict["negative_aesthetic_score"]])
.to(device)
.repeat(*N, 1)
)
elif key == "target_size_as_tuple":
batch["target_size_as_tuple"] = (
torch.tensor([value_dict["target_height"], value_dict["target_width"]])
.to(device)
.repeat(*N, 1)
)
else:
batch[key] = value_dict[key]
for key in batch.keys():
if key not in batch_uc and isinstance(batch[key], torch.Tensor):
batch_uc[key] = torch.clone(batch[key])
return batch, batch_uc
@torch.no_grad()
def do_img2img(
img,
model,
sampler,
value_dict,
num_samples,
force_uc_zero_embeddings=[],
additional_kwargs={},
offset_noise_level: int = 0.0,
return_latents=False,
skip_encode=False,
filter=None,
add_noise=True,
):
st.text("Sampling")
outputs = st.empty()
precision_scope = autocast
with torch.no_grad():
with precision_scope("cuda"):
with model.ema_scope():
load_model(model.conditioner)
batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner),
value_dict,
[num_samples],
)
c, uc = model.conditioner.get_unconditional_conditioning(
batch,
batch_uc=batch_uc,
force_uc_zero_embeddings=force_uc_zero_embeddings,
)
unload_model(model.conditioner)
for k in c:
c[k], uc[k] = map(lambda y: y[k][:num_samples].to("cuda"), (c, uc))
for k in additional_kwargs:
c[k] = uc[k] = additional_kwargs[k]
if skip_encode:
z = img
else:
load_model(model.first_stage_model)
z = model.encode_first_stage(img)
unload_model(model.first_stage_model)
noise = torch.randn_like(z)
sigmas = sampler.discretization(sampler.num_steps).cuda()
sigma = sigmas[0]
st.info(f"all sigmas: {sigmas}")
st.info(f"noising sigma: {sigma}")
if offset_noise_level > 0.0:
noise = noise + offset_noise_level * append_dims(
torch.randn(z.shape[0], device=z.device), z.ndim
)
if add_noise:
noised_z = z + noise * append_dims(sigma, z.ndim).cuda()
noised_z = noised_z / torch.sqrt(
1.0 + sigmas[0] ** 2.0
) # Note: hardcoded to DDPM-like scaling. need to generalize later.
else:
noised_z = z / torch.sqrt(1.0 + sigmas[0] ** 2.0)
def denoiser(x, sigma, c):
return model.denoiser(model.model, x, sigma, c)
load_model(model.denoiser)
load_model(model.model)
samples_z = sampler(denoiser, noised_z, cond=c, uc=uc)
unload_model(model.model)
unload_model(model.denoiser)
load_model(model.first_stage_model)
samples_x = model.decode_first_stage(samples_z)
unload_model(model.first_stage_model)
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
if filter is not None:
samples = filter(samples)
grid = embed_watemark(torch.stack([samples]))
grid = rearrange(grid, "n b c h w -> (n h) (b w) c")
outputs.image(grid.cpu().numpy())
if return_latents:
return samples, samples_z
return samples