mirror of
https://github.com/Stability-AI/generative-models.git
synced 2026-02-04 13:24:28 +01:00
Compare commits
58 Commits
devcontain
...
helpers-fi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7ef5489cea | ||
|
|
477d8b9a77 | ||
|
|
e289621992 | ||
|
|
2fc4680bf9 | ||
|
|
e32972b85b | ||
|
|
65c6ec1cec | ||
|
|
5fde7e73b8 | ||
|
|
fbe93fc53b | ||
|
|
c0655731d5 | ||
|
|
f6704532a0 | ||
|
|
98c4b7753b | ||
|
|
d4307bef5d | ||
|
|
fe4632034b | ||
|
|
d6f2b78994 | ||
|
|
cd81956241 | ||
|
|
5c17043434 | ||
|
|
2aebc8882d | ||
|
|
3816aaa639 | ||
|
|
88395261d8 | ||
|
|
b3866d1218 | ||
|
|
a25662e969 | ||
|
|
26b10f56f3 | ||
|
|
3e7ada70c5 | ||
|
|
de7a627978 | ||
|
|
9b18e6fa19 | ||
|
|
47805f233c | ||
|
|
e190ecc60b | ||
|
|
fc498bfaef | ||
|
|
8011d54ca1 | ||
|
|
b51c36b0df | ||
|
|
d245e2002f | ||
|
|
725bea9f75 | ||
|
|
a009aa8a9f | ||
|
|
f86ffac274 | ||
|
|
a726ce3eb7 | ||
|
|
c4b7baf896 | ||
|
|
7e7fee3f0f | ||
|
|
49fe53c165 | ||
|
|
6c18c8443a | ||
|
|
ced97f0e84 | ||
|
|
76ca428422 | ||
|
|
8f8757b4ff | ||
|
|
f2fba1dfa2 | ||
|
|
451c76ada1 | ||
|
|
0c2c5c66a2 | ||
|
|
ea5f232d5d | ||
|
|
f06c67c206 | ||
|
|
b216934b7e | ||
|
|
77d0e27747 | ||
|
|
4aea6fa2a4 | ||
|
|
84d3a7f6f5 | ||
|
|
19fa4da3de | ||
|
|
4e2236f67d | ||
|
|
baf79d2d79 | ||
|
|
44943df4f2 | ||
|
|
73287ec3a3 | ||
|
|
853adb4022 | ||
|
|
45feb6cb9c |
2
.github/workflows/test-inference.yml
vendored
2
.github/workflows/test-inference.yml
vendored
@@ -15,7 +15,7 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: "Symlink checkpoints"
|
||||
run: ln -s ${{vars.SGM_CHECKPOINTS_PATH}} checkpoints
|
||||
run: ln -s $SGM_CHECKPOINTS checkpoints
|
||||
- name: "Setup python"
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
|
||||
@@ -44,5 +44,5 @@ dependencies = [
|
||||
test-inference = [
|
||||
"pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 torchaudio==2.0.2+cu118 --index-url https://download.pytorch.org/whl/cu118",
|
||||
"pip install -r requirements/pt2.txt",
|
||||
"pytest -v tests/inference/test_inference.py {args}",
|
||||
"pytest -v tests/inference {args}",
|
||||
]
|
||||
|
||||
@@ -1,6 +1,30 @@
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import streamlit as st
|
||||
import torch
|
||||
from einops import repeat
|
||||
from pytorch_lightning import seed_everything
|
||||
|
||||
from scripts.demo.streamlit_helpers import *
|
||||
from sgm.inference.api import (
|
||||
SamplingSpec,
|
||||
SamplingParams,
|
||||
ModelArchitecture,
|
||||
SamplingPipeline,
|
||||
model_specs,
|
||||
)
|
||||
from sgm.inference.helpers import (
|
||||
get_unique_embedder_keys_from_conditioner,
|
||||
perform_save_locally,
|
||||
)
|
||||
from scripts.demo.streamlit_helpers import (
|
||||
get_interactive_image,
|
||||
init_embedder_options,
|
||||
init_sampling,
|
||||
init_save_locally,
|
||||
init_st,
|
||||
show_samples,
|
||||
)
|
||||
|
||||
SAVE_PATH = "outputs/demo/txt2img/"
|
||||
|
||||
@@ -33,63 +57,6 @@ SD_XL_BASE_RATIOS = {
|
||||
"3.0": (1728, 576),
|
||||
}
|
||||
|
||||
VERSION2SPECS = {
|
||||
"SDXL-base-1.0": {
|
||||
"H": 1024,
|
||||
"W": 1024,
|
||||
"C": 4,
|
||||
"f": 8,
|
||||
"is_legacy": False,
|
||||
"config": "configs/inference/sd_xl_base.yaml",
|
||||
"ckpt": "checkpoints/sd_xl_base_1.0.safetensors",
|
||||
},
|
||||
"SDXL-base-0.9": {
|
||||
"H": 1024,
|
||||
"W": 1024,
|
||||
"C": 4,
|
||||
"f": 8,
|
||||
"is_legacy": False,
|
||||
"config": "configs/inference/sd_xl_base.yaml",
|
||||
"ckpt": "checkpoints/sd_xl_base_0.9.safetensors",
|
||||
},
|
||||
"SD-2.1": {
|
||||
"H": 512,
|
||||
"W": 512,
|
||||
"C": 4,
|
||||
"f": 8,
|
||||
"is_legacy": True,
|
||||
"config": "configs/inference/sd_2_1.yaml",
|
||||
"ckpt": "checkpoints/v2-1_512-ema-pruned.safetensors",
|
||||
},
|
||||
"SD-2.1-768": {
|
||||
"H": 768,
|
||||
"W": 768,
|
||||
"C": 4,
|
||||
"f": 8,
|
||||
"is_legacy": True,
|
||||
"config": "configs/inference/sd_2_1_768.yaml",
|
||||
"ckpt": "checkpoints/v2-1_768-ema-pruned.safetensors",
|
||||
},
|
||||
"SDXL-refiner-0.9": {
|
||||
"H": 1024,
|
||||
"W": 1024,
|
||||
"C": 4,
|
||||
"f": 8,
|
||||
"is_legacy": True,
|
||||
"config": "configs/inference/sd_xl_refiner.yaml",
|
||||
"ckpt": "checkpoints/sd_xl_refiner_0.9.safetensors",
|
||||
},
|
||||
"SDXL-refiner-1.0": {
|
||||
"H": 1024,
|
||||
"W": 1024,
|
||||
"C": 4,
|
||||
"f": 8,
|
||||
"is_legacy": True,
|
||||
"config": "configs/inference/sd_xl_refiner.yaml",
|
||||
"ckpt": "checkpoints/sd_xl_refiner_1.0.safetensors",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def load_img(display=True, key=None, device="cuda"):
|
||||
image = get_interactive_image(key=key)
|
||||
@@ -111,170 +78,181 @@ def load_img(display=True, key=None, device="cuda"):
|
||||
|
||||
def run_txt2img(
|
||||
state,
|
||||
version,
|
||||
version_dict,
|
||||
is_legacy=False,
|
||||
model_id: ModelArchitecture,
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
return_latents=False,
|
||||
filter=None,
|
||||
stage2strength=None,
|
||||
):
|
||||
if version.startswith("SDXL-base"):
|
||||
W, H = st.selectbox("Resolution:", list(SD_XL_BASE_RATIOS.values()), 10)
|
||||
model: SamplingPipeline = state["model"]
|
||||
params: SamplingParams = state["params"]
|
||||
if model_id in sdxl_base_model_list:
|
||||
width, height = st.selectbox(
|
||||
"Resolution:", list(SD_XL_BASE_RATIOS.values()), 10
|
||||
)
|
||||
else:
|
||||
H = st.number_input("H", value=version_dict["H"], min_value=64, max_value=2048)
|
||||
W = st.number_input("W", value=version_dict["W"], min_value=64, max_value=2048)
|
||||
C = version_dict["C"]
|
||||
F = version_dict["f"]
|
||||
height = int(
|
||||
st.number_input("H", value=params.height, min_value=64, max_value=2048)
|
||||
)
|
||||
width = int(
|
||||
st.number_input("W", value=params.width, min_value=64, max_value=2048)
|
||||
)
|
||||
|
||||
init_dict = {
|
||||
"orig_width": W,
|
||||
"orig_height": H,
|
||||
"target_width": W,
|
||||
"target_height": H,
|
||||
}
|
||||
value_dict = init_embedder_options(
|
||||
get_unique_embedder_keys_from_conditioner(state["model"].conditioner),
|
||||
init_dict,
|
||||
params = init_embedder_options(
|
||||
get_unique_embedder_keys_from_conditioner(model.model.conditioner),
|
||||
params=params,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
)
|
||||
sampler, num_rows, num_cols = init_sampling(stage2strength=stage2strength)
|
||||
params, num_rows, num_cols = init_sampling(params=params)
|
||||
num_samples = num_rows * num_cols
|
||||
params.height = height
|
||||
params.width = width
|
||||
|
||||
if st.button("Sample"):
|
||||
st.write(f"**Model I:** {version}")
|
||||
out = do_sample(
|
||||
state["model"],
|
||||
sampler,
|
||||
value_dict,
|
||||
num_samples,
|
||||
H,
|
||||
W,
|
||||
C,
|
||||
F,
|
||||
force_uc_zero_embeddings=["txt"] if not is_legacy else [],
|
||||
outputs = st.empty()
|
||||
st.text("Sampling")
|
||||
out = model.text_to_image(
|
||||
params=params,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
samples=int(num_samples),
|
||||
return_latents=return_latents,
|
||||
filter=filter,
|
||||
noise_strength=stage2strength,
|
||||
filter=state["filter"],
|
||||
)
|
||||
|
||||
show_samples(out, outputs)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def run_img2img(
|
||||
state,
|
||||
version_dict,
|
||||
is_legacy=False,
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
return_latents=False,
|
||||
filter=None,
|
||||
stage2strength=None,
|
||||
):
|
||||
model: SamplingPipeline = state["model"]
|
||||
params: SamplingParams = state["params"]
|
||||
|
||||
img = load_img()
|
||||
if img is None:
|
||||
return None
|
||||
H, W = img.shape[2], img.shape[3]
|
||||
params.height, params.width = img.shape[2], img.shape[3]
|
||||
|
||||
init_dict = {
|
||||
"orig_width": W,
|
||||
"orig_height": H,
|
||||
"target_width": W,
|
||||
"target_height": H,
|
||||
}
|
||||
value_dict = init_embedder_options(
|
||||
get_unique_embedder_keys_from_conditioner(state["model"].conditioner),
|
||||
init_dict,
|
||||
params = init_embedder_options(
|
||||
get_unique_embedder_keys_from_conditioner(model.model.conditioner),
|
||||
params=params,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
)
|
||||
strength = st.number_input(
|
||||
params.img2img_strength = st.number_input(
|
||||
"**Img2Img Strength**", value=0.75, min_value=0.0, max_value=1.0
|
||||
)
|
||||
sampler, num_rows, num_cols = init_sampling(
|
||||
img2img_strength=strength,
|
||||
stage2strength=stage2strength,
|
||||
)
|
||||
params, num_rows, num_cols = init_sampling(params=params)
|
||||
num_samples = num_rows * num_cols
|
||||
|
||||
if st.button("Sample"):
|
||||
out = do_img2img(
|
||||
repeat(img, "1 ... -> n ...", n=num_samples),
|
||||
state["model"],
|
||||
sampler,
|
||||
value_dict,
|
||||
num_samples,
|
||||
force_uc_zero_embeddings=["txt"] if not is_legacy else [],
|
||||
outputs = st.empty()
|
||||
st.text("Sampling")
|
||||
out = model.image_to_image(
|
||||
image=repeat(img, "1 ... -> n ...", n=num_samples),
|
||||
params=params,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
samples=int(num_samples),
|
||||
return_latents=return_latents,
|
||||
filter=filter,
|
||||
noise_strength=stage2strength,
|
||||
filter=state["filter"],
|
||||
)
|
||||
|
||||
show_samples(out, outputs)
|
||||
return out
|
||||
|
||||
|
||||
def apply_refiner(
|
||||
input,
|
||||
state,
|
||||
sampler,
|
||||
num_samples,
|
||||
prompt,
|
||||
negative_prompt,
|
||||
filter=None,
|
||||
num_samples: int,
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
finish_denoising=False,
|
||||
):
|
||||
init_dict = {
|
||||
"orig_width": input.shape[3] * 8,
|
||||
"orig_height": input.shape[2] * 8,
|
||||
"target_width": input.shape[3] * 8,
|
||||
"target_height": input.shape[2] * 8,
|
||||
}
|
||||
model: SamplingPipeline = state["model"]
|
||||
params: SamplingParams = state["params"]
|
||||
|
||||
value_dict = init_dict
|
||||
value_dict["prompt"] = prompt
|
||||
value_dict["negative_prompt"] = negative_prompt
|
||||
|
||||
value_dict["crop_coords_top"] = 0
|
||||
value_dict["crop_coords_left"] = 0
|
||||
|
||||
value_dict["aesthetic_score"] = 6.0
|
||||
value_dict["negative_aesthetic_score"] = 2.5
|
||||
params.orig_width = input.shape[3] * 8
|
||||
params.orig_height = input.shape[2] * 8
|
||||
params.width = input.shape[3] * 8
|
||||
params.height = input.shape[2] * 8
|
||||
|
||||
st.warning(f"refiner input shape: {input.shape}")
|
||||
samples = do_img2img(
|
||||
input,
|
||||
state["model"],
|
||||
sampler,
|
||||
value_dict,
|
||||
num_samples,
|
||||
skip_encode=True,
|
||||
filter=filter,
|
||||
|
||||
samples = model.refiner(
|
||||
image=input,
|
||||
params=params,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
samples=num_samples,
|
||||
return_latents=False,
|
||||
filter=state["filter"],
|
||||
add_noise=not finish_denoising,
|
||||
)
|
||||
|
||||
return samples
|
||||
|
||||
|
||||
sdxl_base_model_list = [
|
||||
ModelArchitecture.SDXL_V1_0_BASE,
|
||||
ModelArchitecture.SDXL_V0_9_BASE,
|
||||
]
|
||||
|
||||
sdxl_refiner_model_list = [
|
||||
ModelArchitecture.SDXL_V1_0_REFINER,
|
||||
ModelArchitecture.SDXL_V0_9_REFINER,
|
||||
]
|
||||
|
||||
if __name__ == "__main__":
|
||||
st.title("Stable Diffusion")
|
||||
version = st.selectbox("Model Version", list(VERSION2SPECS.keys()), 0)
|
||||
version_dict = VERSION2SPECS[version]
|
||||
version = st.selectbox(
|
||||
"Model Version",
|
||||
[member.value for member in ModelArchitecture],
|
||||
0,
|
||||
)
|
||||
version_enum = ModelArchitecture(version)
|
||||
specs = model_specs[version_enum]
|
||||
mode = st.radio("Mode", ("txt2img", "img2img"), 0)
|
||||
st.write("__________________________")
|
||||
|
||||
set_lowvram_mode(st.checkbox("Low vram mode", True))
|
||||
st.write("**Performance Options:**")
|
||||
use_fp16 = st.checkbox("Use fp16 (Saves VRAM)", True)
|
||||
enable_swap = st.checkbox("Swap models to CPU (Saves VRAM, uses RAM)", True)
|
||||
st.write("__________________________")
|
||||
|
||||
if version.startswith("SDXL-base"):
|
||||
if version_enum in sdxl_base_model_list:
|
||||
add_pipeline = st.checkbox("Load SDXL-refiner?", False)
|
||||
st.write("__________________________")
|
||||
else:
|
||||
add_pipeline = False
|
||||
|
||||
seed = st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9))
|
||||
seed = int(
|
||||
st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9))
|
||||
)
|
||||
seed_everything(seed)
|
||||
|
||||
save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, version))
|
||||
|
||||
state = init_st(version_dict, load_filter=True)
|
||||
if state["msg"]:
|
||||
st.info(state["msg"])
|
||||
save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, str(version)))
|
||||
state = init_st(
|
||||
model_specs[version_enum],
|
||||
load_filter=True,
|
||||
use_fp16=use_fp16,
|
||||
enable_swap=enable_swap,
|
||||
)
|
||||
model = state["model"]
|
||||
|
||||
is_legacy = version_dict["is_legacy"]
|
||||
is_legacy = specs.is_legacy
|
||||
|
||||
prompt = st.text_input(
|
||||
"prompt",
|
||||
@@ -290,47 +268,58 @@ if __name__ == "__main__":
|
||||
|
||||
if add_pipeline:
|
||||
st.write("__________________________")
|
||||
version2 = st.selectbox("Refiner:", ["SDXL-refiner-1.0", "SDXL-refiner-0.9"])
|
||||
version2 = ModelArchitecture(
|
||||
st.selectbox(
|
||||
"Refiner:",
|
||||
[member.value for member in sdxl_refiner_model_list],
|
||||
)
|
||||
)
|
||||
st.warning(
|
||||
f"Running with {version2} as the second stage model. Make sure to provide (V)RAM :) "
|
||||
)
|
||||
st.write("**Refiner Options:**")
|
||||
|
||||
version_dict2 = VERSION2SPECS[version2]
|
||||
state2 = init_st(version_dict2, load_filter=False)
|
||||
st.info(state2["msg"])
|
||||
specs2 = model_specs[version2]
|
||||
state2 = init_st(
|
||||
specs2, load_filter=False, use_fp16=use_fp16, enable_swap=enable_swap
|
||||
)
|
||||
params2 = state2["params"]
|
||||
|
||||
stage2strength = st.number_input(
|
||||
params2.img2img_strength = st.number_input(
|
||||
"**Refinement strength**", value=0.15, min_value=0.0, max_value=1.0
|
||||
)
|
||||
|
||||
sampler2, *_ = init_sampling(
|
||||
params2, *_ = init_sampling(
|
||||
params=state2["params"],
|
||||
key=2,
|
||||
img2img_strength=stage2strength,
|
||||
specify_num_samples=False,
|
||||
)
|
||||
st.write("__________________________")
|
||||
finish_denoising = st.checkbox("Finish denoising with refiner.", True)
|
||||
if not finish_denoising:
|
||||
if finish_denoising:
|
||||
stage2strength = params2.img2img_strength
|
||||
else:
|
||||
stage2strength = None
|
||||
else:
|
||||
state2 = None
|
||||
params2 = None
|
||||
stage2strength = None
|
||||
|
||||
if mode == "txt2img":
|
||||
out = run_txt2img(
|
||||
state,
|
||||
version,
|
||||
version_dict,
|
||||
is_legacy=is_legacy,
|
||||
state=state,
|
||||
model_id=version_enum,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
return_latents=add_pipeline,
|
||||
filter=state.get("filter"),
|
||||
stage2strength=stage2strength,
|
||||
)
|
||||
elif mode == "img2img":
|
||||
out = run_img2img(
|
||||
state,
|
||||
version_dict,
|
||||
is_legacy=is_legacy,
|
||||
state=state,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
return_latents=add_pipeline,
|
||||
filter=state.get("filter"),
|
||||
stage2strength=stage2strength,
|
||||
)
|
||||
else:
|
||||
@@ -342,17 +331,17 @@ if __name__ == "__main__":
|
||||
samples_z = None
|
||||
|
||||
if add_pipeline and samples_z is not None:
|
||||
outputs = st.empty()
|
||||
st.write("**Running Refinement Stage**")
|
||||
samples = apply_refiner(
|
||||
samples_z,
|
||||
state2,
|
||||
sampler2,
|
||||
samples_z.shape[0],
|
||||
input=samples_z,
|
||||
state=state2,
|
||||
num_samples=samples_z.shape[0],
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt if is_legacy else "",
|
||||
filter=state.get("filter"),
|
||||
finish_denoising=finish_denoising,
|
||||
)
|
||||
show_samples(samples, outputs)
|
||||
|
||||
if save_locally and samples is not None:
|
||||
perform_save_locally(save_path, samples)
|
||||
|
||||
@@ -1,166 +1,68 @@
|
||||
import math
|
||||
import os
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import streamlit as st
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
from imwatermark import WatermarkEncoder
|
||||
from omegaconf import ListConfig, OmegaConf
|
||||
from PIL import Image
|
||||
from safetensors.torch import load_file as load_safetensors
|
||||
from torch import autocast
|
||||
from torchvision import transforms
|
||||
from torchvision.utils import make_grid
|
||||
from typing import Optional, Tuple, Dict, Any
|
||||
|
||||
|
||||
from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering
|
||||
from sgm.modules.diffusionmodules.sampling import (
|
||||
DPMPP2MSampler,
|
||||
DPMPP2SAncestralSampler,
|
||||
EulerAncestralSampler,
|
||||
EulerEDMSampler,
|
||||
HeunEDMSampler,
|
||||
LinearMultistepSampler,
|
||||
|
||||
from sgm.inference.api import (
|
||||
Discretization,
|
||||
Guider,
|
||||
Sampler,
|
||||
SamplingParams,
|
||||
SamplingSpec,
|
||||
SamplingPipeline,
|
||||
Thresholder,
|
||||
)
|
||||
from sgm.util import append_dims, instantiate_from_config
|
||||
|
||||
|
||||
class WatermarkEmbedder:
|
||||
def __init__(self, watermark):
|
||||
self.watermark = watermark
|
||||
self.num_bits = len(WATERMARK_BITS)
|
||||
self.encoder = WatermarkEncoder()
|
||||
self.encoder.set_watermark("bits", self.watermark)
|
||||
|
||||
def __call__(self, image: torch.Tensor):
|
||||
"""
|
||||
Adds a predefined watermark to the input image
|
||||
|
||||
Args:
|
||||
image: ([N,] B, C, H, W) in range [0, 1]
|
||||
|
||||
Returns:
|
||||
same as input but watermarked
|
||||
"""
|
||||
# watermarking libary expects input as cv2 BGR format
|
||||
squeeze = len(image.shape) == 4
|
||||
if squeeze:
|
||||
image = image[None, ...]
|
||||
n = image.shape[0]
|
||||
image_np = rearrange(
|
||||
(255 * image).detach().cpu(), "n b c h w -> (n b) h w c"
|
||||
).numpy()[:, :, :, ::-1]
|
||||
# torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
|
||||
for k in range(image_np.shape[0]):
|
||||
image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
|
||||
image = torch.from_numpy(
|
||||
rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)
|
||||
).to(image.device)
|
||||
image = torch.clamp(image / 255, min=0.0, max=1.0)
|
||||
if squeeze:
|
||||
image = image[0]
|
||||
return image
|
||||
|
||||
|
||||
# A fixed 48-bit message that was choosen at random
|
||||
# WATERMARK_MESSAGE = 0xB3EC907BB19E
|
||||
WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110
|
||||
# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
|
||||
WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
|
||||
embed_watemark = WatermarkEmbedder(WATERMARK_BITS)
|
||||
from sgm.inference.helpers import embed_watermark, CudaModelManager
|
||||
|
||||
|
||||
@st.cache_resource()
|
||||
def init_st(version_dict, load_ckpt=True, load_filter=True):
|
||||
state = dict()
|
||||
if not "model" in state:
|
||||
config = version_dict["config"]
|
||||
ckpt = version_dict["ckpt"]
|
||||
def init_st(
|
||||
spec: SamplingSpec,
|
||||
load_ckpt=True,
|
||||
load_filter=True,
|
||||
use_fp16=True,
|
||||
enable_swap=True,
|
||||
) -> Dict[str, Any]:
|
||||
state: Dict[str, Any] = dict()
|
||||
config = spec.config
|
||||
ckpt = spec.ckpt
|
||||
|
||||
config = OmegaConf.load(config)
|
||||
model, msg = load_model_from_config(config, ckpt if load_ckpt else None)
|
||||
if enable_swap:
|
||||
pipeline = SamplingPipeline(
|
||||
model_spec=spec,
|
||||
use_fp16=use_fp16,
|
||||
device=CudaModelManager(device="cuda", swap_device="cpu"),
|
||||
)
|
||||
else:
|
||||
pipeline = SamplingPipeline(model_spec=spec, use_fp16=use_fp16)
|
||||
|
||||
state["msg"] = msg
|
||||
state["model"] = model
|
||||
state["ckpt"] = ckpt if load_ckpt else None
|
||||
state["config"] = config
|
||||
if load_filter:
|
||||
state["filter"] = DeepFloydDataFiltering(verbose=False)
|
||||
state["spec"] = spec
|
||||
state["model"] = pipeline
|
||||
state["ckpt"] = ckpt if load_ckpt else None
|
||||
state["config"] = config
|
||||
state["params"] = spec.default_params
|
||||
if load_filter:
|
||||
state["filter"] = DeepFloydDataFiltering(verbose=False)
|
||||
else:
|
||||
state["filter"] = None
|
||||
return state
|
||||
|
||||
|
||||
def load_model(model):
|
||||
model.cuda()
|
||||
|
||||
|
||||
lowvram_mode = False
|
||||
|
||||
|
||||
def set_lowvram_mode(mode):
|
||||
global lowvram_mode
|
||||
lowvram_mode = mode
|
||||
|
||||
|
||||
def initial_model_load(model):
|
||||
global lowvram_mode
|
||||
if lowvram_mode:
|
||||
model.model.half()
|
||||
else:
|
||||
model.cuda()
|
||||
return model
|
||||
|
||||
|
||||
def unload_model(model):
|
||||
global lowvram_mode
|
||||
if lowvram_mode:
|
||||
model.cpu()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def load_model_from_config(config, ckpt=None, verbose=True):
|
||||
model = instantiate_from_config(config.model)
|
||||
|
||||
if ckpt is not None:
|
||||
print(f"Loading model from {ckpt}")
|
||||
if ckpt.endswith("ckpt"):
|
||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||
if "global_step" in pl_sd:
|
||||
global_step = pl_sd["global_step"]
|
||||
st.info(f"loaded ckpt from global step {global_step}")
|
||||
print(f"Global Step: {pl_sd['global_step']}")
|
||||
sd = pl_sd["state_dict"]
|
||||
elif ckpt.endswith("safetensors"):
|
||||
sd = load_safetensors(ckpt)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
msg = None
|
||||
|
||||
m, u = model.load_state_dict(sd, strict=False)
|
||||
|
||||
if len(m) > 0 and verbose:
|
||||
print("missing keys:")
|
||||
print(m)
|
||||
if len(u) > 0 and verbose:
|
||||
print("unexpected keys:")
|
||||
print(u)
|
||||
else:
|
||||
msg = None
|
||||
|
||||
model = initial_model_load(model)
|
||||
model.eval()
|
||||
return model, msg
|
||||
|
||||
|
||||
def get_unique_embedder_keys_from_conditioner(conditioner):
|
||||
return list(set([x.input_key for x in conditioner.embedders]))
|
||||
|
||||
|
||||
def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None):
|
||||
# Hardcoded demo settings; might undergo some changes in the future
|
||||
|
||||
value_dict = {}
|
||||
def init_embedder_options(
|
||||
keys, params: SamplingParams, prompt=None, negative_prompt=None
|
||||
) -> SamplingParams:
|
||||
for key in keys:
|
||||
if key == "txt":
|
||||
if prompt is None:
|
||||
@@ -170,46 +72,38 @@ def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None):
|
||||
if negative_prompt is None:
|
||||
negative_prompt = st.text_input("Negative prompt", "")
|
||||
|
||||
value_dict["prompt"] = prompt
|
||||
value_dict["negative_prompt"] = negative_prompt
|
||||
|
||||
if key == "original_size_as_tuple":
|
||||
orig_width = st.number_input(
|
||||
"orig_width",
|
||||
value=init_dict["orig_width"],
|
||||
value=params.orig_width,
|
||||
min_value=16,
|
||||
)
|
||||
orig_height = st.number_input(
|
||||
"orig_height",
|
||||
value=init_dict["orig_height"],
|
||||
value=params.orig_height,
|
||||
min_value=16,
|
||||
)
|
||||
|
||||
value_dict["orig_width"] = orig_width
|
||||
value_dict["orig_height"] = orig_height
|
||||
params.orig_width = int(orig_width)
|
||||
params.orig_height = int(orig_height)
|
||||
|
||||
if key == "crop_coords_top_left":
|
||||
crop_coord_top = st.number_input("crop_coords_top", value=0, min_value=0)
|
||||
crop_coord_left = st.number_input("crop_coords_left", value=0, min_value=0)
|
||||
crop_coord_top = st.number_input(
|
||||
"crop_coords_top", value=params.crop_coords_top, min_value=0
|
||||
)
|
||||
crop_coord_left = st.number_input(
|
||||
"crop_coords_left", value=params.crop_coords_left, min_value=0
|
||||
)
|
||||
|
||||
value_dict["crop_coords_top"] = crop_coord_top
|
||||
value_dict["crop_coords_left"] = crop_coord_left
|
||||
|
||||
if key == "aesthetic_score":
|
||||
value_dict["aesthetic_score"] = 6.0
|
||||
value_dict["negative_aesthetic_score"] = 2.5
|
||||
|
||||
if key == "target_size_as_tuple":
|
||||
value_dict["target_width"] = init_dict["target_width"]
|
||||
value_dict["target_height"] = init_dict["target_height"]
|
||||
|
||||
return value_dict
|
||||
params.crop_coords_top = int(crop_coord_top)
|
||||
params.crop_coords_left = int(crop_coord_left)
|
||||
return params
|
||||
|
||||
|
||||
def perform_save_locally(save_path, samples):
|
||||
os.makedirs(os.path.join(save_path), exist_ok=True)
|
||||
base_count = len(os.listdir(os.path.join(save_path)))
|
||||
samples = embed_watemark(samples)
|
||||
samples = embed_watermark(samples)
|
||||
for sample in samples:
|
||||
sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c")
|
||||
Image.fromarray(sample.astype(np.uint8)).save(
|
||||
@@ -228,78 +122,26 @@ def init_save_locally(_dir, init_value: bool = False):
|
||||
return save_locally, save_path
|
||||
|
||||
|
||||
class Img2ImgDiscretizationWrapper:
|
||||
"""
|
||||
wraps a discretizer, and prunes the sigmas
|
||||
params:
|
||||
strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned)
|
||||
"""
|
||||
|
||||
def __init__(self, discretization, strength: float = 1.0):
|
||||
self.discretization = discretization
|
||||
self.strength = strength
|
||||
assert 0.0 <= self.strength <= 1.0
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
# sigmas start large first, and decrease then
|
||||
sigmas = self.discretization(*args, **kwargs)
|
||||
print(f"sigmas after discretization, before pruning img2img: ", sigmas)
|
||||
sigmas = torch.flip(sigmas, (0,))
|
||||
sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)]
|
||||
print("prune index:", max(int(self.strength * len(sigmas)), 1))
|
||||
sigmas = torch.flip(sigmas, (0,))
|
||||
print(f"sigmas after pruning: ", sigmas)
|
||||
return sigmas
|
||||
def show_samples(samples, outputs):
|
||||
if isinstance(samples, tuple):
|
||||
samples, _ = samples
|
||||
grid = embed_watermark(torch.stack([samples]))
|
||||
grid = rearrange(grid, "n b c h w -> (n h) (b w) c")
|
||||
outputs.image(grid.cpu().numpy())
|
||||
|
||||
|
||||
class Txt2NoisyDiscretizationWrapper:
|
||||
"""
|
||||
wraps a discretizer, and prunes the sigmas
|
||||
params:
|
||||
strength: float between 0.0 and 1.0. 0.0 means full sampling (all sigmas are returned)
|
||||
"""
|
||||
|
||||
def __init__(self, discretization, strength: float = 0.0, original_steps=None):
|
||||
self.discretization = discretization
|
||||
self.strength = strength
|
||||
self.original_steps = original_steps
|
||||
assert 0.0 <= self.strength <= 1.0
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
# sigmas start large first, and decrease then
|
||||
sigmas = self.discretization(*args, **kwargs)
|
||||
print(f"sigmas after discretization, before pruning img2img: ", sigmas)
|
||||
sigmas = torch.flip(sigmas, (0,))
|
||||
if self.original_steps is None:
|
||||
steps = len(sigmas)
|
||||
else:
|
||||
steps = self.original_steps + 1
|
||||
prune_index = max(min(int(self.strength * steps) - 1, steps - 1), 0)
|
||||
sigmas = sigmas[prune_index:]
|
||||
print("prune index:", prune_index)
|
||||
sigmas = torch.flip(sigmas, (0,))
|
||||
print(f"sigmas after pruning: ", sigmas)
|
||||
return sigmas
|
||||
|
||||
|
||||
def get_guider(key):
|
||||
guider = st.sidebar.selectbox(
|
||||
f"Discretization #{key}",
|
||||
[
|
||||
"VanillaCFG",
|
||||
"IdentityGuider",
|
||||
],
|
||||
def get_guider(params: SamplingParams, key=1) -> SamplingParams:
|
||||
params.guider = Guider(
|
||||
st.sidebar.selectbox(
|
||||
f"Discretization #{key}", [member.value for member in Guider]
|
||||
)
|
||||
)
|
||||
|
||||
if guider == "IdentityGuider":
|
||||
guider_config = {
|
||||
"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"
|
||||
}
|
||||
elif guider == "VanillaCFG":
|
||||
if params.guider == Guider.VANILLA:
|
||||
scale = st.number_input(
|
||||
f"cfg-scale #{key}", value=5.0, min_value=0.0, max_value=100.0
|
||||
f"cfg-scale #{key}", value=params.scale, min_value=0.0, max_value=100.0
|
||||
)
|
||||
|
||||
params.scale = scale
|
||||
thresholder = st.sidebar.selectbox(
|
||||
f"Thresholder #{key}",
|
||||
[
|
||||
@@ -308,182 +150,97 @@ def get_guider(key):
|
||||
)
|
||||
|
||||
if thresholder == "None":
|
||||
dyn_thresh_config = {
|
||||
"target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"
|
||||
}
|
||||
params.thresholder = Thresholder.NONE
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
guider_config = {
|
||||
"target": "sgm.modules.diffusionmodules.guiders.VanillaCFG",
|
||||
"params": {"scale": scale, "dyn_thresh_config": dyn_thresh_config},
|
||||
}
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return guider_config
|
||||
return params
|
||||
|
||||
|
||||
def init_sampling(
|
||||
params: SamplingParams,
|
||||
key=1,
|
||||
img2img_strength=1.0,
|
||||
specify_num_samples=True,
|
||||
stage2strength=None,
|
||||
):
|
||||
) -> Tuple[SamplingParams, int, int]:
|
||||
num_rows, num_cols = 1, 1
|
||||
if specify_num_samples:
|
||||
num_cols = st.number_input(
|
||||
f"num cols #{key}", value=2, min_value=1, max_value=10
|
||||
)
|
||||
|
||||
steps = st.sidebar.number_input(
|
||||
f"steps #{key}", value=40, min_value=1, max_value=1000
|
||||
)
|
||||
sampler = st.sidebar.selectbox(
|
||||
f"Sampler #{key}",
|
||||
[
|
||||
"EulerEDMSampler",
|
||||
"HeunEDMSampler",
|
||||
"EulerAncestralSampler",
|
||||
"DPMPP2SAncestralSampler",
|
||||
"DPMPP2MSampler",
|
||||
"LinearMultistepSampler",
|
||||
],
|
||||
0,
|
||||
)
|
||||
discretization = st.sidebar.selectbox(
|
||||
f"Discretization #{key}",
|
||||
[
|
||||
"LegacyDDPMDiscretization",
|
||||
"EDMDiscretization",
|
||||
],
|
||||
params.steps = int(
|
||||
st.sidebar.number_input(
|
||||
f"steps #{key}", value=params.steps, min_value=1, max_value=1000
|
||||
)
|
||||
)
|
||||
|
||||
discretization_config = get_discretization(discretization, key=key)
|
||||
|
||||
guider_config = get_guider(key=key)
|
||||
|
||||
sampler = get_sampler(sampler, steps, discretization_config, guider_config, key=key)
|
||||
if img2img_strength < 1.0:
|
||||
st.warning(
|
||||
f"Wrapping {sampler.__class__.__name__} with Img2ImgDiscretizationWrapper"
|
||||
params.sampler = Sampler(
|
||||
st.sidebar.selectbox(
|
||||
f"Sampler #{key}",
|
||||
[member.value for member in Sampler],
|
||||
0,
|
||||
)
|
||||
sampler.discretization = Img2ImgDiscretizationWrapper(
|
||||
sampler.discretization, strength=img2img_strength
|
||||
)
|
||||
params.discretization = Discretization(
|
||||
st.sidebar.selectbox(
|
||||
f"Discretization #{key}",
|
||||
[member.value for member in Discretization],
|
||||
)
|
||||
if stage2strength is not None:
|
||||
sampler.discretization = Txt2NoisyDiscretizationWrapper(
|
||||
sampler.discretization, strength=stage2strength, original_steps=steps
|
||||
)
|
||||
|
||||
params = get_discretization(params=params, key=key)
|
||||
params = get_guider(params=params, key=key)
|
||||
params = get_sampler(params=params, key=key)
|
||||
|
||||
return params, num_rows, num_cols
|
||||
|
||||
|
||||
def get_discretization(params: SamplingParams, key=1) -> SamplingParams:
|
||||
if params.discretization == Discretization.EDM:
|
||||
params.sigma_min = st.number_input(f"sigma_min #{key}", value=params.sigma_min)
|
||||
params.sigma_max = st.number_input(f"sigma_max #{key}", value=params.sigma_max)
|
||||
params.rho = st.number_input(f"rho #{key}", value=params.rho)
|
||||
return params
|
||||
|
||||
|
||||
def get_sampler(params: SamplingParams, key=1) -> SamplingParams:
|
||||
if params.sampler in (Sampler.EULER_EDM, Sampler.HEUN_EDM):
|
||||
params.s_churn = st.sidebar.number_input(
|
||||
f"s_churn #{key}", value=params.s_churn, min_value=0.0
|
||||
)
|
||||
return sampler, num_rows, num_cols
|
||||
|
||||
|
||||
def get_discretization(discretization, key=1):
|
||||
if discretization == "LegacyDDPMDiscretization":
|
||||
discretization_config = {
|
||||
"target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",
|
||||
}
|
||||
elif discretization == "EDMDiscretization":
|
||||
sigma_min = st.number_input(f"sigma_min #{key}", value=0.03) # 0.0292
|
||||
sigma_max = st.number_input(f"sigma_max #{key}", value=14.61) # 14.6146
|
||||
rho = st.number_input(f"rho #{key}", value=3.0)
|
||||
discretization_config = {
|
||||
"target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization",
|
||||
"params": {
|
||||
"sigma_min": sigma_min,
|
||||
"sigma_max": sigma_max,
|
||||
"rho": rho,
|
||||
},
|
||||
}
|
||||
|
||||
return discretization_config
|
||||
|
||||
|
||||
def get_sampler(sampler_name, steps, discretization_config, guider_config, key=1):
|
||||
if sampler_name == "EulerEDMSampler" or sampler_name == "HeunEDMSampler":
|
||||
s_churn = st.sidebar.number_input(f"s_churn #{key}", value=0.0, min_value=0.0)
|
||||
s_tmin = st.sidebar.number_input(f"s_tmin #{key}", value=0.0, min_value=0.0)
|
||||
s_tmax = st.sidebar.number_input(f"s_tmax #{key}", value=999.0, min_value=0.0)
|
||||
s_noise = st.sidebar.number_input(f"s_noise #{key}", value=1.0, min_value=0.0)
|
||||
|
||||
if sampler_name == "EulerEDMSampler":
|
||||
sampler = EulerEDMSampler(
|
||||
num_steps=steps,
|
||||
discretization_config=discretization_config,
|
||||
guider_config=guider_config,
|
||||
s_churn=s_churn,
|
||||
s_tmin=s_tmin,
|
||||
s_tmax=s_tmax,
|
||||
s_noise=s_noise,
|
||||
verbose=True,
|
||||
)
|
||||
elif sampler_name == "HeunEDMSampler":
|
||||
sampler = HeunEDMSampler(
|
||||
num_steps=steps,
|
||||
discretization_config=discretization_config,
|
||||
guider_config=guider_config,
|
||||
s_churn=s_churn,
|
||||
s_tmin=s_tmin,
|
||||
s_tmax=s_tmax,
|
||||
s_noise=s_noise,
|
||||
verbose=True,
|
||||
)
|
||||
elif (
|
||||
sampler_name == "EulerAncestralSampler"
|
||||
or sampler_name == "DPMPP2SAncestralSampler"
|
||||
):
|
||||
s_noise = st.sidebar.number_input("s_noise", value=1.0, min_value=0.0)
|
||||
eta = st.sidebar.number_input("eta", value=1.0, min_value=0.0)
|
||||
|
||||
if sampler_name == "EulerAncestralSampler":
|
||||
sampler = EulerAncestralSampler(
|
||||
num_steps=steps,
|
||||
discretization_config=discretization_config,
|
||||
guider_config=guider_config,
|
||||
eta=eta,
|
||||
s_noise=s_noise,
|
||||
verbose=True,
|
||||
)
|
||||
elif sampler_name == "DPMPP2SAncestralSampler":
|
||||
sampler = DPMPP2SAncestralSampler(
|
||||
num_steps=steps,
|
||||
discretization_config=discretization_config,
|
||||
guider_config=guider_config,
|
||||
eta=eta,
|
||||
s_noise=s_noise,
|
||||
verbose=True,
|
||||
)
|
||||
elif sampler_name == "DPMPP2MSampler":
|
||||
sampler = DPMPP2MSampler(
|
||||
num_steps=steps,
|
||||
discretization_config=discretization_config,
|
||||
guider_config=guider_config,
|
||||
verbose=True,
|
||||
params.s_tmin = st.sidebar.number_input(
|
||||
f"s_tmin #{key}", value=params.s_tmin, min_value=0.0
|
||||
)
|
||||
elif sampler_name == "LinearMultistepSampler":
|
||||
order = st.sidebar.number_input("order", value=4, min_value=1)
|
||||
sampler = LinearMultistepSampler(
|
||||
num_steps=steps,
|
||||
discretization_config=discretization_config,
|
||||
guider_config=guider_config,
|
||||
order=order,
|
||||
verbose=True,
|
||||
params.s_tmax = st.sidebar.number_input(
|
||||
f"s_tmax #{key}", value=params.s_tmax, min_value=0.0
|
||||
)
|
||||
params.s_noise = st.sidebar.number_input(
|
||||
f"s_noise #{key}", value=params.s_noise, min_value=0.0
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unknown sampler {sampler_name}!")
|
||||
|
||||
return sampler
|
||||
elif params.sampler in (Sampler.EULER_ANCESTRAL, Sampler.DPMPP2S_ANCESTRAL):
|
||||
params.s_noise = st.sidebar.number_input(
|
||||
"s_noise", value=params.s_noise, min_value=0.0
|
||||
)
|
||||
params.eta = st.sidebar.number_input("eta", value=params.eta, min_value=0.0)
|
||||
|
||||
elif params.sampler == Sampler.LINEAR_MULTISTEP:
|
||||
params.order = int(
|
||||
st.sidebar.number_input("order", value=params.order, min_value=1)
|
||||
)
|
||||
return params
|
||||
|
||||
|
||||
def get_interactive_image(key=None) -> Image.Image:
|
||||
def get_interactive_image(key=None) -> Optional[Image.Image]:
|
||||
image = st.file_uploader("Input", type=["jpg", "JPEG", "png"], key=key)
|
||||
if image is not None:
|
||||
image = Image.open(image)
|
||||
if not image.mode == "RGB":
|
||||
image = image.convert("RGB")
|
||||
return image
|
||||
return None
|
||||
|
||||
|
||||
def load_img(display=True, key=None):
|
||||
def load_img(display=True, key=None) -> Optional[torch.Tensor]:
|
||||
image = get_interactive_image(key=key)
|
||||
if image is None:
|
||||
return None
|
||||
@@ -507,238 +264,3 @@ def get_init_img(batch_size=1, key=None):
|
||||
init_image = load_img(key=key).cuda()
|
||||
init_image = repeat(init_image, "1 ... -> b ...", b=batch_size)
|
||||
return init_image
|
||||
|
||||
|
||||
def do_sample(
|
||||
model,
|
||||
sampler,
|
||||
value_dict,
|
||||
num_samples,
|
||||
H,
|
||||
W,
|
||||
C,
|
||||
F,
|
||||
force_uc_zero_embeddings: List = None,
|
||||
batch2model_input: List = None,
|
||||
return_latents=False,
|
||||
filter=None,
|
||||
):
|
||||
if force_uc_zero_embeddings is None:
|
||||
force_uc_zero_embeddings = []
|
||||
if batch2model_input is None:
|
||||
batch2model_input = []
|
||||
|
||||
st.text("Sampling")
|
||||
|
||||
outputs = st.empty()
|
||||
precision_scope = autocast
|
||||
with torch.no_grad():
|
||||
with precision_scope("cuda"):
|
||||
with model.ema_scope():
|
||||
num_samples = [num_samples]
|
||||
load_model(model.conditioner)
|
||||
batch, batch_uc = get_batch(
|
||||
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
||||
value_dict,
|
||||
num_samples,
|
||||
)
|
||||
for key in batch:
|
||||
if isinstance(batch[key], torch.Tensor):
|
||||
print(key, batch[key].shape)
|
||||
elif isinstance(batch[key], list):
|
||||
print(key, [len(l) for l in batch[key]])
|
||||
else:
|
||||
print(key, batch[key])
|
||||
c, uc = model.conditioner.get_unconditional_conditioning(
|
||||
batch,
|
||||
batch_uc=batch_uc,
|
||||
force_uc_zero_embeddings=force_uc_zero_embeddings,
|
||||
)
|
||||
unload_model(model.conditioner)
|
||||
|
||||
for k in c:
|
||||
if not k == "crossattn":
|
||||
c[k], uc[k] = map(
|
||||
lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc)
|
||||
)
|
||||
|
||||
additional_model_inputs = {}
|
||||
for k in batch2model_input:
|
||||
additional_model_inputs[k] = batch[k]
|
||||
|
||||
shape = (math.prod(num_samples), C, H // F, W // F)
|
||||
randn = torch.randn(shape).to("cuda")
|
||||
|
||||
def denoiser(input, sigma, c):
|
||||
return model.denoiser(
|
||||
model.model, input, sigma, c, **additional_model_inputs
|
||||
)
|
||||
|
||||
load_model(model.denoiser)
|
||||
load_model(model.model)
|
||||
samples_z = sampler(denoiser, randn, cond=c, uc=uc)
|
||||
unload_model(model.model)
|
||||
unload_model(model.denoiser)
|
||||
|
||||
load_model(model.first_stage_model)
|
||||
samples_x = model.decode_first_stage(samples_z)
|
||||
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
unload_model(model.first_stage_model)
|
||||
|
||||
if filter is not None:
|
||||
samples = filter(samples)
|
||||
|
||||
grid = torch.stack([samples])
|
||||
grid = rearrange(grid, "n b c h w -> (n h) (b w) c")
|
||||
outputs.image(grid.cpu().numpy())
|
||||
|
||||
if return_latents:
|
||||
return samples, samples_z
|
||||
return samples
|
||||
|
||||
|
||||
def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
|
||||
# Hardcoded demo setups; might undergo some changes in the future
|
||||
|
||||
batch = {}
|
||||
batch_uc = {}
|
||||
|
||||
for key in keys:
|
||||
if key == "txt":
|
||||
batch["txt"] = (
|
||||
np.repeat([value_dict["prompt"]], repeats=math.prod(N))
|
||||
.reshape(N)
|
||||
.tolist()
|
||||
)
|
||||
batch_uc["txt"] = (
|
||||
np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N))
|
||||
.reshape(N)
|
||||
.tolist()
|
||||
)
|
||||
elif key == "original_size_as_tuple":
|
||||
batch["original_size_as_tuple"] = (
|
||||
torch.tensor([value_dict["orig_height"], value_dict["orig_width"]])
|
||||
.to(device)
|
||||
.repeat(*N, 1)
|
||||
)
|
||||
elif key == "crop_coords_top_left":
|
||||
batch["crop_coords_top_left"] = (
|
||||
torch.tensor(
|
||||
[value_dict["crop_coords_top"], value_dict["crop_coords_left"]]
|
||||
)
|
||||
.to(device)
|
||||
.repeat(*N, 1)
|
||||
)
|
||||
elif key == "aesthetic_score":
|
||||
batch["aesthetic_score"] = (
|
||||
torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
|
||||
)
|
||||
batch_uc["aesthetic_score"] = (
|
||||
torch.tensor([value_dict["negative_aesthetic_score"]])
|
||||
.to(device)
|
||||
.repeat(*N, 1)
|
||||
)
|
||||
|
||||
elif key == "target_size_as_tuple":
|
||||
batch["target_size_as_tuple"] = (
|
||||
torch.tensor([value_dict["target_height"], value_dict["target_width"]])
|
||||
.to(device)
|
||||
.repeat(*N, 1)
|
||||
)
|
||||
else:
|
||||
batch[key] = value_dict[key]
|
||||
|
||||
for key in batch.keys():
|
||||
if key not in batch_uc and isinstance(batch[key], torch.Tensor):
|
||||
batch_uc[key] = torch.clone(batch[key])
|
||||
return batch, batch_uc
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def do_img2img(
|
||||
img,
|
||||
model,
|
||||
sampler,
|
||||
value_dict,
|
||||
num_samples,
|
||||
force_uc_zero_embeddings=[],
|
||||
additional_kwargs={},
|
||||
offset_noise_level: int = 0.0,
|
||||
return_latents=False,
|
||||
skip_encode=False,
|
||||
filter=None,
|
||||
add_noise=True,
|
||||
):
|
||||
st.text("Sampling")
|
||||
|
||||
outputs = st.empty()
|
||||
precision_scope = autocast
|
||||
with torch.no_grad():
|
||||
with precision_scope("cuda"):
|
||||
with model.ema_scope():
|
||||
load_model(model.conditioner)
|
||||
batch, batch_uc = get_batch(
|
||||
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
||||
value_dict,
|
||||
[num_samples],
|
||||
)
|
||||
c, uc = model.conditioner.get_unconditional_conditioning(
|
||||
batch,
|
||||
batch_uc=batch_uc,
|
||||
force_uc_zero_embeddings=force_uc_zero_embeddings,
|
||||
)
|
||||
unload_model(model.conditioner)
|
||||
for k in c:
|
||||
c[k], uc[k] = map(lambda y: y[k][:num_samples].to("cuda"), (c, uc))
|
||||
|
||||
for k in additional_kwargs:
|
||||
c[k] = uc[k] = additional_kwargs[k]
|
||||
if skip_encode:
|
||||
z = img
|
||||
else:
|
||||
load_model(model.first_stage_model)
|
||||
z = model.encode_first_stage(img)
|
||||
unload_model(model.first_stage_model)
|
||||
|
||||
noise = torch.randn_like(z)
|
||||
|
||||
sigmas = sampler.discretization(sampler.num_steps).cuda()
|
||||
sigma = sigmas[0]
|
||||
|
||||
st.info(f"all sigmas: {sigmas}")
|
||||
st.info(f"noising sigma: {sigma}")
|
||||
if offset_noise_level > 0.0:
|
||||
noise = noise + offset_noise_level * append_dims(
|
||||
torch.randn(z.shape[0], device=z.device), z.ndim
|
||||
)
|
||||
if add_noise:
|
||||
noised_z = z + noise * append_dims(sigma, z.ndim).cuda()
|
||||
noised_z = noised_z / torch.sqrt(
|
||||
1.0 + sigmas[0] ** 2.0
|
||||
) # Note: hardcoded to DDPM-like scaling. need to generalize later.
|
||||
else:
|
||||
noised_z = z / torch.sqrt(1.0 + sigmas[0] ** 2.0)
|
||||
|
||||
def denoiser(x, sigma, c):
|
||||
return model.denoiser(model.model, x, sigma, c)
|
||||
|
||||
load_model(model.denoiser)
|
||||
load_model(model.model)
|
||||
samples_z = sampler(denoiser, noised_z, cond=c, uc=uc)
|
||||
unload_model(model.model)
|
||||
unload_model(model.denoiser)
|
||||
|
||||
load_model(model.first_stage_model)
|
||||
samples_x = model.decode_first_stage(samples_z)
|
||||
unload_model(model.first_stage_model)
|
||||
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
|
||||
if filter is not None:
|
||||
samples = filter(samples)
|
||||
|
||||
grid = embed_watemark(torch.stack([samples]))
|
||||
grid = rearrange(grid, "n b c h w -> (n h) (b w) c")
|
||||
outputs.image(grid.cpu().numpy())
|
||||
if return_latents:
|
||||
return samples, samples_z
|
||||
return samples
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from .models import AutoencodingEngine, DiffusionEngine
|
||||
from .util import get_configs_path, instantiate_from_config
|
||||
|
||||
__version__ = "0.1.0"
|
||||
__version__ = "0.1.1"
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
from dataclasses import dataclass, asdict
|
||||
from enum import Enum
|
||||
from omegaconf import OmegaConf
|
||||
import pathlib
|
||||
import os
|
||||
from sgm.inference.helpers import (
|
||||
do_sample,
|
||||
do_img2img,
|
||||
DeviceModelManager,
|
||||
get_model_manager,
|
||||
Img2ImgDiscretizationWrapper,
|
||||
Txt2NoisyDiscretizationWrapper,
|
||||
)
|
||||
from sgm.modules.diffusionmodules.sampling import (
|
||||
EulerEDMSampler,
|
||||
@@ -15,17 +18,18 @@ from sgm.modules.diffusionmodules.sampling import (
|
||||
DPMPP2MSampler,
|
||||
LinearMultistepSampler,
|
||||
)
|
||||
from sgm.util import load_model_from_config
|
||||
from typing import Optional
|
||||
from sgm.util import load_model_from_config, get_configs_path, get_checkpoints_path
|
||||
import torch
|
||||
from typing import Optional, Dict, Any, Union
|
||||
|
||||
|
||||
class ModelArchitecture(str, Enum):
|
||||
SD_2_1 = "stable-diffusion-v2-1"
|
||||
SD_2_1_768 = "stable-diffusion-v2-1-768"
|
||||
SDXL_V1_0_BASE = "stable-diffusion-xl-v1-base"
|
||||
SDXL_V1_0_REFINER = "stable-diffusion-xl-v1-refiner"
|
||||
SDXL_V0_9_BASE = "stable-diffusion-xl-v0-9-base"
|
||||
SDXL_V0_9_REFINER = "stable-diffusion-xl-v0-9-refiner"
|
||||
SDXL_V1_BASE = "stable-diffusion-xl-v1-base"
|
||||
SDXL_V1_REFINER = "stable-diffusion-xl-v1-refiner"
|
||||
SD_2_1 = "stable-diffusion-v2-1"
|
||||
SD_2_1_768 = "stable-diffusion-v2-1-768"
|
||||
|
||||
|
||||
class Sampler(str, Enum):
|
||||
@@ -53,16 +57,20 @@ class Thresholder(str, Enum):
|
||||
|
||||
@dataclass
|
||||
class SamplingParams:
|
||||
width: int = 1024
|
||||
height: int = 1024
|
||||
steps: int = 50
|
||||
sampler: Sampler = Sampler.DPMPP2M
|
||||
"""
|
||||
Parameters for sampling.
|
||||
"""
|
||||
|
||||
width: Optional[int] = None
|
||||
height: Optional[int] = None
|
||||
steps: Optional[int] = None
|
||||
sampler: Sampler = Sampler.EULER_EDM
|
||||
discretization: Discretization = Discretization.LEGACY_DDPM
|
||||
guider: Guider = Guider.VANILLA
|
||||
thresholder: Thresholder = Thresholder.NONE
|
||||
scale: float = 6.0
|
||||
aesthetic_score: float = 5.0
|
||||
negative_aesthetic_score: float = 5.0
|
||||
scale: float = 5.0
|
||||
aesthetic_score: float = 6.0
|
||||
negative_aesthetic_score: float = 2.5
|
||||
img2img_strength: float = 1.0
|
||||
orig_width: int = 1024
|
||||
orig_height: int = 1024
|
||||
@@ -89,8 +97,10 @@ class SamplingSpec:
|
||||
config: str
|
||||
ckpt: str
|
||||
is_guided: bool
|
||||
default_params: SamplingParams
|
||||
|
||||
|
||||
# The defaults here are derived from user preference testing.
|
||||
model_specs = {
|
||||
ModelArchitecture.SD_2_1: SamplingSpec(
|
||||
height=512,
|
||||
@@ -101,6 +111,12 @@ model_specs = {
|
||||
config="sd_2_1.yaml",
|
||||
ckpt="v2-1_512-ema-pruned.safetensors",
|
||||
is_guided=True,
|
||||
default_params=SamplingParams(
|
||||
width=512,
|
||||
height=512,
|
||||
steps=40,
|
||||
scale=7.0,
|
||||
),
|
||||
),
|
||||
ModelArchitecture.SD_2_1_768: SamplingSpec(
|
||||
height=768,
|
||||
@@ -111,6 +127,12 @@ model_specs = {
|
||||
config="sd_2_1_768.yaml",
|
||||
ckpt="v2-1_768-ema-pruned.safetensors",
|
||||
is_guided=True,
|
||||
default_params=SamplingParams(
|
||||
width=768,
|
||||
height=768,
|
||||
steps=40,
|
||||
scale=7.0,
|
||||
),
|
||||
),
|
||||
ModelArchitecture.SDXL_V0_9_BASE: SamplingSpec(
|
||||
height=1024,
|
||||
@@ -121,6 +143,7 @@ model_specs = {
|
||||
config="sd_xl_base.yaml",
|
||||
ckpt="sd_xl_base_0.9.safetensors",
|
||||
is_guided=True,
|
||||
default_params=SamplingParams(width=1024, height=1024, steps=40, scale=5.0),
|
||||
),
|
||||
ModelArchitecture.SDXL_V0_9_REFINER: SamplingSpec(
|
||||
height=1024,
|
||||
@@ -131,8 +154,11 @@ model_specs = {
|
||||
config="sd_xl_refiner.yaml",
|
||||
ckpt="sd_xl_refiner_0.9.safetensors",
|
||||
is_guided=True,
|
||||
default_params=SamplingParams(
|
||||
width=1024, height=1024, steps=40, scale=5.0, img2img_strength=0.15
|
||||
),
|
||||
),
|
||||
ModelArchitecture.SDXL_V1_BASE: SamplingSpec(
|
||||
ModelArchitecture.SDXL_V1_0_BASE: SamplingSpec(
|
||||
height=1024,
|
||||
width=1024,
|
||||
channels=4,
|
||||
@@ -141,8 +167,9 @@ model_specs = {
|
||||
config="sd_xl_base.yaml",
|
||||
ckpt="sd_xl_base_1.0.safetensors",
|
||||
is_guided=True,
|
||||
default_params=SamplingParams(width=1024, height=1024, steps=40, scale=5.0),
|
||||
),
|
||||
ModelArchitecture.SDXL_V1_REFINER: SamplingSpec(
|
||||
ModelArchitecture.SDXL_V1_0_REFINER: SamplingSpec(
|
||||
height=1024,
|
||||
width=1024,
|
||||
channels=4,
|
||||
@@ -151,34 +178,97 @@ model_specs = {
|
||||
config="sd_xl_refiner.yaml",
|
||||
ckpt="sd_xl_refiner_1.0.safetensors",
|
||||
is_guided=True,
|
||||
default_params=SamplingParams(
|
||||
width=1024, height=1024, steps=40, scale=5.0, img2img_strength=0.15
|
||||
),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def wrap_discretization(
|
||||
discretization, image_strength=None, noise_strength=None, steps=None
|
||||
):
|
||||
if isinstance(discretization, Img2ImgDiscretizationWrapper) or isinstance(
|
||||
discretization, Txt2NoisyDiscretizationWrapper
|
||||
):
|
||||
return discretization # Already wrapped
|
||||
if image_strength is not None and image_strength < 1.0 and image_strength > 0.0:
|
||||
discretization = Img2ImgDiscretizationWrapper(
|
||||
discretization, strength=image_strength
|
||||
)
|
||||
|
||||
if (
|
||||
noise_strength is not None
|
||||
and noise_strength < 1.0
|
||||
and noise_strength > 0.0
|
||||
and steps is not None
|
||||
):
|
||||
discretization = Txt2NoisyDiscretizationWrapper(
|
||||
discretization,
|
||||
strength=noise_strength,
|
||||
original_steps=steps,
|
||||
)
|
||||
return discretization
|
||||
|
||||
|
||||
class SamplingPipeline:
|
||||
def __init__(
|
||||
self,
|
||||
model_id: ModelArchitecture,
|
||||
model_path="checkpoints",
|
||||
config_path="configs/inference",
|
||||
device="cuda",
|
||||
use_fp16=True,
|
||||
model_id: Optional[ModelArchitecture] = None,
|
||||
model_spec: Optional[SamplingSpec] = None,
|
||||
model_path: Optional[str] = None,
|
||||
config_path: Optional[str] = None,
|
||||
use_fp16: bool = True,
|
||||
device: Optional[Union[DeviceModelManager, str, torch.device]] = None,
|
||||
) -> None:
|
||||
if model_id not in model_specs:
|
||||
raise ValueError(f"Model {model_id} not supported")
|
||||
self.model_id = model_id
|
||||
self.specs = model_specs[self.model_id]
|
||||
self.config = str(pathlib.Path(config_path, self.specs.config))
|
||||
self.ckpt = str(pathlib.Path(model_path, self.specs.ckpt))
|
||||
self.device = device
|
||||
self.model = self._load_model(device=device, use_fp16=use_fp16)
|
||||
"""
|
||||
Sampling pipeline for generating images from a model.
|
||||
|
||||
def _load_model(self, device="cuda", use_fp16=True):
|
||||
@param model_id: Model architecture to use. If not specified, model_spec must be specified.
|
||||
@param model_spec: Model specification to use. If not specified, model_id must be specified.
|
||||
@param model_path: Path to model checkpoints folder.
|
||||
@param config_path: Path to model config folder.
|
||||
@param use_fp16: Whether to use fp16 for sampling.
|
||||
@param device: Device manager to use with this pipeline. If a string or torch.device is passed, a device manager will be created based on device type if possible.
|
||||
"""
|
||||
|
||||
self.model_id = model_id
|
||||
if model_spec is not None:
|
||||
self.specs = model_spec
|
||||
elif model_id is not None:
|
||||
if model_id not in model_specs:
|
||||
raise ValueError(f"Model {model_id} not supported")
|
||||
self.specs = model_specs[model_id]
|
||||
else:
|
||||
raise ValueError("Either model_id or model_spec should be provided")
|
||||
|
||||
if model_path is None:
|
||||
model_path = get_checkpoints_path()
|
||||
if config_path is None:
|
||||
config_path = get_configs_path()
|
||||
self.config = os.path.join(config_path, "inference", self.specs.config)
|
||||
self.ckpt = os.path.join(model_path, self.specs.ckpt)
|
||||
if not os.path.exists(self.config):
|
||||
raise ValueError(
|
||||
f"Config {self.config} not found, check model spec or config_path"
|
||||
)
|
||||
if not os.path.exists(self.ckpt):
|
||||
raise ValueError(
|
||||
f"Checkpoint {self.ckpt} not found, check model spec or config_path"
|
||||
)
|
||||
|
||||
self.device_manager = get_model_manager(device)
|
||||
|
||||
self.model = self._load_model(
|
||||
device_manager=self.device_manager, use_fp16=use_fp16
|
||||
)
|
||||
|
||||
def _load_model(self, device_manager: DeviceModelManager, use_fp16=True):
|
||||
config = OmegaConf.load(self.config)
|
||||
model = load_model_from_config(config, self.ckpt)
|
||||
if model is None:
|
||||
raise ValueError(f"Model {self.model_id} could not be loaded")
|
||||
model.to(device)
|
||||
device_manager.load(model)
|
||||
if use_fp16:
|
||||
model.conditioner.half()
|
||||
model.model.half()
|
||||
@@ -186,13 +276,34 @@ class SamplingPipeline:
|
||||
|
||||
def text_to_image(
|
||||
self,
|
||||
params: SamplingParams,
|
||||
prompt: str,
|
||||
params: Optional[SamplingParams] = None,
|
||||
negative_prompt: str = "",
|
||||
samples: int = 1,
|
||||
return_latents: bool = False,
|
||||
noise_strength: Optional[float] = None,
|
||||
filter=None,
|
||||
):
|
||||
if params is None:
|
||||
params = self.specs.default_params
|
||||
else:
|
||||
# Set defaults if optional params are not specified
|
||||
if params.width is None:
|
||||
params.width = self.specs.default_params.width
|
||||
if params.height is None:
|
||||
params.height = self.specs.default_params.height
|
||||
if params.steps is None:
|
||||
params.steps = self.specs.default_params.steps
|
||||
|
||||
sampler = get_sampler_config(params)
|
||||
|
||||
sampler.discretization = wrap_discretization(
|
||||
sampler.discretization,
|
||||
image_strength=None,
|
||||
noise_strength=noise_strength,
|
||||
steps=params.steps,
|
||||
)
|
||||
|
||||
value_dict = asdict(params)
|
||||
value_dict["prompt"] = prompt
|
||||
value_dict["negative_prompt"] = negative_prompt
|
||||
@@ -209,31 +320,40 @@ class SamplingPipeline:
|
||||
self.specs.factor,
|
||||
force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
|
||||
return_latents=return_latents,
|
||||
filter=None,
|
||||
filter=filter,
|
||||
device=self.device_manager,
|
||||
)
|
||||
|
||||
def image_to_image(
|
||||
self,
|
||||
params: SamplingParams,
|
||||
image,
|
||||
image: torch.Tensor,
|
||||
prompt: str,
|
||||
params: Optional[SamplingParams] = None,
|
||||
negative_prompt: str = "",
|
||||
samples: int = 1,
|
||||
return_latents: bool = False,
|
||||
noise_strength: Optional[float] = None,
|
||||
filter=None,
|
||||
):
|
||||
if params is None:
|
||||
params = self.specs.default_params
|
||||
sampler = get_sampler_config(params)
|
||||
|
||||
if params.img2img_strength < 1.0:
|
||||
sampler.discretization = Img2ImgDiscretizationWrapper(
|
||||
sampler.discretization,
|
||||
strength=params.img2img_strength,
|
||||
)
|
||||
sampler.discretization = wrap_discretization(
|
||||
sampler.discretization,
|
||||
image_strength=params.img2img_strength,
|
||||
noise_strength=noise_strength,
|
||||
steps=params.steps,
|
||||
)
|
||||
|
||||
height, width = image.shape[2], image.shape[3]
|
||||
value_dict = asdict(params)
|
||||
value_dict["prompt"] = prompt
|
||||
value_dict["negative_prompt"] = negative_prompt
|
||||
value_dict["target_width"] = width
|
||||
value_dict["target_height"] = height
|
||||
value_dict["orig_width"] = width
|
||||
value_dict["orig_height"] = height
|
||||
return do_img2img(
|
||||
image,
|
||||
self.model,
|
||||
@@ -242,18 +362,24 @@ class SamplingPipeline:
|
||||
samples,
|
||||
force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
|
||||
return_latents=return_latents,
|
||||
filter=None,
|
||||
filter=filter,
|
||||
device=self.device_manager,
|
||||
)
|
||||
|
||||
def refiner(
|
||||
self,
|
||||
params: SamplingParams,
|
||||
image,
|
||||
image: torch.Tensor,
|
||||
prompt: str,
|
||||
negative_prompt: Optional[str] = None,
|
||||
negative_prompt: str = "",
|
||||
params: Optional[SamplingParams] = None,
|
||||
samples: int = 1,
|
||||
return_latents: bool = False,
|
||||
filter: Any = None,
|
||||
add_noise: bool = False,
|
||||
):
|
||||
if params is None:
|
||||
params = self.specs.default_params
|
||||
|
||||
sampler = get_sampler_config(params)
|
||||
value_dict = {
|
||||
"orig_width": image.shape[3] * 8,
|
||||
@@ -268,6 +394,10 @@ class SamplingPipeline:
|
||||
"negative_aesthetic_score": 2.5,
|
||||
}
|
||||
|
||||
sampler.discretization = wrap_discretization(
|
||||
sampler.discretization, image_strength=params.img2img_strength
|
||||
)
|
||||
|
||||
return do_img2img(
|
||||
image,
|
||||
self.model,
|
||||
@@ -276,11 +406,14 @@ class SamplingPipeline:
|
||||
samples,
|
||||
skip_encode=True,
|
||||
return_latents=return_latents,
|
||||
filter=None,
|
||||
filter=filter,
|
||||
add_noise=add_noise,
|
||||
device=self.device_manager,
|
||||
)
|
||||
|
||||
|
||||
def get_guider_config(params: SamplingParams):
|
||||
def get_guider_config(params: SamplingParams) -> Dict[str, Any]:
|
||||
guider_config: Dict[str, Any]
|
||||
if params.guider == Guider.IDENTITY:
|
||||
guider_config = {
|
||||
"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"
|
||||
@@ -306,7 +439,8 @@ def get_guider_config(params: SamplingParams):
|
||||
return guider_config
|
||||
|
||||
|
||||
def get_discretization_config(params: SamplingParams):
|
||||
def get_discretization_config(params: SamplingParams) -> Dict[str, Any]:
|
||||
discretization_config: Dict[str, Any]
|
||||
if params.discretization == Discretization.LEGACY_DDPM:
|
||||
discretization_config = {
|
||||
"target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import contextlib
|
||||
import os
|
||||
from typing import Union, List, Optional
|
||||
|
||||
@@ -8,7 +9,6 @@ from PIL import Image
|
||||
from einops import rearrange
|
||||
from imwatermark import WatermarkEncoder
|
||||
from omegaconf import ListConfig
|
||||
from torch import autocast
|
||||
|
||||
from sgm.util import append_dims
|
||||
|
||||
@@ -58,6 +58,73 @@ WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
|
||||
embed_watermark = WatermarkEmbedder(WATERMARK_BITS)
|
||||
|
||||
|
||||
class DeviceModelManager(object):
|
||||
"""
|
||||
Default model loading class, should work for all device classes.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device: Union[torch.device, str],
|
||||
swap_device: Optional[Union[torch.device, str]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
device (Union[torch.device, str]): The device to use for the model.
|
||||
"""
|
||||
self.device = torch.device(device)
|
||||
self.swap_device = (
|
||||
torch.device(swap_device) if swap_device is not None else self.device
|
||||
)
|
||||
|
||||
def load(self, model: torch.nn.Module) -> None:
|
||||
"""
|
||||
Loads a model to the (swap) device.
|
||||
"""
|
||||
model.to(self.swap_device)
|
||||
|
||||
def autocast(self):
|
||||
"""
|
||||
Context manager that enables autocast for the device if supported.
|
||||
"""
|
||||
if self.device.type not in ("cuda", "cpu"):
|
||||
return contextlib.nullcontext()
|
||||
return torch.autocast(self.device.type)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def use(self, model: torch.nn.Module):
|
||||
"""
|
||||
Context manager that ensures a model is on the correct device during use.
|
||||
The default model loader does not perform any swapping, so the model will
|
||||
stay on device.
|
||||
"""
|
||||
try:
|
||||
model.to(self.device)
|
||||
yield
|
||||
finally:
|
||||
if self.device != self.swap_device:
|
||||
model.to(self.swap_device)
|
||||
|
||||
|
||||
class CudaModelManager(DeviceModelManager):
|
||||
"""
|
||||
Device manager that loads a model to a CUDA device, optionally swapping to CPU when not in use.
|
||||
"""
|
||||
|
||||
@contextlib.contextmanager
|
||||
def use(self, model: Union[torch.nn.Module, torch.Tensor]):
|
||||
"""
|
||||
Context manager that ensures a model is on the correct device during use.
|
||||
If a swap device was provided, this will move the model to it after use and clear cache.
|
||||
"""
|
||||
model.to(self.device)
|
||||
yield
|
||||
if self.device != self.swap_device:
|
||||
model.to(self.swap_device)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def get_unique_embedder_keys_from_conditioner(conditioner):
|
||||
return list({x.input_key for x in conditioner.embedders})
|
||||
|
||||
@@ -74,6 +141,20 @@ def perform_save_locally(save_path, samples):
|
||||
base_count += 1
|
||||
|
||||
|
||||
def get_model_manager(
|
||||
device: Optional[Union[DeviceModelManager, str, torch.device]]
|
||||
) -> DeviceModelManager:
|
||||
if isinstance(device, DeviceModelManager):
|
||||
return device
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
device = torch.device(device)
|
||||
if device.type == "cuda":
|
||||
return CudaModelManager(device=device)
|
||||
else:
|
||||
return DeviceModelManager(device=device)
|
||||
|
||||
|
||||
class Img2ImgDiscretizationWrapper:
|
||||
"""
|
||||
wraps a discretizer, and prunes the sigmas
|
||||
@@ -98,6 +179,36 @@ class Img2ImgDiscretizationWrapper:
|
||||
return sigmas
|
||||
|
||||
|
||||
class Txt2NoisyDiscretizationWrapper:
|
||||
"""
|
||||
wraps a discretizer, and prunes the sigmas
|
||||
params:
|
||||
strength: float between 0.0 and 1.0. 0.0 means full sampling (all sigmas are returned)
|
||||
"""
|
||||
|
||||
def __init__(self, discretization, strength: float = 0.0, original_steps=None):
|
||||
self.discretization = discretization
|
||||
self.strength = strength
|
||||
self.original_steps = original_steps
|
||||
assert 0.0 <= self.strength <= 1.0
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
# sigmas start large first, and decrease then
|
||||
sigmas = self.discretization(*args, **kwargs)
|
||||
print(f"sigmas after discretization, before pruning img2img: ", sigmas)
|
||||
sigmas = torch.flip(sigmas, (0,))
|
||||
if self.original_steps is None:
|
||||
steps = len(sigmas)
|
||||
else:
|
||||
steps = self.original_steps + 1
|
||||
prune_index = max(min(int(self.strength * steps) - 1, steps - 1), 0)
|
||||
sigmas = sigmas[prune_index:]
|
||||
print("prune index:", prune_index)
|
||||
sigmas = torch.flip(sigmas, (0,))
|
||||
print(f"sigmas after pruning: ", sigmas)
|
||||
return sigmas
|
||||
|
||||
|
||||
def do_sample(
|
||||
model,
|
||||
sampler,
|
||||
@@ -111,39 +222,45 @@ def do_sample(
|
||||
batch2model_input: Optional[List] = None,
|
||||
return_latents=False,
|
||||
filter=None,
|
||||
device="cuda",
|
||||
device: Optional[Union[DeviceModelManager, str, torch.device]] = None,
|
||||
):
|
||||
if force_uc_zero_embeddings is None:
|
||||
force_uc_zero_embeddings = []
|
||||
if batch2model_input is None:
|
||||
batch2model_input = []
|
||||
|
||||
device_manager = get_model_manager(device=device)
|
||||
|
||||
with torch.no_grad():
|
||||
with autocast(device) as precision_scope:
|
||||
with device_manager.autocast():
|
||||
with model.ema_scope():
|
||||
num_samples = [num_samples]
|
||||
batch, batch_uc = get_batch(
|
||||
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
||||
value_dict,
|
||||
num_samples,
|
||||
)
|
||||
for key in batch:
|
||||
if isinstance(batch[key], torch.Tensor):
|
||||
print(key, batch[key].shape)
|
||||
elif isinstance(batch[key], list):
|
||||
print(key, [len(l) for l in batch[key]])
|
||||
else:
|
||||
print(key, batch[key])
|
||||
c, uc = model.conditioner.get_unconditional_conditioning(
|
||||
batch,
|
||||
batch_uc=batch_uc,
|
||||
force_uc_zero_embeddings=force_uc_zero_embeddings,
|
||||
)
|
||||
with device_manager.use(model.conditioner):
|
||||
batch, batch_uc = get_batch(
|
||||
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
||||
value_dict,
|
||||
num_samples,
|
||||
)
|
||||
for key in batch:
|
||||
if isinstance(batch[key], torch.Tensor):
|
||||
print(key, batch[key].shape)
|
||||
elif isinstance(batch[key], list):
|
||||
print(key, [len(l) for l in batch[key]])
|
||||
else:
|
||||
print(key, batch[key])
|
||||
c, uc = model.conditioner.get_unconditional_conditioning(
|
||||
batch,
|
||||
batch_uc=batch_uc,
|
||||
force_uc_zero_embeddings=force_uc_zero_embeddings,
|
||||
)
|
||||
|
||||
for k in c:
|
||||
if not k == "crossattn":
|
||||
c[k], uc[k] = map(
|
||||
lambda y: y[k][: math.prod(num_samples)].to(device), (c, uc)
|
||||
lambda y: y[k][: math.prod(num_samples)].to(
|
||||
device_manager.device
|
||||
),
|
||||
(c, uc),
|
||||
)
|
||||
|
||||
additional_model_inputs = {}
|
||||
@@ -151,16 +268,20 @@ def do_sample(
|
||||
additional_model_inputs[k] = batch[k]
|
||||
|
||||
shape = (math.prod(num_samples), C, H // F, W // F)
|
||||
randn = torch.randn(shape).to(device)
|
||||
randn = torch.randn(shape).to(device_manager.device)
|
||||
|
||||
def denoiser(input, sigma, c):
|
||||
return model.denoiser(
|
||||
model.model, input, sigma, c, **additional_model_inputs
|
||||
)
|
||||
|
||||
samples_z = sampler(denoiser, randn, cond=c, uc=uc)
|
||||
samples_x = model.decode_first_stage(samples_z)
|
||||
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
with device_manager.use(model.denoiser):
|
||||
with device_manager.use(model.model):
|
||||
samples_z = sampler(denoiser, randn, cond=c, uc=uc)
|
||||
|
||||
with device_manager.use(model.first_stage_model):
|
||||
samples_x = model.decode_first_stage(samples_z)
|
||||
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
|
||||
if filter is not None:
|
||||
samples = filter(samples)
|
||||
@@ -252,32 +373,40 @@ def do_img2img(
|
||||
return_latents=False,
|
||||
skip_encode=False,
|
||||
filter=None,
|
||||
device="cuda",
|
||||
add_noise=True,
|
||||
device: Optional[Union[DeviceModelManager, str, torch.device]] = None,
|
||||
):
|
||||
device_manager = get_model_manager(device)
|
||||
with torch.no_grad():
|
||||
with autocast(device) as precision_scope:
|
||||
with device_manager.autocast():
|
||||
with model.ema_scope():
|
||||
batch, batch_uc = get_batch(
|
||||
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
||||
value_dict,
|
||||
[num_samples],
|
||||
)
|
||||
c, uc = model.conditioner.get_unconditional_conditioning(
|
||||
batch,
|
||||
batch_uc=batch_uc,
|
||||
force_uc_zero_embeddings=force_uc_zero_embeddings,
|
||||
)
|
||||
with device_manager.use(model.conditioner):
|
||||
batch, batch_uc = get_batch(
|
||||
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
||||
value_dict,
|
||||
[num_samples],
|
||||
)
|
||||
c, uc = model.conditioner.get_unconditional_conditioning(
|
||||
batch,
|
||||
batch_uc=batch_uc,
|
||||
force_uc_zero_embeddings=force_uc_zero_embeddings,
|
||||
)
|
||||
|
||||
for k in c:
|
||||
c[k], uc[k] = map(lambda y: y[k][:num_samples].to(device), (c, uc))
|
||||
c[k], uc[k] = map(
|
||||
lambda y: y[k][:num_samples].to(device_manager.device), (c, uc)
|
||||
)
|
||||
|
||||
for k in additional_kwargs:
|
||||
c[k] = uc[k] = additional_kwargs[k]
|
||||
if skip_encode:
|
||||
z = img
|
||||
else:
|
||||
z = model.encode_first_stage(img)
|
||||
with device_manager.use(model.first_stage_model):
|
||||
z = model.encode_first_stage(img)
|
||||
|
||||
noise = torch.randn_like(z)
|
||||
|
||||
sigmas = sampler.discretization(sampler.num_steps)
|
||||
sigma = sigmas[0].to(z.device)
|
||||
|
||||
@@ -285,17 +414,24 @@ def do_img2img(
|
||||
noise = noise + offset_noise_level * append_dims(
|
||||
torch.randn(z.shape[0], device=z.device), z.ndim
|
||||
)
|
||||
noised_z = z + noise * append_dims(sigma, z.ndim)
|
||||
noised_z = noised_z / torch.sqrt(
|
||||
1.0 + sigmas[0] ** 2.0
|
||||
) # Note: hardcoded to DDPM-like scaling. need to generalize later.
|
||||
if add_noise:
|
||||
noised_z = z + noise * append_dims(sigma, z.ndim).cuda()
|
||||
noised_z = noised_z / torch.sqrt(
|
||||
1.0 + sigmas[0] ** 2.0
|
||||
) # Note: hardcoded to DDPM-like scaling. need to generalize later.
|
||||
else:
|
||||
noised_z = z / torch.sqrt(1.0 + sigmas[0] ** 2.0)
|
||||
|
||||
def denoiser(x, sigma, c):
|
||||
return model.denoiser(model.model, x, sigma, c)
|
||||
|
||||
samples_z = sampler(denoiser, noised_z, cond=c, uc=uc)
|
||||
samples_x = model.decode_first_stage(samples_z)
|
||||
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
with device_manager.use(model.denoiser):
|
||||
with device_manager.use(model.model):
|
||||
samples_z = sampler(denoiser, noised_z, cond=c, uc=uc)
|
||||
|
||||
with device_manager.use(model.first_stage_model):
|
||||
samples_x = model.decode_first_stage(samples_z)
|
||||
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
|
||||
if filter is not None:
|
||||
samples = filter(samples)
|
||||
|
||||
@@ -26,7 +26,7 @@ class Discretization:
|
||||
|
||||
|
||||
class EDMDiscretization(Discretization):
|
||||
def __init__(self, sigma_min=0.02, sigma_max=80.0, rho=7.0):
|
||||
def __init__(self, sigma_min=0.002, sigma_max=80.0, rho=7.0):
|
||||
self.sigma_min = sigma_min
|
||||
self.sigma_max = sigma_max
|
||||
self.rho = rho
|
||||
|
||||
18
sgm/util.py
18
sgm/util.py
@@ -230,6 +230,24 @@ def load_model_from_config(config, ckpt, verbose=True, freeze=True):
|
||||
return model
|
||||
|
||||
|
||||
def get_checkpoints_path() -> str:
|
||||
"""
|
||||
Get the `checkpoints` directory.
|
||||
This could be in the root of the repository for a working copy,
|
||||
or in the cwd for other use cases.
|
||||
"""
|
||||
this_dir = os.path.dirname(__file__)
|
||||
candidates = (
|
||||
os.path.join(this_dir, "checkpoints"),
|
||||
os.path.join(os.getcwd(), "checkpoints"),
|
||||
)
|
||||
for candidate in candidates:
|
||||
candidate = os.path.abspath(candidate)
|
||||
if os.path.isdir(candidate):
|
||||
return candidate
|
||||
raise FileNotFoundError(f"Could not find SGM checkpoints in {candidates}")
|
||||
|
||||
|
||||
def get_configs_path() -> str:
|
||||
"""
|
||||
Get the `configs` directory.
|
||||
|
||||
@@ -27,7 +27,7 @@ class TestInference:
|
||||
@fixture(
|
||||
scope="class",
|
||||
params=[
|
||||
[ModelArchitecture.SDXL_V1_BASE, ModelArchitecture.SDXL_V1_REFINER],
|
||||
[ModelArchitecture.SDXL_V1_0_BASE, ModelArchitecture.SDXL_V1_0_REFINER],
|
||||
[ModelArchitecture.SDXL_V0_9_BASE, ModelArchitecture.SDXL_V0_9_REFINER],
|
||||
],
|
||||
ids=["SDXL_V1", "SDXL_V0_9"],
|
||||
@@ -68,9 +68,7 @@ class TestInference:
|
||||
assert output is not None
|
||||
|
||||
@pytest.mark.parametrize("sampler_enum", Sampler)
|
||||
@pytest.mark.parametrize(
|
||||
"use_init_image", [True, False], ids=["img2img", "txt2img"]
|
||||
)
|
||||
@pytest.mark.parametrize("use_init_image", [True, False], ids=["img2img", "txt2img"])
|
||||
def test_sdxl_with_refiner(
|
||||
self,
|
||||
sdxl_pipelines: Tuple[SamplingPipeline, SamplingPipeline],
|
||||
@@ -81,13 +79,12 @@ class TestInference:
|
||||
if use_init_image:
|
||||
output = base_pipeline.image_to_image(
|
||||
params=SamplingParams(sampler=sampler_enum.value, steps=10),
|
||||
image=self.create_init_image(
|
||||
base_pipeline.specs.height, base_pipeline.specs.width
|
||||
),
|
||||
image=self.create_init_image(base_pipeline.specs.height, base_pipeline.specs.width),
|
||||
prompt="A professional photograph of an astronaut riding a pig",
|
||||
negative_prompt="",
|
||||
samples=1,
|
||||
return_latents=True,
|
||||
noise_strength=0.15,
|
||||
)
|
||||
else:
|
||||
output = base_pipeline.text_to_image(
|
||||
@@ -96,6 +93,7 @@ class TestInference:
|
||||
negative_prompt="",
|
||||
samples=1,
|
||||
return_latents=True,
|
||||
noise_strength=0.15,
|
||||
)
|
||||
|
||||
assert isinstance(output, (tuple, list))
|
||||
@@ -103,9 +101,9 @@ class TestInference:
|
||||
assert samples is not None
|
||||
assert samples_z is not None
|
||||
refiner_pipeline.refiner(
|
||||
params=SamplingParams(sampler=sampler_enum.value, steps=10),
|
||||
image=samples_z,
|
||||
prompt="A professional photograph of an astronaut riding a pig",
|
||||
params=SamplingParams(sampler=sampler_enum.value, steps=40, img2img_strength=0.15),
|
||||
negative_prompt="",
|
||||
samples=1,
|
||||
)
|
||||
|
||||
44
tests/inference/test_modelmanager.py
Normal file
44
tests/inference/test_modelmanager.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from sgm.inference.api import (
|
||||
SamplingPipeline,
|
||||
ModelArchitecture,
|
||||
)
|
||||
import sgm.inference.helpers as helpers
|
||||
|
||||
def get_torch_device(model: torch.nn.Module) -> torch.device:
|
||||
param = next(model.parameters(), None)
|
||||
if param is not None:
|
||||
return param.device
|
||||
else:
|
||||
buf = next(model.buffers(), None)
|
||||
if buf is not None:
|
||||
return buf.device
|
||||
else:
|
||||
raise TypeError("Could not determine device of input model")
|
||||
|
||||
|
||||
@pytest.mark.inference
|
||||
def test_default_loading():
|
||||
pipeline = SamplingPipeline(model_id=ModelArchitecture.SD_2_1)
|
||||
assert get_torch_device(pipeline.model.model).type == "cuda"
|
||||
assert get_torch_device(pipeline.model.conditioner).type == "cuda"
|
||||
with pipeline.device_manager.use(pipeline.model.model):
|
||||
assert get_torch_device(pipeline.model.model).type == "cuda"
|
||||
assert get_torch_device(pipeline.model.model).type == "cuda"
|
||||
with pipeline.device_manager.use(pipeline.model.conditioner):
|
||||
assert get_torch_device(pipeline.model.conditioner).type == "cuda"
|
||||
assert get_torch_device(pipeline.model.conditioner).type == "cuda"
|
||||
|
||||
@pytest.mark.inference
|
||||
def test_model_swapping():
|
||||
pipeline = SamplingPipeline(model_id=ModelArchitecture.SD_2_1, device=helpers.CudaModelManager(device="cuda", swap_device="cpu"))
|
||||
assert get_torch_device(pipeline.model.model).type == "cpu"
|
||||
assert get_torch_device(pipeline.model.conditioner).type == "cpu"
|
||||
with pipeline.device_manager.use(pipeline.model.model):
|
||||
assert get_torch_device(pipeline.model.model).type == "cuda"
|
||||
assert get_torch_device(pipeline.model.model).type == "cpu"
|
||||
with pipeline.device_manager.use(pipeline.model.conditioner):
|
||||
assert get_torch_device(pipeline.model.conditioner).type == "cuda"
|
||||
assert get_torch_device(pipeline.model.conditioner).type == "cpu"
|
||||
Reference in New Issue
Block a user