mirror of
https://github.com/Stability-AI/generative-models.git
synced 2025-12-19 22:34:22 +01:00
Improved sampling (#69)
* New research features * Add new model specs --------- Co-authored-by: Dominik Lorenz <53151171+qp-qp@users.noreply.github.com> * remove sd1.5 and change default refiner to 1.0 * remove asking second time for output * adapt model names * adjusted strength * Correctly pass prompt --------- Co-authored-by: Dominik Lorenz <53151171+qp-qp@users.noreply.github.com>
This commit is contained in:
@@ -1,14 +1,6 @@
|
|||||||
import numpy as np
|
|
||||||
from pytorch_lightning import seed_everything
|
from pytorch_lightning import seed_everything
|
||||||
|
|
||||||
from scripts.demo.streamlit_helpers import *
|
from scripts.demo.streamlit_helpers import *
|
||||||
from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering
|
|
||||||
from sgm.inference.helpers import (
|
|
||||||
do_img2img,
|
|
||||||
do_sample,
|
|
||||||
get_unique_embedder_keys_from_conditioner,
|
|
||||||
perform_save_locally,
|
|
||||||
)
|
|
||||||
|
|
||||||
SAVE_PATH = "outputs/demo/txt2img/"
|
SAVE_PATH = "outputs/demo/txt2img/"
|
||||||
|
|
||||||
@@ -42,7 +34,16 @@ SD_XL_BASE_RATIOS = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
VERSION2SPECS = {
|
VERSION2SPECS = {
|
||||||
"SD-XL base": {
|
"SDXL-base-1.0": {
|
||||||
|
"H": 1024,
|
||||||
|
"W": 1024,
|
||||||
|
"C": 4,
|
||||||
|
"f": 8,
|
||||||
|
"is_legacy": False,
|
||||||
|
"config": "configs/inference/sd_xl_base.yaml",
|
||||||
|
"ckpt": "checkpoints/sd_xl_base_1.0.safetensors",
|
||||||
|
},
|
||||||
|
"SDXL-base-0.9": {
|
||||||
"H": 1024,
|
"H": 1024,
|
||||||
"W": 1024,
|
"W": 1024,
|
||||||
"C": 4,
|
"C": 4,
|
||||||
@@ -50,9 +51,8 @@ VERSION2SPECS = {
|
|||||||
"is_legacy": False,
|
"is_legacy": False,
|
||||||
"config": "configs/inference/sd_xl_base.yaml",
|
"config": "configs/inference/sd_xl_base.yaml",
|
||||||
"ckpt": "checkpoints/sd_xl_base_0.9.safetensors",
|
"ckpt": "checkpoints/sd_xl_base_0.9.safetensors",
|
||||||
"is_guided": True,
|
|
||||||
},
|
},
|
||||||
"sd-2.1": {
|
"SD-2.1": {
|
||||||
"H": 512,
|
"H": 512,
|
||||||
"W": 512,
|
"W": 512,
|
||||||
"C": 4,
|
"C": 4,
|
||||||
@@ -60,9 +60,8 @@ VERSION2SPECS = {
|
|||||||
"is_legacy": True,
|
"is_legacy": True,
|
||||||
"config": "configs/inference/sd_2_1.yaml",
|
"config": "configs/inference/sd_2_1.yaml",
|
||||||
"ckpt": "checkpoints/v2-1_512-ema-pruned.safetensors",
|
"ckpt": "checkpoints/v2-1_512-ema-pruned.safetensors",
|
||||||
"is_guided": True,
|
|
||||||
},
|
},
|
||||||
"sd-2.1-768": {
|
"SD-2.1-768": {
|
||||||
"H": 768,
|
"H": 768,
|
||||||
"W": 768,
|
"W": 768,
|
||||||
"C": 4,
|
"C": 4,
|
||||||
@@ -71,7 +70,7 @@ VERSION2SPECS = {
|
|||||||
"config": "configs/inference/sd_2_1_768.yaml",
|
"config": "configs/inference/sd_2_1_768.yaml",
|
||||||
"ckpt": "checkpoints/v2-1_768-ema-pruned.safetensors",
|
"ckpt": "checkpoints/v2-1_768-ema-pruned.safetensors",
|
||||||
},
|
},
|
||||||
"SDXL-Refiner": {
|
"SDXL-refiner-0.9": {
|
||||||
"H": 1024,
|
"H": 1024,
|
||||||
"W": 1024,
|
"W": 1024,
|
||||||
"C": 4,
|
"C": 4,
|
||||||
@@ -79,7 +78,15 @@ VERSION2SPECS = {
|
|||||||
"is_legacy": True,
|
"is_legacy": True,
|
||||||
"config": "configs/inference/sd_xl_refiner.yaml",
|
"config": "configs/inference/sd_xl_refiner.yaml",
|
||||||
"ckpt": "checkpoints/sd_xl_refiner_0.9.safetensors",
|
"ckpt": "checkpoints/sd_xl_refiner_0.9.safetensors",
|
||||||
"is_guided": True,
|
},
|
||||||
|
"SDXL-refiner-1.0": {
|
||||||
|
"H": 1024,
|
||||||
|
"W": 1024,
|
||||||
|
"C": 4,
|
||||||
|
"f": 8,
|
||||||
|
"is_legacy": True,
|
||||||
|
"config": "configs/inference/sd_xl_refiner.yaml",
|
||||||
|
"ckpt": "checkpoints/sd_xl_refiner_1.0.safetensors",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -103,18 +110,19 @@ def load_img(display=True, key=None, device="cuda"):
|
|||||||
|
|
||||||
|
|
||||||
def run_txt2img(
|
def run_txt2img(
|
||||||
state, version, version_dict, is_legacy=False, return_latents=False, filter=None
|
state,
|
||||||
|
version,
|
||||||
|
version_dict,
|
||||||
|
is_legacy=False,
|
||||||
|
return_latents=False,
|
||||||
|
filter=None,
|
||||||
|
stage2strength=None,
|
||||||
):
|
):
|
||||||
if version == "SD-XL base":
|
if version.startswith("SDXL-base"):
|
||||||
ratio = st.sidebar.selectbox("Ratio:", list(SD_XL_BASE_RATIOS.keys()), 10)
|
W, H = st.selectbox("Resolution:", list(SD_XL_BASE_RATIOS.values()), 10)
|
||||||
W, H = SD_XL_BASE_RATIOS[ratio]
|
|
||||||
else:
|
else:
|
||||||
H = st.sidebar.number_input(
|
H = st.number_input("H", value=version_dict["H"], min_value=64, max_value=2048)
|
||||||
"H", value=version_dict["H"], min_value=64, max_value=2048
|
W = st.number_input("W", value=version_dict["W"], min_value=64, max_value=2048)
|
||||||
)
|
|
||||||
W = st.sidebar.number_input(
|
|
||||||
"W", value=version_dict["W"], min_value=64, max_value=2048
|
|
||||||
)
|
|
||||||
C = version_dict["C"]
|
C = version_dict["C"]
|
||||||
F = version_dict["f"]
|
F = version_dict["f"]
|
||||||
|
|
||||||
@@ -130,16 +138,11 @@ def run_txt2img(
|
|||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
negative_prompt=negative_prompt,
|
negative_prompt=negative_prompt,
|
||||||
)
|
)
|
||||||
num_rows, num_cols, sampler = init_sampling(
|
sampler, num_rows, num_cols = init_sampling(stage2strength=stage2strength)
|
||||||
use_identity_guider=not version_dict["is_guided"]
|
|
||||||
)
|
|
||||||
|
|
||||||
num_samples = num_rows * num_cols
|
num_samples = num_rows * num_cols
|
||||||
|
|
||||||
if st.button("Sample"):
|
if st.button("Sample"):
|
||||||
st.write(f"**Model I:** {version}")
|
st.write(f"**Model I:** {version}")
|
||||||
outputs = st.empty()
|
|
||||||
st.text("Sampling")
|
|
||||||
out = do_sample(
|
out = do_sample(
|
||||||
state["model"],
|
state["model"],
|
||||||
sampler,
|
sampler,
|
||||||
@@ -153,13 +156,16 @@ def run_txt2img(
|
|||||||
return_latents=return_latents,
|
return_latents=return_latents,
|
||||||
filter=filter,
|
filter=filter,
|
||||||
)
|
)
|
||||||
show_samples(out, outputs)
|
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def run_img2img(
|
def run_img2img(
|
||||||
state, version_dict, is_legacy=False, return_latents=False, filter=None
|
state,
|
||||||
|
version_dict,
|
||||||
|
is_legacy=False,
|
||||||
|
return_latents=False,
|
||||||
|
filter=None,
|
||||||
|
stage2strength=None,
|
||||||
):
|
):
|
||||||
img = load_img()
|
img = load_img()
|
||||||
if img is None:
|
if img is None:
|
||||||
@@ -175,19 +181,19 @@ def run_img2img(
|
|||||||
value_dict = init_embedder_options(
|
value_dict = init_embedder_options(
|
||||||
get_unique_embedder_keys_from_conditioner(state["model"].conditioner),
|
get_unique_embedder_keys_from_conditioner(state["model"].conditioner),
|
||||||
init_dict,
|
init_dict,
|
||||||
|
prompt=prompt,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
)
|
)
|
||||||
strength = st.number_input(
|
strength = st.number_input(
|
||||||
"**Img2Img Strength**", value=0.5, min_value=0.0, max_value=1.0
|
"**Img2Img Strength**", value=0.75, min_value=0.0, max_value=1.0
|
||||||
)
|
)
|
||||||
num_rows, num_cols, sampler = init_sampling(
|
sampler, num_rows, num_cols = init_sampling(
|
||||||
img2img_strength=strength,
|
img2img_strength=strength,
|
||||||
use_identity_guider=not version_dict["is_guided"],
|
stage2strength=stage2strength,
|
||||||
)
|
)
|
||||||
num_samples = num_rows * num_cols
|
num_samples = num_rows * num_cols
|
||||||
|
|
||||||
if st.button("Sample"):
|
if st.button("Sample"):
|
||||||
outputs = st.empty()
|
|
||||||
st.text("Sampling")
|
|
||||||
out = do_img2img(
|
out = do_img2img(
|
||||||
repeat(img, "1 ... -> n ...", n=num_samples),
|
repeat(img, "1 ... -> n ...", n=num_samples),
|
||||||
state["model"],
|
state["model"],
|
||||||
@@ -198,7 +204,6 @@ def run_img2img(
|
|||||||
return_latents=return_latents,
|
return_latents=return_latents,
|
||||||
filter=filter,
|
filter=filter,
|
||||||
)
|
)
|
||||||
show_samples(out, outputs)
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
@@ -210,6 +215,7 @@ def apply_refiner(
|
|||||||
prompt,
|
prompt,
|
||||||
negative_prompt,
|
negative_prompt,
|
||||||
filter=None,
|
filter=None,
|
||||||
|
finish_denoising=False,
|
||||||
):
|
):
|
||||||
init_dict = {
|
init_dict = {
|
||||||
"orig_width": input.shape[3] * 8,
|
"orig_width": input.shape[3] * 8,
|
||||||
@@ -237,6 +243,7 @@ def apply_refiner(
|
|||||||
num_samples,
|
num_samples,
|
||||||
skip_encode=True,
|
skip_encode=True,
|
||||||
filter=filter,
|
filter=filter,
|
||||||
|
add_noise=not finish_denoising,
|
||||||
)
|
)
|
||||||
|
|
||||||
return samples
|
return samples
|
||||||
@@ -249,20 +256,22 @@ if __name__ == "__main__":
|
|||||||
mode = st.radio("Mode", ("txt2img", "img2img"), 0)
|
mode = st.radio("Mode", ("txt2img", "img2img"), 0)
|
||||||
st.write("__________________________")
|
st.write("__________________________")
|
||||||
|
|
||||||
if version == "SD-XL base":
|
set_lowvram_mode(st.checkbox("Low vram mode", True))
|
||||||
add_pipeline = st.checkbox("Load SDXL-Refiner?", False)
|
|
||||||
|
if version.startswith("SDXL-base"):
|
||||||
|
add_pipeline = st.checkbox("Load SDXL-refiner?", False)
|
||||||
st.write("__________________________")
|
st.write("__________________________")
|
||||||
else:
|
else:
|
||||||
add_pipeline = False
|
add_pipeline = False
|
||||||
|
|
||||||
filter = DeepFloydDataFiltering(verbose=False)
|
|
||||||
|
|
||||||
seed = st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9))
|
seed = st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9))
|
||||||
seed_everything(seed)
|
seed_everything(seed)
|
||||||
|
|
||||||
save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, version))
|
save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, version))
|
||||||
|
|
||||||
state = init_st(version_dict)
|
state = init_st(version_dict, load_filter=True)
|
||||||
|
if state["msg"]:
|
||||||
|
st.info(state["msg"])
|
||||||
model = state["model"]
|
model = state["model"]
|
||||||
|
|
||||||
is_legacy = version_dict["is_legacy"]
|
is_legacy = version_dict["is_legacy"]
|
||||||
@@ -276,29 +285,34 @@ if __name__ == "__main__":
|
|||||||
else:
|
else:
|
||||||
negative_prompt = "" # which is unused
|
negative_prompt = "" # which is unused
|
||||||
|
|
||||||
|
stage2strength = None
|
||||||
|
finish_denoising = False
|
||||||
|
|
||||||
if add_pipeline:
|
if add_pipeline:
|
||||||
st.write("__________________________")
|
st.write("__________________________")
|
||||||
|
version2 = st.selectbox("Refiner:", ["SDXL-refiner-1.0", "SDXL-refiner-0.9"])
|
||||||
version2 = "SDXL-Refiner"
|
|
||||||
st.warning(
|
st.warning(
|
||||||
f"Running with {version2} as the second stage model. Make sure to provide (V)RAM :) "
|
f"Running with {version2} as the second stage model. Make sure to provide (V)RAM :) "
|
||||||
)
|
)
|
||||||
st.write("**Refiner Options:**")
|
st.write("**Refiner Options:**")
|
||||||
|
|
||||||
version_dict2 = VERSION2SPECS[version2]
|
version_dict2 = VERSION2SPECS[version2]
|
||||||
state2 = init_st(version_dict2)
|
state2 = init_st(version_dict2, load_filter=False)
|
||||||
|
st.info(state2["msg"])
|
||||||
|
|
||||||
stage2strength = st.number_input(
|
stage2strength = st.number_input(
|
||||||
"**Refinement strength**", value=0.3, min_value=0.0, max_value=1.0
|
"**Refinement strength**", value=0.15, min_value=0.0, max_value=1.0
|
||||||
)
|
)
|
||||||
|
|
||||||
sampler2 = init_sampling(
|
sampler2, *_ = init_sampling(
|
||||||
key=2,
|
key=2,
|
||||||
img2img_strength=stage2strength,
|
img2img_strength=stage2strength,
|
||||||
use_identity_guider=not version_dict2["is_guided"],
|
specify_num_samples=False,
|
||||||
get_num_samples=False,
|
|
||||||
)
|
)
|
||||||
st.write("__________________________")
|
st.write("__________________________")
|
||||||
|
finish_denoising = st.checkbox("Finish denoising with refiner.", True)
|
||||||
|
if not finish_denoising:
|
||||||
|
stage2strength = None
|
||||||
|
|
||||||
if mode == "txt2img":
|
if mode == "txt2img":
|
||||||
out = run_txt2img(
|
out = run_txt2img(
|
||||||
@@ -307,7 +321,8 @@ if __name__ == "__main__":
|
|||||||
version_dict,
|
version_dict,
|
||||||
is_legacy=is_legacy,
|
is_legacy=is_legacy,
|
||||||
return_latents=add_pipeline,
|
return_latents=add_pipeline,
|
||||||
filter=filter,
|
filter=state.get("filter"),
|
||||||
|
stage2strength=stage2strength,
|
||||||
)
|
)
|
||||||
elif mode == "img2img":
|
elif mode == "img2img":
|
||||||
out = run_img2img(
|
out = run_img2img(
|
||||||
@@ -315,7 +330,8 @@ if __name__ == "__main__":
|
|||||||
version_dict,
|
version_dict,
|
||||||
is_legacy=is_legacy,
|
is_legacy=is_legacy,
|
||||||
return_latents=add_pipeline,
|
return_latents=add_pipeline,
|
||||||
filter=filter,
|
filter=state.get("filter"),
|
||||||
|
stage2strength=stage2strength,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"unknown mode {mode}")
|
raise ValueError(f"unknown mode {mode}")
|
||||||
@@ -326,7 +342,6 @@ if __name__ == "__main__":
|
|||||||
samples_z = None
|
samples_z = None
|
||||||
|
|
||||||
if add_pipeline and samples_z is not None:
|
if add_pipeline and samples_z is not None:
|
||||||
outputs = st.empty()
|
|
||||||
st.write("**Running Refinement Stage**")
|
st.write("**Running Refinement Stage**")
|
||||||
samples = apply_refiner(
|
samples = apply_refiner(
|
||||||
samples_z,
|
samples_z,
|
||||||
@@ -335,9 +350,9 @@ if __name__ == "__main__":
|
|||||||
samples_z.shape[0],
|
samples_z.shape[0],
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
negative_prompt=negative_prompt if is_legacy else "",
|
negative_prompt=negative_prompt if is_legacy else "",
|
||||||
filter=filter,
|
filter=state.get("filter"),
|
||||||
|
finish_denoising=finish_denoising,
|
||||||
)
|
)
|
||||||
show_samples(samples, outputs)
|
|
||||||
|
|
||||||
if save_locally and samples is not None:
|
if save_locally and samples is not None:
|
||||||
perform_save_locally(save_path, samples)
|
perform_save_locally(save_path, samples)
|
||||||
|
|||||||
@@ -1,12 +1,20 @@
|
|||||||
|
import math
|
||||||
import os
|
import os
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from omegaconf import OmegaConf
|
from imwatermark import WatermarkEncoder
|
||||||
|
from omegaconf import ListConfig, OmegaConf
|
||||||
|
from PIL import Image
|
||||||
|
from safetensors.torch import load_file as load_safetensors
|
||||||
|
from torch import autocast
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
|
from torchvision.utils import make_grid
|
||||||
|
|
||||||
|
from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering
|
||||||
from sgm.modules.diffusionmodules.sampling import (
|
from sgm.modules.diffusionmodules.sampling import (
|
||||||
DPMPP2MSampler,
|
DPMPP2MSampler,
|
||||||
DPMPP2SAncestralSampler,
|
DPMPP2SAncestralSampler,
|
||||||
@@ -15,29 +23,140 @@ from sgm.modules.diffusionmodules.sampling import (
|
|||||||
HeunEDMSampler,
|
HeunEDMSampler,
|
||||||
LinearMultistepSampler,
|
LinearMultistepSampler,
|
||||||
)
|
)
|
||||||
from sgm.inference.helpers import Img2ImgDiscretizationWrapper, embed_watermark
|
from sgm.util import append_dims, instantiate_from_config
|
||||||
from sgm.util import load_model_from_config
|
|
||||||
|
|
||||||
|
class WatermarkEmbedder:
|
||||||
|
def __init__(self, watermark):
|
||||||
|
self.watermark = watermark
|
||||||
|
self.num_bits = len(WATERMARK_BITS)
|
||||||
|
self.encoder = WatermarkEncoder()
|
||||||
|
self.encoder.set_watermark("bits", self.watermark)
|
||||||
|
|
||||||
|
def __call__(self, image: torch.Tensor):
|
||||||
|
"""
|
||||||
|
Adds a predefined watermark to the input image
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: ([N,] B, C, H, W) in range [0, 1]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
same as input but watermarked
|
||||||
|
"""
|
||||||
|
# watermarking libary expects input as cv2 BGR format
|
||||||
|
squeeze = len(image.shape) == 4
|
||||||
|
if squeeze:
|
||||||
|
image = image[None, ...]
|
||||||
|
n = image.shape[0]
|
||||||
|
image_np = rearrange(
|
||||||
|
(255 * image).detach().cpu(), "n b c h w -> (n b) h w c"
|
||||||
|
).numpy()[:, :, :, ::-1]
|
||||||
|
# torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
|
||||||
|
for k in range(image_np.shape[0]):
|
||||||
|
image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
|
||||||
|
image = torch.from_numpy(
|
||||||
|
rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)
|
||||||
|
).to(image.device)
|
||||||
|
image = torch.clamp(image / 255, min=0.0, max=1.0)
|
||||||
|
if squeeze:
|
||||||
|
image = image[0]
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
# A fixed 48-bit message that was choosen at random
|
||||||
|
# WATERMARK_MESSAGE = 0xB3EC907BB19E
|
||||||
|
WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110
|
||||||
|
# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
|
||||||
|
WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
|
||||||
|
embed_watemark = WatermarkEmbedder(WATERMARK_BITS)
|
||||||
|
|
||||||
|
|
||||||
@st.cache_resource()
|
@st.cache_resource()
|
||||||
def init_st(version_dict, load_ckpt=True):
|
def init_st(version_dict, load_ckpt=True, load_filter=True):
|
||||||
state = dict()
|
state = dict()
|
||||||
if not "model" in state:
|
if not "model" in state:
|
||||||
config = version_dict["config"]
|
config = version_dict["config"]
|
||||||
ckpt = version_dict["ckpt"]
|
ckpt = version_dict["ckpt"]
|
||||||
|
|
||||||
config = OmegaConf.load(config)
|
config = OmegaConf.load(config)
|
||||||
model = load_model_from_config(config, ckpt if load_ckpt else None)
|
model, msg = load_model_from_config(config, ckpt if load_ckpt else None)
|
||||||
model = model.to("cuda")
|
|
||||||
model.conditioner.half()
|
|
||||||
model.model.half()
|
|
||||||
|
|
||||||
|
state["msg"] = msg
|
||||||
state["model"] = model
|
state["model"] = model
|
||||||
state["ckpt"] = ckpt if load_ckpt else None
|
state["ckpt"] = ckpt if load_ckpt else None
|
||||||
state["config"] = config
|
state["config"] = config
|
||||||
|
if load_filter:
|
||||||
|
state["filter"] = DeepFloydDataFiltering(verbose=False)
|
||||||
return state
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(model):
|
||||||
|
model.cuda()
|
||||||
|
|
||||||
|
|
||||||
|
lowvram_mode = False
|
||||||
|
|
||||||
|
|
||||||
|
def set_lowvram_mode(mode):
|
||||||
|
global lowvram_mode
|
||||||
|
lowvram_mode = mode
|
||||||
|
|
||||||
|
|
||||||
|
def initial_model_load(model):
|
||||||
|
global lowvram_mode
|
||||||
|
if lowvram_mode:
|
||||||
|
model.model.half()
|
||||||
|
else:
|
||||||
|
model.cuda()
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def unload_model(model):
|
||||||
|
global lowvram_mode
|
||||||
|
if lowvram_mode:
|
||||||
|
model.cpu()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
def load_model_from_config(config, ckpt=None, verbose=True):
|
||||||
|
model = instantiate_from_config(config.model)
|
||||||
|
|
||||||
|
if ckpt is not None:
|
||||||
|
print(f"Loading model from {ckpt}")
|
||||||
|
if ckpt.endswith("ckpt"):
|
||||||
|
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||||
|
if "global_step" in pl_sd:
|
||||||
|
global_step = pl_sd["global_step"]
|
||||||
|
st.info(f"loaded ckpt from global step {global_step}")
|
||||||
|
print(f"Global Step: {pl_sd['global_step']}")
|
||||||
|
sd = pl_sd["state_dict"]
|
||||||
|
elif ckpt.endswith("safetensors"):
|
||||||
|
sd = load_safetensors(ckpt)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
msg = None
|
||||||
|
|
||||||
|
m, u = model.load_state_dict(sd, strict=False)
|
||||||
|
|
||||||
|
if len(m) > 0 and verbose:
|
||||||
|
print("missing keys:")
|
||||||
|
print(m)
|
||||||
|
if len(u) > 0 and verbose:
|
||||||
|
print("unexpected keys:")
|
||||||
|
print(u)
|
||||||
|
else:
|
||||||
|
msg = None
|
||||||
|
|
||||||
|
model = initial_model_load(model)
|
||||||
|
model.eval()
|
||||||
|
return model, msg
|
||||||
|
|
||||||
|
|
||||||
|
def get_unique_embedder_keys_from_conditioner(conditioner):
|
||||||
|
return list(set([x.input_key for x in conditioner.embedders]))
|
||||||
|
|
||||||
|
|
||||||
def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None):
|
def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None):
|
||||||
# Hardcoded demo settings; might undergo some changes in the future
|
# Hardcoded demo settings; might undergo some changes in the future
|
||||||
|
|
||||||
@@ -81,23 +200,24 @@ def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None):
|
|||||||
value_dict["negative_aesthetic_score"] = 2.5
|
value_dict["negative_aesthetic_score"] = 2.5
|
||||||
|
|
||||||
if key == "target_size_as_tuple":
|
if key == "target_size_as_tuple":
|
||||||
target_width = st.number_input(
|
value_dict["target_width"] = init_dict["target_width"]
|
||||||
"target_width",
|
value_dict["target_height"] = init_dict["target_height"]
|
||||||
value=init_dict["target_width"],
|
|
||||||
min_value=16,
|
|
||||||
)
|
|
||||||
target_height = st.number_input(
|
|
||||||
"target_height",
|
|
||||||
value=init_dict["target_height"],
|
|
||||||
min_value=16,
|
|
||||||
)
|
|
||||||
|
|
||||||
value_dict["target_width"] = target_width
|
|
||||||
value_dict["target_height"] = target_height
|
|
||||||
|
|
||||||
return value_dict
|
return value_dict
|
||||||
|
|
||||||
|
|
||||||
|
def perform_save_locally(save_path, samples):
|
||||||
|
os.makedirs(os.path.join(save_path), exist_ok=True)
|
||||||
|
base_count = len(os.listdir(os.path.join(save_path)))
|
||||||
|
samples = embed_watemark(samples)
|
||||||
|
for sample in samples:
|
||||||
|
sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c")
|
||||||
|
Image.fromarray(sample.astype(np.uint8)).save(
|
||||||
|
os.path.join(save_path, f"{base_count:09}.png")
|
||||||
|
)
|
||||||
|
base_count += 1
|
||||||
|
|
||||||
|
|
||||||
def init_save_locally(_dir, init_value: bool = False):
|
def init_save_locally(_dir, init_value: bool = False):
|
||||||
save_locally = st.sidebar.checkbox("Save images locally", value=init_value)
|
save_locally = st.sidebar.checkbox("Save images locally", value=init_value)
|
||||||
if save_locally:
|
if save_locally:
|
||||||
@@ -108,12 +228,58 @@ def init_save_locally(_dir, init_value: bool = False):
|
|||||||
return save_locally, save_path
|
return save_locally, save_path
|
||||||
|
|
||||||
|
|
||||||
def show_samples(samples, outputs):
|
class Img2ImgDiscretizationWrapper:
|
||||||
if isinstance(samples, tuple):
|
"""
|
||||||
samples, _ = samples
|
wraps a discretizer, and prunes the sigmas
|
||||||
grid = embed_watermark(torch.stack([samples]))
|
params:
|
||||||
grid = rearrange(grid, "n b c h w -> (n h) (b w) c")
|
strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned)
|
||||||
outputs.image(grid.cpu().numpy())
|
"""
|
||||||
|
|
||||||
|
def __init__(self, discretization, strength: float = 1.0):
|
||||||
|
self.discretization = discretization
|
||||||
|
self.strength = strength
|
||||||
|
assert 0.0 <= self.strength <= 1.0
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
# sigmas start large first, and decrease then
|
||||||
|
sigmas = self.discretization(*args, **kwargs)
|
||||||
|
print(f"sigmas after discretization, before pruning img2img: ", sigmas)
|
||||||
|
sigmas = torch.flip(sigmas, (0,))
|
||||||
|
sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)]
|
||||||
|
print("prune index:", max(int(self.strength * len(sigmas)), 1))
|
||||||
|
sigmas = torch.flip(sigmas, (0,))
|
||||||
|
print(f"sigmas after pruning: ", sigmas)
|
||||||
|
return sigmas
|
||||||
|
|
||||||
|
|
||||||
|
class Txt2NoisyDiscretizationWrapper:
|
||||||
|
"""
|
||||||
|
wraps a discretizer, and prunes the sigmas
|
||||||
|
params:
|
||||||
|
strength: float between 0.0 and 1.0. 0.0 means full sampling (all sigmas are returned)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, discretization, strength: float = 0.0, original_steps=None):
|
||||||
|
self.discretization = discretization
|
||||||
|
self.strength = strength
|
||||||
|
self.original_steps = original_steps
|
||||||
|
assert 0.0 <= self.strength <= 1.0
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
# sigmas start large first, and decrease then
|
||||||
|
sigmas = self.discretization(*args, **kwargs)
|
||||||
|
print(f"sigmas after discretization, before pruning img2img: ", sigmas)
|
||||||
|
sigmas = torch.flip(sigmas, (0,))
|
||||||
|
if self.original_steps is None:
|
||||||
|
steps = len(sigmas)
|
||||||
|
else:
|
||||||
|
steps = self.original_steps + 1
|
||||||
|
prune_index = max(min(int(self.strength * steps) - 1, steps - 1), 0)
|
||||||
|
sigmas = sigmas[prune_index:]
|
||||||
|
print("prune index:", prune_index)
|
||||||
|
sigmas = torch.flip(sigmas, (0,))
|
||||||
|
print(f"sigmas after pruning: ", sigmas)
|
||||||
|
return sigmas
|
||||||
|
|
||||||
|
|
||||||
def get_guider(key):
|
def get_guider(key):
|
||||||
@@ -158,16 +324,19 @@ def get_guider(key):
|
|||||||
|
|
||||||
|
|
||||||
def init_sampling(
|
def init_sampling(
|
||||||
key=1, img2img_strength=1.0, use_identity_guider=False, get_num_samples=True
|
key=1,
|
||||||
|
img2img_strength=1.0,
|
||||||
|
specify_num_samples=True,
|
||||||
|
stage2strength=None,
|
||||||
):
|
):
|
||||||
if get_num_samples:
|
num_rows, num_cols = 1, 1
|
||||||
num_rows = 1
|
if specify_num_samples:
|
||||||
num_cols = st.number_input(
|
num_cols = st.number_input(
|
||||||
f"num cols #{key}", value=2, min_value=1, max_value=10
|
f"num cols #{key}", value=2, min_value=1, max_value=10
|
||||||
)
|
)
|
||||||
|
|
||||||
steps = st.sidebar.number_input(
|
steps = st.sidebar.number_input(
|
||||||
f"steps #{key}", value=50, min_value=1, max_value=1000
|
f"steps #{key}", value=40, min_value=1, max_value=1000
|
||||||
)
|
)
|
||||||
sampler = st.sidebar.selectbox(
|
sampler = st.sidebar.selectbox(
|
||||||
f"Sampler #{key}",
|
f"Sampler #{key}",
|
||||||
@@ -201,9 +370,11 @@ def init_sampling(
|
|||||||
sampler.discretization = Img2ImgDiscretizationWrapper(
|
sampler.discretization = Img2ImgDiscretizationWrapper(
|
||||||
sampler.discretization, strength=img2img_strength
|
sampler.discretization, strength=img2img_strength
|
||||||
)
|
)
|
||||||
if get_num_samples:
|
if stage2strength is not None:
|
||||||
return num_rows, num_cols, sampler
|
sampler.discretization = Txt2NoisyDiscretizationWrapper(
|
||||||
return sampler
|
sampler.discretization, strength=stage2strength, original_steps=steps
|
||||||
|
)
|
||||||
|
return sampler, num_rows, num_cols
|
||||||
|
|
||||||
|
|
||||||
def get_discretization(discretization, key=1):
|
def get_discretization(discretization, key=1):
|
||||||
@@ -336,3 +507,238 @@ def get_init_img(batch_size=1, key=None):
|
|||||||
init_image = load_img(key=key).cuda()
|
init_image = load_img(key=key).cuda()
|
||||||
init_image = repeat(init_image, "1 ... -> b ...", b=batch_size)
|
init_image = repeat(init_image, "1 ... -> b ...", b=batch_size)
|
||||||
return init_image
|
return init_image
|
||||||
|
|
||||||
|
|
||||||
|
def do_sample(
|
||||||
|
model,
|
||||||
|
sampler,
|
||||||
|
value_dict,
|
||||||
|
num_samples,
|
||||||
|
H,
|
||||||
|
W,
|
||||||
|
C,
|
||||||
|
F,
|
||||||
|
force_uc_zero_embeddings: List = None,
|
||||||
|
batch2model_input: List = None,
|
||||||
|
return_latents=False,
|
||||||
|
filter=None,
|
||||||
|
):
|
||||||
|
if force_uc_zero_embeddings is None:
|
||||||
|
force_uc_zero_embeddings = []
|
||||||
|
if batch2model_input is None:
|
||||||
|
batch2model_input = []
|
||||||
|
|
||||||
|
st.text("Sampling")
|
||||||
|
|
||||||
|
outputs = st.empty()
|
||||||
|
precision_scope = autocast
|
||||||
|
with torch.no_grad():
|
||||||
|
with precision_scope("cuda"):
|
||||||
|
with model.ema_scope():
|
||||||
|
num_samples = [num_samples]
|
||||||
|
load_model(model.conditioner)
|
||||||
|
batch, batch_uc = get_batch(
|
||||||
|
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
||||||
|
value_dict,
|
||||||
|
num_samples,
|
||||||
|
)
|
||||||
|
for key in batch:
|
||||||
|
if isinstance(batch[key], torch.Tensor):
|
||||||
|
print(key, batch[key].shape)
|
||||||
|
elif isinstance(batch[key], list):
|
||||||
|
print(key, [len(l) for l in batch[key]])
|
||||||
|
else:
|
||||||
|
print(key, batch[key])
|
||||||
|
c, uc = model.conditioner.get_unconditional_conditioning(
|
||||||
|
batch,
|
||||||
|
batch_uc=batch_uc,
|
||||||
|
force_uc_zero_embeddings=force_uc_zero_embeddings,
|
||||||
|
)
|
||||||
|
unload_model(model.conditioner)
|
||||||
|
|
||||||
|
for k in c:
|
||||||
|
if not k == "crossattn":
|
||||||
|
c[k], uc[k] = map(
|
||||||
|
lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc)
|
||||||
|
)
|
||||||
|
|
||||||
|
additional_model_inputs = {}
|
||||||
|
for k in batch2model_input:
|
||||||
|
additional_model_inputs[k] = batch[k]
|
||||||
|
|
||||||
|
shape = (math.prod(num_samples), C, H // F, W // F)
|
||||||
|
randn = torch.randn(shape).to("cuda")
|
||||||
|
|
||||||
|
def denoiser(input, sigma, c):
|
||||||
|
return model.denoiser(
|
||||||
|
model.model, input, sigma, c, **additional_model_inputs
|
||||||
|
)
|
||||||
|
|
||||||
|
load_model(model.denoiser)
|
||||||
|
load_model(model.model)
|
||||||
|
samples_z = sampler(denoiser, randn, cond=c, uc=uc)
|
||||||
|
unload_model(model.model)
|
||||||
|
unload_model(model.denoiser)
|
||||||
|
|
||||||
|
load_model(model.first_stage_model)
|
||||||
|
samples_x = model.decode_first_stage(samples_z)
|
||||||
|
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
unload_model(model.first_stage_model)
|
||||||
|
|
||||||
|
if filter is not None:
|
||||||
|
samples = filter(samples)
|
||||||
|
|
||||||
|
grid = torch.stack([samples])
|
||||||
|
grid = rearrange(grid, "n b c h w -> (n h) (b w) c")
|
||||||
|
outputs.image(grid.cpu().numpy())
|
||||||
|
|
||||||
|
if return_latents:
|
||||||
|
return samples, samples_z
|
||||||
|
return samples
|
||||||
|
|
||||||
|
|
||||||
|
def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
|
||||||
|
# Hardcoded demo setups; might undergo some changes in the future
|
||||||
|
|
||||||
|
batch = {}
|
||||||
|
batch_uc = {}
|
||||||
|
|
||||||
|
for key in keys:
|
||||||
|
if key == "txt":
|
||||||
|
batch["txt"] = (
|
||||||
|
np.repeat([value_dict["prompt"]], repeats=math.prod(N))
|
||||||
|
.reshape(N)
|
||||||
|
.tolist()
|
||||||
|
)
|
||||||
|
batch_uc["txt"] = (
|
||||||
|
np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N))
|
||||||
|
.reshape(N)
|
||||||
|
.tolist()
|
||||||
|
)
|
||||||
|
elif key == "original_size_as_tuple":
|
||||||
|
batch["original_size_as_tuple"] = (
|
||||||
|
torch.tensor([value_dict["orig_height"], value_dict["orig_width"]])
|
||||||
|
.to(device)
|
||||||
|
.repeat(*N, 1)
|
||||||
|
)
|
||||||
|
elif key == "crop_coords_top_left":
|
||||||
|
batch["crop_coords_top_left"] = (
|
||||||
|
torch.tensor(
|
||||||
|
[value_dict["crop_coords_top"], value_dict["crop_coords_left"]]
|
||||||
|
)
|
||||||
|
.to(device)
|
||||||
|
.repeat(*N, 1)
|
||||||
|
)
|
||||||
|
elif key == "aesthetic_score":
|
||||||
|
batch["aesthetic_score"] = (
|
||||||
|
torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
|
||||||
|
)
|
||||||
|
batch_uc["aesthetic_score"] = (
|
||||||
|
torch.tensor([value_dict["negative_aesthetic_score"]])
|
||||||
|
.to(device)
|
||||||
|
.repeat(*N, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
elif key == "target_size_as_tuple":
|
||||||
|
batch["target_size_as_tuple"] = (
|
||||||
|
torch.tensor([value_dict["target_height"], value_dict["target_width"]])
|
||||||
|
.to(device)
|
||||||
|
.repeat(*N, 1)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
batch[key] = value_dict[key]
|
||||||
|
|
||||||
|
for key in batch.keys():
|
||||||
|
if key not in batch_uc and isinstance(batch[key], torch.Tensor):
|
||||||
|
batch_uc[key] = torch.clone(batch[key])
|
||||||
|
return batch, batch_uc
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def do_img2img(
|
||||||
|
img,
|
||||||
|
model,
|
||||||
|
sampler,
|
||||||
|
value_dict,
|
||||||
|
num_samples,
|
||||||
|
force_uc_zero_embeddings=[],
|
||||||
|
additional_kwargs={},
|
||||||
|
offset_noise_level: int = 0.0,
|
||||||
|
return_latents=False,
|
||||||
|
skip_encode=False,
|
||||||
|
filter=None,
|
||||||
|
add_noise=True,
|
||||||
|
):
|
||||||
|
st.text("Sampling")
|
||||||
|
|
||||||
|
outputs = st.empty()
|
||||||
|
precision_scope = autocast
|
||||||
|
with torch.no_grad():
|
||||||
|
with precision_scope("cuda"):
|
||||||
|
with model.ema_scope():
|
||||||
|
load_model(model.conditioner)
|
||||||
|
batch, batch_uc = get_batch(
|
||||||
|
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
||||||
|
value_dict,
|
||||||
|
[num_samples],
|
||||||
|
)
|
||||||
|
c, uc = model.conditioner.get_unconditional_conditioning(
|
||||||
|
batch,
|
||||||
|
batch_uc=batch_uc,
|
||||||
|
force_uc_zero_embeddings=force_uc_zero_embeddings,
|
||||||
|
)
|
||||||
|
unload_model(model.conditioner)
|
||||||
|
for k in c:
|
||||||
|
c[k], uc[k] = map(lambda y: y[k][:num_samples].to("cuda"), (c, uc))
|
||||||
|
|
||||||
|
for k in additional_kwargs:
|
||||||
|
c[k] = uc[k] = additional_kwargs[k]
|
||||||
|
if skip_encode:
|
||||||
|
z = img
|
||||||
|
else:
|
||||||
|
load_model(model.first_stage_model)
|
||||||
|
z = model.encode_first_stage(img)
|
||||||
|
unload_model(model.first_stage_model)
|
||||||
|
|
||||||
|
noise = torch.randn_like(z)
|
||||||
|
|
||||||
|
sigmas = sampler.discretization(sampler.num_steps).cuda()
|
||||||
|
sigma = sigmas[0]
|
||||||
|
|
||||||
|
st.info(f"all sigmas: {sigmas}")
|
||||||
|
st.info(f"noising sigma: {sigma}")
|
||||||
|
if offset_noise_level > 0.0:
|
||||||
|
noise = noise + offset_noise_level * append_dims(
|
||||||
|
torch.randn(z.shape[0], device=z.device), z.ndim
|
||||||
|
)
|
||||||
|
if add_noise:
|
||||||
|
noised_z = z + noise * append_dims(sigma, z.ndim).cuda()
|
||||||
|
noised_z = noised_z / torch.sqrt(
|
||||||
|
1.0 + sigmas[0] ** 2.0
|
||||||
|
) # Note: hardcoded to DDPM-like scaling. need to generalize later.
|
||||||
|
else:
|
||||||
|
noised_z = z / torch.sqrt(1.0 + sigmas[0] ** 2.0)
|
||||||
|
|
||||||
|
def denoiser(x, sigma, c):
|
||||||
|
return model.denoiser(model.model, x, sigma, c)
|
||||||
|
|
||||||
|
load_model(model.denoiser)
|
||||||
|
load_model(model.model)
|
||||||
|
samples_z = sampler(denoiser, noised_z, cond=c, uc=uc)
|
||||||
|
unload_model(model.model)
|
||||||
|
unload_model(model.denoiser)
|
||||||
|
|
||||||
|
load_model(model.first_stage_model)
|
||||||
|
samples_x = model.decode_first_stage(samples_z)
|
||||||
|
unload_model(model.first_stage_model)
|
||||||
|
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
|
||||||
|
if filter is not None:
|
||||||
|
samples = filter(samples)
|
||||||
|
|
||||||
|
grid = embed_watemark(torch.stack([samples]))
|
||||||
|
grid = rearrange(grid, "n b c h w -> (n h) (b w) c")
|
||||||
|
outputs.image(grid.cpu().numpy())
|
||||||
|
if return_latents:
|
||||||
|
return samples, samples_z
|
||||||
|
return samples
|
||||||
|
|||||||
Reference in New Issue
Block a user