58 Commits

Author SHA1 Message Date
Stephan Auerhahn
7ef5489cea Merge branch 'main' into helpers-fixes 2023-08-17 10:21:35 -07:00
Vitaly Bondar
477d8b9a77 fix EDMDiscretization sigma_min for correct sampling noise scheduling (#114) 2023-08-17 08:48:30 -07:00
Stephan Auerhahn
e289621992 fix reference 2023-08-12 13:52:46 -07:00
Stephan Auerhahn
2fc4680bf9 Easier default params 2023-08-12 13:22:04 -07:00
Stephan Auerhahn
e32972b85b remove extra init 2023-08-12 05:42:22 -07:00
Stephan Auerhahn
65c6ec1cec run black 2023-08-12 05:40:25 -07:00
Stephan Auerhahn
5fde7e73b8 set a default scale 2023-08-12 05:35:36 -07:00
Stephan Auerhahn
fbe93fc53b PR fixes, model specific defaults 2023-08-12 05:33:16 -07:00
Stephan Auerhahn
c0655731d5 fix streamlit inputs 2023-08-12 04:25:56 -07:00
Stephan Auerhahn
f6704532a0 abstract device defaults 2023-08-12 07:27:25 +00:00
Stephan Auerhahn
98c4b7753b cleanup imports in test 2023-08-12 07:16:02 +00:00
Stephan Auerhahn
d4307bef5d Test model device manager and fix bugs 2023-08-12 07:15:36 +00:00
Stephan Auerhahn
fe4632034b fix for orig dimensions 2023-08-11 16:31:53 -07:00
Stephan Auerhahn
d6f2b78994 pass options into state2 init 2023-08-10 15:06:55 -07:00
Stephan Auerhahn
cd81956241 text updates 2023-08-10 13:31:03 -07:00
Stephan Auerhahn
5c17043434 change default 2023-08-10 13:15:23 -07:00
Stephan Auerhahn
2aebc8882d split fp16 and swapping functionality 2023-08-10 13:14:38 -07:00
Stephan Auerhahn
3816aaa639 simplify device_manager usage 2023-08-10 13:05:30 -07:00
Stephan Auerhahn
88395261d8 update helpers 2023-08-10 12:45:37 -07:00
Stephan Auerhahn
b3866d1218 move checkbox out of cached resource 2023-08-10 12:44:48 -07:00
Stephan Auerhahn
a25662e969 low vram checkbox fix, remove magic strings 2023-08-10 12:40:32 -07:00
Stephan Auerhahn
26b10f56f3 fix missing index 2023-08-10 12:24:12 -07:00
Stephan Auerhahn
3e7ada70c5 fix autocast 2023-08-10 05:42:31 -07:00
Stephan Auerhahn
de7a627978 more fixes and cleanup 2023-08-10 05:11:34 -07:00
Stephan Auerhahn
9b18e6fa19 update api module 2023-08-10 05:07:22 -07:00
Stephan Auerhahn
47805f233c finish device manager refactor 2023-08-10 04:55:43 -07:00
Stephan Auerhahn
e190ecc60b path helper & model swapping rewrite 2023-08-10 04:35:59 -07:00
Stephan Auerhahn
fc498bfaef remove duplicate imports 2023-08-10 03:20:56 -07:00
Stephan Auerhahn
8011d54ca1 some PR fixes 2023-08-10 03:19:37 -07:00
Stephan Auerhahn
b51c36b0df extract path resolution method, fix/improve device swapping support 2023-08-09 19:31:59 -07:00
Stephan Auerhahn
d245e2002f more types 2023-08-09 13:46:06 -07:00
Stephan Auerhahn
725bea9f75 pull in import fix 2023-08-09 13:29:16 -07:00
Stephan Auerhahn
a009aa8a9f adding some typing 2023-08-09 13:27:30 -07:00
Stephan Auerhahn
f86ffac274 context manager 2023-08-09 12:38:44 -07:00
Stephan Auerhahn
a726ce3eb7 replace usage of get 2023-08-09 12:30:43 -07:00
Stephan Auerhahn
c4b7baf896 Streamlit refactor (#105)
* initial streamlit refactoring pass

* cleanup and fixes

* fix refiner strength

* Modify params correctly

* fix exception
2023-08-06 19:58:52 -07:00
Stephan Auerhahn
7e7fee3f0f system env var 2023-08-06 19:22:59 -07:00
Stephan Auerhahn
49fe53c165 use env var for sgm checkpoints path 2023-08-06 19:21:17 -07:00
Stephan Auerhahn
6c18c8443a rename ModelOnDevice to SwapToDevice 2023-08-06 23:46:20 +00:00
Stephan Auerhahn
ced97f0e84 update defaults 2023-08-06 23:24:14 +00:00
Stephan Auerhahn
76ca428422 fix path resolution bug 2023-08-06 21:39:18 +00:00
Stephan Auerhahn
8f8757b4ff version bump for changes to inference helpers 2023-08-06 21:09:09 +00:00
Stephan Auerhahn
f2fba1dfa2 fix noisy latent handling 2023-08-06 21:08:19 +00:00
Stephan Auerhahn
451c76ada1 format 2023-08-06 12:26:16 +00:00
Stephan Auerhahn
0c2c5c66a2 fix device check 2023-08-06 12:26:01 +00:00
Stephan Auerhahn
ea5f232d5d move conditioner to device 2023-08-06 11:42:39 +00:00
Stephan Auerhahn
f06c67c206 formatting, remove reference 2023-08-06 11:30:40 +00:00
Stephan Auerhahn
b216934b7e align with streamlit helpers and re-de-deuplicate 2023-08-06 11:20:22 +00:00
Stephan Auerhahn
77d0e27747 format 2023-08-03 17:57:55 -07:00
Stephan Auerhahn
4aea6fa2a4 Fix checkpoint loading too 2023-08-03 17:56:24 -07:00
Stephan Auerhahn
84d3a7f6f5 fix fallback logic for config path 2023-08-03 17:50:10 -07:00
Stephan Auerhahn
19fa4da3de run black again 2023-08-04 00:16:29 +00:00
Stephan Auerhahn
4e2236f67d Fix path logic for development installs 2023-08-04 00:15:22 +00:00
Stephan Auerhahn
baf79d2d79 black 2023-08-04 00:00:51 +00:00
Stephan Auerhahn
44943df4f2 Allow loading custom models and improve path logic 2023-08-03 23:59:42 +00:00
Stephan Auerhahn
73287ec3a3 Extract method for img2img wrapper 2023-08-03 23:42:11 +00:00
Stephan Auerhahn
853adb4022 Add defaults to refiner function 2023-08-03 12:50:23 -07:00
Stephan Auerhahn
45feb6cb9c Use wrapper correctly in refiner helper 2023-08-02 23:14:30 +00:00
11 changed files with 730 additions and 889 deletions

View File

@@ -15,7 +15,7 @@ jobs:
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
- name: "Symlink checkpoints" - name: "Symlink checkpoints"
run: ln -s ${{vars.SGM_CHECKPOINTS_PATH}} checkpoints run: ln -s $SGM_CHECKPOINTS checkpoints
- name: "Setup python" - name: "Setup python"
uses: actions/setup-python@v4 uses: actions/setup-python@v4
with: with:

View File

@@ -44,5 +44,5 @@ dependencies = [
test-inference = [ test-inference = [
"pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 torchaudio==2.0.2+cu118 --index-url https://download.pytorch.org/whl/cu118", "pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 torchaudio==2.0.2+cu118 --index-url https://download.pytorch.org/whl/cu118",
"pip install -r requirements/pt2.txt", "pip install -r requirements/pt2.txt",
"pytest -v tests/inference/test_inference.py {args}", "pytest -v tests/inference {args}",
] ]

View File

@@ -1,6 +1,30 @@
import os
import numpy as np
import streamlit as st
import torch
from einops import repeat
from pytorch_lightning import seed_everything from pytorch_lightning import seed_everything
from scripts.demo.streamlit_helpers import * from sgm.inference.api import (
SamplingSpec,
SamplingParams,
ModelArchitecture,
SamplingPipeline,
model_specs,
)
from sgm.inference.helpers import (
get_unique_embedder_keys_from_conditioner,
perform_save_locally,
)
from scripts.demo.streamlit_helpers import (
get_interactive_image,
init_embedder_options,
init_sampling,
init_save_locally,
init_st,
show_samples,
)
SAVE_PATH = "outputs/demo/txt2img/" SAVE_PATH = "outputs/demo/txt2img/"
@@ -33,63 +57,6 @@ SD_XL_BASE_RATIOS = {
"3.0": (1728, 576), "3.0": (1728, 576),
} }
VERSION2SPECS = {
"SDXL-base-1.0": {
"H": 1024,
"W": 1024,
"C": 4,
"f": 8,
"is_legacy": False,
"config": "configs/inference/sd_xl_base.yaml",
"ckpt": "checkpoints/sd_xl_base_1.0.safetensors",
},
"SDXL-base-0.9": {
"H": 1024,
"W": 1024,
"C": 4,
"f": 8,
"is_legacy": False,
"config": "configs/inference/sd_xl_base.yaml",
"ckpt": "checkpoints/sd_xl_base_0.9.safetensors",
},
"SD-2.1": {
"H": 512,
"W": 512,
"C": 4,
"f": 8,
"is_legacy": True,
"config": "configs/inference/sd_2_1.yaml",
"ckpt": "checkpoints/v2-1_512-ema-pruned.safetensors",
},
"SD-2.1-768": {
"H": 768,
"W": 768,
"C": 4,
"f": 8,
"is_legacy": True,
"config": "configs/inference/sd_2_1_768.yaml",
"ckpt": "checkpoints/v2-1_768-ema-pruned.safetensors",
},
"SDXL-refiner-0.9": {
"H": 1024,
"W": 1024,
"C": 4,
"f": 8,
"is_legacy": True,
"config": "configs/inference/sd_xl_refiner.yaml",
"ckpt": "checkpoints/sd_xl_refiner_0.9.safetensors",
},
"SDXL-refiner-1.0": {
"H": 1024,
"W": 1024,
"C": 4,
"f": 8,
"is_legacy": True,
"config": "configs/inference/sd_xl_refiner.yaml",
"ckpt": "checkpoints/sd_xl_refiner_1.0.safetensors",
},
}
def load_img(display=True, key=None, device="cuda"): def load_img(display=True, key=None, device="cuda"):
image = get_interactive_image(key=key) image = get_interactive_image(key=key)
@@ -111,170 +78,181 @@ def load_img(display=True, key=None, device="cuda"):
def run_txt2img( def run_txt2img(
state, state,
version, model_id: ModelArchitecture,
version_dict, prompt: str,
is_legacy=False, negative_prompt: str,
return_latents=False, return_latents=False,
filter=None,
stage2strength=None, stage2strength=None,
): ):
if version.startswith("SDXL-base"): model: SamplingPipeline = state["model"]
W, H = st.selectbox("Resolution:", list(SD_XL_BASE_RATIOS.values()), 10) params: SamplingParams = state["params"]
if model_id in sdxl_base_model_list:
width, height = st.selectbox(
"Resolution:", list(SD_XL_BASE_RATIOS.values()), 10
)
else: else:
H = st.number_input("H", value=version_dict["H"], min_value=64, max_value=2048) height = int(
W = st.number_input("W", value=version_dict["W"], min_value=64, max_value=2048) st.number_input("H", value=params.height, min_value=64, max_value=2048)
C = version_dict["C"] )
F = version_dict["f"] width = int(
st.number_input("W", value=params.width, min_value=64, max_value=2048)
)
init_dict = { params = init_embedder_options(
"orig_width": W, get_unique_embedder_keys_from_conditioner(model.model.conditioner),
"orig_height": H, params=params,
"target_width": W,
"target_height": H,
}
value_dict = init_embedder_options(
get_unique_embedder_keys_from_conditioner(state["model"].conditioner),
init_dict,
prompt=prompt, prompt=prompt,
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
) )
sampler, num_rows, num_cols = init_sampling(stage2strength=stage2strength) params, num_rows, num_cols = init_sampling(params=params)
num_samples = num_rows * num_cols num_samples = num_rows * num_cols
params.height = height
params.width = width
if st.button("Sample"): if st.button("Sample"):
st.write(f"**Model I:** {version}") st.write(f"**Model I:** {version}")
out = do_sample( outputs = st.empty()
state["model"], st.text("Sampling")
sampler, out = model.text_to_image(
value_dict, params=params,
num_samples, prompt=prompt,
H, negative_prompt=negative_prompt,
W, samples=int(num_samples),
C,
F,
force_uc_zero_embeddings=["txt"] if not is_legacy else [],
return_latents=return_latents, return_latents=return_latents,
filter=filter, noise_strength=stage2strength,
filter=state["filter"],
) )
show_samples(out, outputs)
return out return out
def run_img2img( def run_img2img(
state, state,
version_dict, prompt: str,
is_legacy=False, negative_prompt: str,
return_latents=False, return_latents=False,
filter=None,
stage2strength=None, stage2strength=None,
): ):
model: SamplingPipeline = state["model"]
params: SamplingParams = state["params"]
img = load_img() img = load_img()
if img is None: if img is None:
return None return None
H, W = img.shape[2], img.shape[3] params.height, params.width = img.shape[2], img.shape[3]
init_dict = { params = init_embedder_options(
"orig_width": W, get_unique_embedder_keys_from_conditioner(model.model.conditioner),
"orig_height": H, params=params,
"target_width": W,
"target_height": H,
}
value_dict = init_embedder_options(
get_unique_embedder_keys_from_conditioner(state["model"].conditioner),
init_dict,
prompt=prompt, prompt=prompt,
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
) )
strength = st.number_input( params.img2img_strength = st.number_input(
"**Img2Img Strength**", value=0.75, min_value=0.0, max_value=1.0 "**Img2Img Strength**", value=0.75, min_value=0.0, max_value=1.0
) )
sampler, num_rows, num_cols = init_sampling( params, num_rows, num_cols = init_sampling(params=params)
img2img_strength=strength,
stage2strength=stage2strength,
)
num_samples = num_rows * num_cols num_samples = num_rows * num_cols
if st.button("Sample"): if st.button("Sample"):
out = do_img2img( outputs = st.empty()
repeat(img, "1 ... -> n ...", n=num_samples), st.text("Sampling")
state["model"], out = model.image_to_image(
sampler, image=repeat(img, "1 ... -> n ...", n=num_samples),
value_dict, params=params,
num_samples, prompt=prompt,
force_uc_zero_embeddings=["txt"] if not is_legacy else [], negative_prompt=negative_prompt,
samples=int(num_samples),
return_latents=return_latents, return_latents=return_latents,
filter=filter, noise_strength=stage2strength,
filter=state["filter"],
) )
show_samples(out, outputs)
return out return out
def apply_refiner( def apply_refiner(
input, input,
state, state,
sampler, num_samples: int,
num_samples, prompt: str,
prompt, negative_prompt: str,
negative_prompt,
filter=None,
finish_denoising=False, finish_denoising=False,
): ):
init_dict = { model: SamplingPipeline = state["model"]
"orig_width": input.shape[3] * 8, params: SamplingParams = state["params"]
"orig_height": input.shape[2] * 8,
"target_width": input.shape[3] * 8,
"target_height": input.shape[2] * 8,
}
value_dict = init_dict params.orig_width = input.shape[3] * 8
value_dict["prompt"] = prompt params.orig_height = input.shape[2] * 8
value_dict["negative_prompt"] = negative_prompt params.width = input.shape[3] * 8
params.height = input.shape[2] * 8
value_dict["crop_coords_top"] = 0
value_dict["crop_coords_left"] = 0
value_dict["aesthetic_score"] = 6.0
value_dict["negative_aesthetic_score"] = 2.5
st.warning(f"refiner input shape: {input.shape}") st.warning(f"refiner input shape: {input.shape}")
samples = do_img2img(
input, samples = model.refiner(
state["model"], image=input,
sampler, params=params,
value_dict, prompt=prompt,
num_samples, negative_prompt=negative_prompt,
skip_encode=True, samples=num_samples,
filter=filter, return_latents=False,
filter=state["filter"],
add_noise=not finish_denoising, add_noise=not finish_denoising,
) )
return samples return samples
sdxl_base_model_list = [
ModelArchitecture.SDXL_V1_0_BASE,
ModelArchitecture.SDXL_V0_9_BASE,
]
sdxl_refiner_model_list = [
ModelArchitecture.SDXL_V1_0_REFINER,
ModelArchitecture.SDXL_V0_9_REFINER,
]
if __name__ == "__main__": if __name__ == "__main__":
st.title("Stable Diffusion") st.title("Stable Diffusion")
version = st.selectbox("Model Version", list(VERSION2SPECS.keys()), 0) version = st.selectbox(
version_dict = VERSION2SPECS[version] "Model Version",
[member.value for member in ModelArchitecture],
0,
)
version_enum = ModelArchitecture(version)
specs = model_specs[version_enum]
mode = st.radio("Mode", ("txt2img", "img2img"), 0) mode = st.radio("Mode", ("txt2img", "img2img"), 0)
st.write("__________________________") st.write("__________________________")
set_lowvram_mode(st.checkbox("Low vram mode", True)) st.write("**Performance Options:**")
use_fp16 = st.checkbox("Use fp16 (Saves VRAM)", True)
enable_swap = st.checkbox("Swap models to CPU (Saves VRAM, uses RAM)", True)
st.write("__________________________")
if version.startswith("SDXL-base"): if version_enum in sdxl_base_model_list:
add_pipeline = st.checkbox("Load SDXL-refiner?", False) add_pipeline = st.checkbox("Load SDXL-refiner?", False)
st.write("__________________________") st.write("__________________________")
else: else:
add_pipeline = False add_pipeline = False
seed = st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9)) seed = int(
st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9))
)
seed_everything(seed) seed_everything(seed)
save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, version)) save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, str(version)))
state = init_st(
state = init_st(version_dict, load_filter=True) model_specs[version_enum],
if state["msg"]: load_filter=True,
st.info(state["msg"]) use_fp16=use_fp16,
enable_swap=enable_swap,
)
model = state["model"] model = state["model"]
is_legacy = version_dict["is_legacy"] is_legacy = specs.is_legacy
prompt = st.text_input( prompt = st.text_input(
"prompt", "prompt",
@@ -290,47 +268,58 @@ if __name__ == "__main__":
if add_pipeline: if add_pipeline:
st.write("__________________________") st.write("__________________________")
version2 = st.selectbox("Refiner:", ["SDXL-refiner-1.0", "SDXL-refiner-0.9"]) version2 = ModelArchitecture(
st.selectbox(
"Refiner:",
[member.value for member in sdxl_refiner_model_list],
)
)
st.warning( st.warning(
f"Running with {version2} as the second stage model. Make sure to provide (V)RAM :) " f"Running with {version2} as the second stage model. Make sure to provide (V)RAM :) "
) )
st.write("**Refiner Options:**") st.write("**Refiner Options:**")
version_dict2 = VERSION2SPECS[version2] specs2 = model_specs[version2]
state2 = init_st(version_dict2, load_filter=False) state2 = init_st(
st.info(state2["msg"]) specs2, load_filter=False, use_fp16=use_fp16, enable_swap=enable_swap
)
params2 = state2["params"]
stage2strength = st.number_input( params2.img2img_strength = st.number_input(
"**Refinement strength**", value=0.15, min_value=0.0, max_value=1.0 "**Refinement strength**", value=0.15, min_value=0.0, max_value=1.0
) )
sampler2, *_ = init_sampling( params2, *_ = init_sampling(
params=state2["params"],
key=2, key=2,
img2img_strength=stage2strength,
specify_num_samples=False, specify_num_samples=False,
) )
st.write("__________________________") st.write("__________________________")
finish_denoising = st.checkbox("Finish denoising with refiner.", True) finish_denoising = st.checkbox("Finish denoising with refiner.", True)
if not finish_denoising: if finish_denoising:
stage2strength = params2.img2img_strength
else:
stage2strength = None stage2strength = None
else:
state2 = None
params2 = None
stage2strength = None
if mode == "txt2img": if mode == "txt2img":
out = run_txt2img( out = run_txt2img(
state, state=state,
version, model_id=version_enum,
version_dict, prompt=prompt,
is_legacy=is_legacy, negative_prompt=negative_prompt,
return_latents=add_pipeline, return_latents=add_pipeline,
filter=state.get("filter"),
stage2strength=stage2strength, stage2strength=stage2strength,
) )
elif mode == "img2img": elif mode == "img2img":
out = run_img2img( out = run_img2img(
state, state=state,
version_dict, prompt=prompt,
is_legacy=is_legacy, negative_prompt=negative_prompt,
return_latents=add_pipeline, return_latents=add_pipeline,
filter=state.get("filter"),
stage2strength=stage2strength, stage2strength=stage2strength,
) )
else: else:
@@ -342,17 +331,17 @@ if __name__ == "__main__":
samples_z = None samples_z = None
if add_pipeline and samples_z is not None: if add_pipeline and samples_z is not None:
outputs = st.empty()
st.write("**Running Refinement Stage**") st.write("**Running Refinement Stage**")
samples = apply_refiner( samples = apply_refiner(
samples_z, input=samples_z,
state2, state=state2,
sampler2, num_samples=samples_z.shape[0],
samples_z.shape[0],
prompt=prompt, prompt=prompt,
negative_prompt=negative_prompt if is_legacy else "", negative_prompt=negative_prompt if is_legacy else "",
filter=state.get("filter"),
finish_denoising=finish_denoising, finish_denoising=finish_denoising,
) )
show_samples(samples, outputs)
if save_locally and samples is not None: if save_locally and samples is not None:
perform_save_locally(save_path, samples) perform_save_locally(save_path, samples)

View File

@@ -1,166 +1,68 @@
import math
import os import os
from typing import List, Union
import numpy as np import numpy as np
import streamlit as st import streamlit as st
import torch import torch
from einops import rearrange, repeat from einops import rearrange, repeat
from imwatermark import WatermarkEncoder
from omegaconf import ListConfig, OmegaConf
from PIL import Image from PIL import Image
from safetensors.torch import load_file as load_safetensors
from torch import autocast
from torchvision import transforms from torchvision import transforms
from torchvision.utils import make_grid from typing import Optional, Tuple, Dict, Any
from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering
from sgm.modules.diffusionmodules.sampling import (
DPMPP2MSampler, from sgm.inference.api import (
DPMPP2SAncestralSampler, Discretization,
EulerAncestralSampler, Guider,
EulerEDMSampler, Sampler,
HeunEDMSampler, SamplingParams,
LinearMultistepSampler, SamplingSpec,
SamplingPipeline,
Thresholder,
) )
from sgm.util import append_dims, instantiate_from_config from sgm.inference.helpers import embed_watermark, CudaModelManager
class WatermarkEmbedder:
def __init__(self, watermark):
self.watermark = watermark
self.num_bits = len(WATERMARK_BITS)
self.encoder = WatermarkEncoder()
self.encoder.set_watermark("bits", self.watermark)
def __call__(self, image: torch.Tensor):
"""
Adds a predefined watermark to the input image
Args:
image: ([N,] B, C, H, W) in range [0, 1]
Returns:
same as input but watermarked
"""
# watermarking libary expects input as cv2 BGR format
squeeze = len(image.shape) == 4
if squeeze:
image = image[None, ...]
n = image.shape[0]
image_np = rearrange(
(255 * image).detach().cpu(), "n b c h w -> (n b) h w c"
).numpy()[:, :, :, ::-1]
# torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
for k in range(image_np.shape[0]):
image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
image = torch.from_numpy(
rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)
).to(image.device)
image = torch.clamp(image / 255, min=0.0, max=1.0)
if squeeze:
image = image[0]
return image
# A fixed 48-bit message that was choosen at random
# WATERMARK_MESSAGE = 0xB3EC907BB19E
WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110
# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
embed_watemark = WatermarkEmbedder(WATERMARK_BITS)
@st.cache_resource() @st.cache_resource()
def init_st(version_dict, load_ckpt=True, load_filter=True): def init_st(
state = dict() spec: SamplingSpec,
if not "model" in state: load_ckpt=True,
config = version_dict["config"] load_filter=True,
ckpt = version_dict["ckpt"] use_fp16=True,
enable_swap=True,
) -> Dict[str, Any]:
state: Dict[str, Any] = dict()
config = spec.config
ckpt = spec.ckpt
config = OmegaConf.load(config) if enable_swap:
model, msg = load_model_from_config(config, ckpt if load_ckpt else None) pipeline = SamplingPipeline(
model_spec=spec,
use_fp16=use_fp16,
device=CudaModelManager(device="cuda", swap_device="cpu"),
)
else:
pipeline = SamplingPipeline(model_spec=spec, use_fp16=use_fp16)
state["msg"] = msg state["spec"] = spec
state["model"] = model state["model"] = pipeline
state["ckpt"] = ckpt if load_ckpt else None state["ckpt"] = ckpt if load_ckpt else None
state["config"] = config state["config"] = config
if load_filter: state["params"] = spec.default_params
state["filter"] = DeepFloydDataFiltering(verbose=False) if load_filter:
state["filter"] = DeepFloydDataFiltering(verbose=False)
else:
state["filter"] = None
return state return state
def load_model(model):
model.cuda()
lowvram_mode = False
def set_lowvram_mode(mode):
global lowvram_mode
lowvram_mode = mode
def initial_model_load(model):
global lowvram_mode
if lowvram_mode:
model.model.half()
else:
model.cuda()
return model
def unload_model(model):
global lowvram_mode
if lowvram_mode:
model.cpu()
torch.cuda.empty_cache()
def load_model_from_config(config, ckpt=None, verbose=True):
model = instantiate_from_config(config.model)
if ckpt is not None:
print(f"Loading model from {ckpt}")
if ckpt.endswith("ckpt"):
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
global_step = pl_sd["global_step"]
st.info(f"loaded ckpt from global step {global_step}")
print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"]
elif ckpt.endswith("safetensors"):
sd = load_safetensors(ckpt)
else:
raise NotImplementedError
msg = None
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print("missing keys:")
print(m)
if len(u) > 0 and verbose:
print("unexpected keys:")
print(u)
else:
msg = None
model = initial_model_load(model)
model.eval()
return model, msg
def get_unique_embedder_keys_from_conditioner(conditioner): def get_unique_embedder_keys_from_conditioner(conditioner):
return list(set([x.input_key for x in conditioner.embedders])) return list(set([x.input_key for x in conditioner.embedders]))
def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None): def init_embedder_options(
# Hardcoded demo settings; might undergo some changes in the future keys, params: SamplingParams, prompt=None, negative_prompt=None
) -> SamplingParams:
value_dict = {}
for key in keys: for key in keys:
if key == "txt": if key == "txt":
if prompt is None: if prompt is None:
@@ -170,46 +72,38 @@ def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None):
if negative_prompt is None: if negative_prompt is None:
negative_prompt = st.text_input("Negative prompt", "") negative_prompt = st.text_input("Negative prompt", "")
value_dict["prompt"] = prompt
value_dict["negative_prompt"] = negative_prompt
if key == "original_size_as_tuple": if key == "original_size_as_tuple":
orig_width = st.number_input( orig_width = st.number_input(
"orig_width", "orig_width",
value=init_dict["orig_width"], value=params.orig_width,
min_value=16, min_value=16,
) )
orig_height = st.number_input( orig_height = st.number_input(
"orig_height", "orig_height",
value=init_dict["orig_height"], value=params.orig_height,
min_value=16, min_value=16,
) )
value_dict["orig_width"] = orig_width params.orig_width = int(orig_width)
value_dict["orig_height"] = orig_height params.orig_height = int(orig_height)
if key == "crop_coords_top_left": if key == "crop_coords_top_left":
crop_coord_top = st.number_input("crop_coords_top", value=0, min_value=0) crop_coord_top = st.number_input(
crop_coord_left = st.number_input("crop_coords_left", value=0, min_value=0) "crop_coords_top", value=params.crop_coords_top, min_value=0
)
crop_coord_left = st.number_input(
"crop_coords_left", value=params.crop_coords_left, min_value=0
)
value_dict["crop_coords_top"] = crop_coord_top params.crop_coords_top = int(crop_coord_top)
value_dict["crop_coords_left"] = crop_coord_left params.crop_coords_left = int(crop_coord_left)
return params
if key == "aesthetic_score":
value_dict["aesthetic_score"] = 6.0
value_dict["negative_aesthetic_score"] = 2.5
if key == "target_size_as_tuple":
value_dict["target_width"] = init_dict["target_width"]
value_dict["target_height"] = init_dict["target_height"]
return value_dict
def perform_save_locally(save_path, samples): def perform_save_locally(save_path, samples):
os.makedirs(os.path.join(save_path), exist_ok=True) os.makedirs(os.path.join(save_path), exist_ok=True)
base_count = len(os.listdir(os.path.join(save_path))) base_count = len(os.listdir(os.path.join(save_path)))
samples = embed_watemark(samples) samples = embed_watermark(samples)
for sample in samples: for sample in samples:
sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c") sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c")
Image.fromarray(sample.astype(np.uint8)).save( Image.fromarray(sample.astype(np.uint8)).save(
@@ -228,78 +122,26 @@ def init_save_locally(_dir, init_value: bool = False):
return save_locally, save_path return save_locally, save_path
class Img2ImgDiscretizationWrapper: def show_samples(samples, outputs):
""" if isinstance(samples, tuple):
wraps a discretizer, and prunes the sigmas samples, _ = samples
params: grid = embed_watermark(torch.stack([samples]))
strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned) grid = rearrange(grid, "n b c h w -> (n h) (b w) c")
""" outputs.image(grid.cpu().numpy())
def __init__(self, discretization, strength: float = 1.0):
self.discretization = discretization
self.strength = strength
assert 0.0 <= self.strength <= 1.0
def __call__(self, *args, **kwargs):
# sigmas start large first, and decrease then
sigmas = self.discretization(*args, **kwargs)
print(f"sigmas after discretization, before pruning img2img: ", sigmas)
sigmas = torch.flip(sigmas, (0,))
sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)]
print("prune index:", max(int(self.strength * len(sigmas)), 1))
sigmas = torch.flip(sigmas, (0,))
print(f"sigmas after pruning: ", sigmas)
return sigmas
class Txt2NoisyDiscretizationWrapper: def get_guider(params: SamplingParams, key=1) -> SamplingParams:
""" params.guider = Guider(
wraps a discretizer, and prunes the sigmas st.sidebar.selectbox(
params: f"Discretization #{key}", [member.value for member in Guider]
strength: float between 0.0 and 1.0. 0.0 means full sampling (all sigmas are returned) )
"""
def __init__(self, discretization, strength: float = 0.0, original_steps=None):
self.discretization = discretization
self.strength = strength
self.original_steps = original_steps
assert 0.0 <= self.strength <= 1.0
def __call__(self, *args, **kwargs):
# sigmas start large first, and decrease then
sigmas = self.discretization(*args, **kwargs)
print(f"sigmas after discretization, before pruning img2img: ", sigmas)
sigmas = torch.flip(sigmas, (0,))
if self.original_steps is None:
steps = len(sigmas)
else:
steps = self.original_steps + 1
prune_index = max(min(int(self.strength * steps) - 1, steps - 1), 0)
sigmas = sigmas[prune_index:]
print("prune index:", prune_index)
sigmas = torch.flip(sigmas, (0,))
print(f"sigmas after pruning: ", sigmas)
return sigmas
def get_guider(key):
guider = st.sidebar.selectbox(
f"Discretization #{key}",
[
"VanillaCFG",
"IdentityGuider",
],
) )
if guider == "IdentityGuider": if params.guider == Guider.VANILLA:
guider_config = {
"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"
}
elif guider == "VanillaCFG":
scale = st.number_input( scale = st.number_input(
f"cfg-scale #{key}", value=5.0, min_value=0.0, max_value=100.0 f"cfg-scale #{key}", value=params.scale, min_value=0.0, max_value=100.0
) )
params.scale = scale
thresholder = st.sidebar.selectbox( thresholder = st.sidebar.selectbox(
f"Thresholder #{key}", f"Thresholder #{key}",
[ [
@@ -308,182 +150,97 @@ def get_guider(key):
) )
if thresholder == "None": if thresholder == "None":
dyn_thresh_config = { params.thresholder = Thresholder.NONE
"target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"
}
else: else:
raise NotImplementedError raise NotImplementedError
return params
guider_config = {
"target": "sgm.modules.diffusionmodules.guiders.VanillaCFG",
"params": {"scale": scale, "dyn_thresh_config": dyn_thresh_config},
}
else:
raise NotImplementedError
return guider_config
def init_sampling( def init_sampling(
params: SamplingParams,
key=1, key=1,
img2img_strength=1.0,
specify_num_samples=True, specify_num_samples=True,
stage2strength=None, ) -> Tuple[SamplingParams, int, int]:
):
num_rows, num_cols = 1, 1 num_rows, num_cols = 1, 1
if specify_num_samples: if specify_num_samples:
num_cols = st.number_input( num_cols = st.number_input(
f"num cols #{key}", value=2, min_value=1, max_value=10 f"num cols #{key}", value=2, min_value=1, max_value=10
) )
steps = st.sidebar.number_input( params.steps = int(
f"steps #{key}", value=40, min_value=1, max_value=1000 st.sidebar.number_input(
) f"steps #{key}", value=params.steps, min_value=1, max_value=1000
sampler = st.sidebar.selectbox( )
f"Sampler #{key}",
[
"EulerEDMSampler",
"HeunEDMSampler",
"EulerAncestralSampler",
"DPMPP2SAncestralSampler",
"DPMPP2MSampler",
"LinearMultistepSampler",
],
0,
)
discretization = st.sidebar.selectbox(
f"Discretization #{key}",
[
"LegacyDDPMDiscretization",
"EDMDiscretization",
],
) )
discretization_config = get_discretization(discretization, key=key) params.sampler = Sampler(
st.sidebar.selectbox(
guider_config = get_guider(key=key) f"Sampler #{key}",
[member.value for member in Sampler],
sampler = get_sampler(sampler, steps, discretization_config, guider_config, key=key) 0,
if img2img_strength < 1.0:
st.warning(
f"Wrapping {sampler.__class__.__name__} with Img2ImgDiscretizationWrapper"
) )
sampler.discretization = Img2ImgDiscretizationWrapper( )
sampler.discretization, strength=img2img_strength params.discretization = Discretization(
st.sidebar.selectbox(
f"Discretization #{key}",
[member.value for member in Discretization],
) )
if stage2strength is not None: )
sampler.discretization = Txt2NoisyDiscretizationWrapper(
sampler.discretization, strength=stage2strength, original_steps=steps params = get_discretization(params=params, key=key)
params = get_guider(params=params, key=key)
params = get_sampler(params=params, key=key)
return params, num_rows, num_cols
def get_discretization(params: SamplingParams, key=1) -> SamplingParams:
if params.discretization == Discretization.EDM:
params.sigma_min = st.number_input(f"sigma_min #{key}", value=params.sigma_min)
params.sigma_max = st.number_input(f"sigma_max #{key}", value=params.sigma_max)
params.rho = st.number_input(f"rho #{key}", value=params.rho)
return params
def get_sampler(params: SamplingParams, key=1) -> SamplingParams:
if params.sampler in (Sampler.EULER_EDM, Sampler.HEUN_EDM):
params.s_churn = st.sidebar.number_input(
f"s_churn #{key}", value=params.s_churn, min_value=0.0
) )
return sampler, num_rows, num_cols params.s_tmin = st.sidebar.number_input(
f"s_tmin #{key}", value=params.s_tmin, min_value=0.0
def get_discretization(discretization, key=1):
if discretization == "LegacyDDPMDiscretization":
discretization_config = {
"target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",
}
elif discretization == "EDMDiscretization":
sigma_min = st.number_input(f"sigma_min #{key}", value=0.03) # 0.0292
sigma_max = st.number_input(f"sigma_max #{key}", value=14.61) # 14.6146
rho = st.number_input(f"rho #{key}", value=3.0)
discretization_config = {
"target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization",
"params": {
"sigma_min": sigma_min,
"sigma_max": sigma_max,
"rho": rho,
},
}
return discretization_config
def get_sampler(sampler_name, steps, discretization_config, guider_config, key=1):
if sampler_name == "EulerEDMSampler" or sampler_name == "HeunEDMSampler":
s_churn = st.sidebar.number_input(f"s_churn #{key}", value=0.0, min_value=0.0)
s_tmin = st.sidebar.number_input(f"s_tmin #{key}", value=0.0, min_value=0.0)
s_tmax = st.sidebar.number_input(f"s_tmax #{key}", value=999.0, min_value=0.0)
s_noise = st.sidebar.number_input(f"s_noise #{key}", value=1.0, min_value=0.0)
if sampler_name == "EulerEDMSampler":
sampler = EulerEDMSampler(
num_steps=steps,
discretization_config=discretization_config,
guider_config=guider_config,
s_churn=s_churn,
s_tmin=s_tmin,
s_tmax=s_tmax,
s_noise=s_noise,
verbose=True,
)
elif sampler_name == "HeunEDMSampler":
sampler = HeunEDMSampler(
num_steps=steps,
discretization_config=discretization_config,
guider_config=guider_config,
s_churn=s_churn,
s_tmin=s_tmin,
s_tmax=s_tmax,
s_noise=s_noise,
verbose=True,
)
elif (
sampler_name == "EulerAncestralSampler"
or sampler_name == "DPMPP2SAncestralSampler"
):
s_noise = st.sidebar.number_input("s_noise", value=1.0, min_value=0.0)
eta = st.sidebar.number_input("eta", value=1.0, min_value=0.0)
if sampler_name == "EulerAncestralSampler":
sampler = EulerAncestralSampler(
num_steps=steps,
discretization_config=discretization_config,
guider_config=guider_config,
eta=eta,
s_noise=s_noise,
verbose=True,
)
elif sampler_name == "DPMPP2SAncestralSampler":
sampler = DPMPP2SAncestralSampler(
num_steps=steps,
discretization_config=discretization_config,
guider_config=guider_config,
eta=eta,
s_noise=s_noise,
verbose=True,
)
elif sampler_name == "DPMPP2MSampler":
sampler = DPMPP2MSampler(
num_steps=steps,
discretization_config=discretization_config,
guider_config=guider_config,
verbose=True,
) )
elif sampler_name == "LinearMultistepSampler": params.s_tmax = st.sidebar.number_input(
order = st.sidebar.number_input("order", value=4, min_value=1) f"s_tmax #{key}", value=params.s_tmax, min_value=0.0
sampler = LinearMultistepSampler( )
num_steps=steps, params.s_noise = st.sidebar.number_input(
discretization_config=discretization_config, f"s_noise #{key}", value=params.s_noise, min_value=0.0
guider_config=guider_config,
order=order,
verbose=True,
) )
else:
raise ValueError(f"unknown sampler {sampler_name}!")
return sampler elif params.sampler in (Sampler.EULER_ANCESTRAL, Sampler.DPMPP2S_ANCESTRAL):
params.s_noise = st.sidebar.number_input(
"s_noise", value=params.s_noise, min_value=0.0
)
params.eta = st.sidebar.number_input("eta", value=params.eta, min_value=0.0)
elif params.sampler == Sampler.LINEAR_MULTISTEP:
params.order = int(
st.sidebar.number_input("order", value=params.order, min_value=1)
)
return params
def get_interactive_image(key=None) -> Image.Image: def get_interactive_image(key=None) -> Optional[Image.Image]:
image = st.file_uploader("Input", type=["jpg", "JPEG", "png"], key=key) image = st.file_uploader("Input", type=["jpg", "JPEG", "png"], key=key)
if image is not None: if image is not None:
image = Image.open(image) image = Image.open(image)
if not image.mode == "RGB": if not image.mode == "RGB":
image = image.convert("RGB") image = image.convert("RGB")
return image return image
return None
def load_img(display=True, key=None): def load_img(display=True, key=None) -> Optional[torch.Tensor]:
image = get_interactive_image(key=key) image = get_interactive_image(key=key)
if image is None: if image is None:
return None return None
@@ -507,238 +264,3 @@ def get_init_img(batch_size=1, key=None):
init_image = load_img(key=key).cuda() init_image = load_img(key=key).cuda()
init_image = repeat(init_image, "1 ... -> b ...", b=batch_size) init_image = repeat(init_image, "1 ... -> b ...", b=batch_size)
return init_image return init_image
def do_sample(
model,
sampler,
value_dict,
num_samples,
H,
W,
C,
F,
force_uc_zero_embeddings: List = None,
batch2model_input: List = None,
return_latents=False,
filter=None,
):
if force_uc_zero_embeddings is None:
force_uc_zero_embeddings = []
if batch2model_input is None:
batch2model_input = []
st.text("Sampling")
outputs = st.empty()
precision_scope = autocast
with torch.no_grad():
with precision_scope("cuda"):
with model.ema_scope():
num_samples = [num_samples]
load_model(model.conditioner)
batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner),
value_dict,
num_samples,
)
for key in batch:
if isinstance(batch[key], torch.Tensor):
print(key, batch[key].shape)
elif isinstance(batch[key], list):
print(key, [len(l) for l in batch[key]])
else:
print(key, batch[key])
c, uc = model.conditioner.get_unconditional_conditioning(
batch,
batch_uc=batch_uc,
force_uc_zero_embeddings=force_uc_zero_embeddings,
)
unload_model(model.conditioner)
for k in c:
if not k == "crossattn":
c[k], uc[k] = map(
lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc)
)
additional_model_inputs = {}
for k in batch2model_input:
additional_model_inputs[k] = batch[k]
shape = (math.prod(num_samples), C, H // F, W // F)
randn = torch.randn(shape).to("cuda")
def denoiser(input, sigma, c):
return model.denoiser(
model.model, input, sigma, c, **additional_model_inputs
)
load_model(model.denoiser)
load_model(model.model)
samples_z = sampler(denoiser, randn, cond=c, uc=uc)
unload_model(model.model)
unload_model(model.denoiser)
load_model(model.first_stage_model)
samples_x = model.decode_first_stage(samples_z)
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
unload_model(model.first_stage_model)
if filter is not None:
samples = filter(samples)
grid = torch.stack([samples])
grid = rearrange(grid, "n b c h w -> (n h) (b w) c")
outputs.image(grid.cpu().numpy())
if return_latents:
return samples, samples_z
return samples
def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
# Hardcoded demo setups; might undergo some changes in the future
batch = {}
batch_uc = {}
for key in keys:
if key == "txt":
batch["txt"] = (
np.repeat([value_dict["prompt"]], repeats=math.prod(N))
.reshape(N)
.tolist()
)
batch_uc["txt"] = (
np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N))
.reshape(N)
.tolist()
)
elif key == "original_size_as_tuple":
batch["original_size_as_tuple"] = (
torch.tensor([value_dict["orig_height"], value_dict["orig_width"]])
.to(device)
.repeat(*N, 1)
)
elif key == "crop_coords_top_left":
batch["crop_coords_top_left"] = (
torch.tensor(
[value_dict["crop_coords_top"], value_dict["crop_coords_left"]]
)
.to(device)
.repeat(*N, 1)
)
elif key == "aesthetic_score":
batch["aesthetic_score"] = (
torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
)
batch_uc["aesthetic_score"] = (
torch.tensor([value_dict["negative_aesthetic_score"]])
.to(device)
.repeat(*N, 1)
)
elif key == "target_size_as_tuple":
batch["target_size_as_tuple"] = (
torch.tensor([value_dict["target_height"], value_dict["target_width"]])
.to(device)
.repeat(*N, 1)
)
else:
batch[key] = value_dict[key]
for key in batch.keys():
if key not in batch_uc and isinstance(batch[key], torch.Tensor):
batch_uc[key] = torch.clone(batch[key])
return batch, batch_uc
@torch.no_grad()
def do_img2img(
img,
model,
sampler,
value_dict,
num_samples,
force_uc_zero_embeddings=[],
additional_kwargs={},
offset_noise_level: int = 0.0,
return_latents=False,
skip_encode=False,
filter=None,
add_noise=True,
):
st.text("Sampling")
outputs = st.empty()
precision_scope = autocast
with torch.no_grad():
with precision_scope("cuda"):
with model.ema_scope():
load_model(model.conditioner)
batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner),
value_dict,
[num_samples],
)
c, uc = model.conditioner.get_unconditional_conditioning(
batch,
batch_uc=batch_uc,
force_uc_zero_embeddings=force_uc_zero_embeddings,
)
unload_model(model.conditioner)
for k in c:
c[k], uc[k] = map(lambda y: y[k][:num_samples].to("cuda"), (c, uc))
for k in additional_kwargs:
c[k] = uc[k] = additional_kwargs[k]
if skip_encode:
z = img
else:
load_model(model.first_stage_model)
z = model.encode_first_stage(img)
unload_model(model.first_stage_model)
noise = torch.randn_like(z)
sigmas = sampler.discretization(sampler.num_steps).cuda()
sigma = sigmas[0]
st.info(f"all sigmas: {sigmas}")
st.info(f"noising sigma: {sigma}")
if offset_noise_level > 0.0:
noise = noise + offset_noise_level * append_dims(
torch.randn(z.shape[0], device=z.device), z.ndim
)
if add_noise:
noised_z = z + noise * append_dims(sigma, z.ndim).cuda()
noised_z = noised_z / torch.sqrt(
1.0 + sigmas[0] ** 2.0
) # Note: hardcoded to DDPM-like scaling. need to generalize later.
else:
noised_z = z / torch.sqrt(1.0 + sigmas[0] ** 2.0)
def denoiser(x, sigma, c):
return model.denoiser(model.model, x, sigma, c)
load_model(model.denoiser)
load_model(model.model)
samples_z = sampler(denoiser, noised_z, cond=c, uc=uc)
unload_model(model.model)
unload_model(model.denoiser)
load_model(model.first_stage_model)
samples_x = model.decode_first_stage(samples_z)
unload_model(model.first_stage_model)
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
if filter is not None:
samples = filter(samples)
grid = embed_watemark(torch.stack([samples]))
grid = rearrange(grid, "n b c h w -> (n h) (b w) c")
outputs.image(grid.cpu().numpy())
if return_latents:
return samples, samples_z
return samples

View File

@@ -1,4 +1,4 @@
from .models import AutoencodingEngine, DiffusionEngine from .models import AutoencodingEngine, DiffusionEngine
from .util import get_configs_path, instantiate_from_config from .util import get_configs_path, instantiate_from_config
__version__ = "0.1.0" __version__ = "0.1.1"

View File

@@ -1,11 +1,14 @@
from dataclasses import dataclass, asdict from dataclasses import dataclass, asdict
from enum import Enum from enum import Enum
from omegaconf import OmegaConf from omegaconf import OmegaConf
import pathlib import os
from sgm.inference.helpers import ( from sgm.inference.helpers import (
do_sample, do_sample,
do_img2img, do_img2img,
DeviceModelManager,
get_model_manager,
Img2ImgDiscretizationWrapper, Img2ImgDiscretizationWrapper,
Txt2NoisyDiscretizationWrapper,
) )
from sgm.modules.diffusionmodules.sampling import ( from sgm.modules.diffusionmodules.sampling import (
EulerEDMSampler, EulerEDMSampler,
@@ -15,17 +18,18 @@ from sgm.modules.diffusionmodules.sampling import (
DPMPP2MSampler, DPMPP2MSampler,
LinearMultistepSampler, LinearMultistepSampler,
) )
from sgm.util import load_model_from_config from sgm.util import load_model_from_config, get_configs_path, get_checkpoints_path
from typing import Optional import torch
from typing import Optional, Dict, Any, Union
class ModelArchitecture(str, Enum): class ModelArchitecture(str, Enum):
SD_2_1 = "stable-diffusion-v2-1" SDXL_V1_0_BASE = "stable-diffusion-xl-v1-base"
SD_2_1_768 = "stable-diffusion-v2-1-768" SDXL_V1_0_REFINER = "stable-diffusion-xl-v1-refiner"
SDXL_V0_9_BASE = "stable-diffusion-xl-v0-9-base" SDXL_V0_9_BASE = "stable-diffusion-xl-v0-9-base"
SDXL_V0_9_REFINER = "stable-diffusion-xl-v0-9-refiner" SDXL_V0_9_REFINER = "stable-diffusion-xl-v0-9-refiner"
SDXL_V1_BASE = "stable-diffusion-xl-v1-base" SD_2_1 = "stable-diffusion-v2-1"
SDXL_V1_REFINER = "stable-diffusion-xl-v1-refiner" SD_2_1_768 = "stable-diffusion-v2-1-768"
class Sampler(str, Enum): class Sampler(str, Enum):
@@ -53,16 +57,20 @@ class Thresholder(str, Enum):
@dataclass @dataclass
class SamplingParams: class SamplingParams:
width: int = 1024 """
height: int = 1024 Parameters for sampling.
steps: int = 50 """
sampler: Sampler = Sampler.DPMPP2M
width: Optional[int] = None
height: Optional[int] = None
steps: Optional[int] = None
sampler: Sampler = Sampler.EULER_EDM
discretization: Discretization = Discretization.LEGACY_DDPM discretization: Discretization = Discretization.LEGACY_DDPM
guider: Guider = Guider.VANILLA guider: Guider = Guider.VANILLA
thresholder: Thresholder = Thresholder.NONE thresholder: Thresholder = Thresholder.NONE
scale: float = 6.0 scale: float = 5.0
aesthetic_score: float = 5.0 aesthetic_score: float = 6.0
negative_aesthetic_score: float = 5.0 negative_aesthetic_score: float = 2.5
img2img_strength: float = 1.0 img2img_strength: float = 1.0
orig_width: int = 1024 orig_width: int = 1024
orig_height: int = 1024 orig_height: int = 1024
@@ -89,8 +97,10 @@ class SamplingSpec:
config: str config: str
ckpt: str ckpt: str
is_guided: bool is_guided: bool
default_params: SamplingParams
# The defaults here are derived from user preference testing.
model_specs = { model_specs = {
ModelArchitecture.SD_2_1: SamplingSpec( ModelArchitecture.SD_2_1: SamplingSpec(
height=512, height=512,
@@ -101,6 +111,12 @@ model_specs = {
config="sd_2_1.yaml", config="sd_2_1.yaml",
ckpt="v2-1_512-ema-pruned.safetensors", ckpt="v2-1_512-ema-pruned.safetensors",
is_guided=True, is_guided=True,
default_params=SamplingParams(
width=512,
height=512,
steps=40,
scale=7.0,
),
), ),
ModelArchitecture.SD_2_1_768: SamplingSpec( ModelArchitecture.SD_2_1_768: SamplingSpec(
height=768, height=768,
@@ -111,6 +127,12 @@ model_specs = {
config="sd_2_1_768.yaml", config="sd_2_1_768.yaml",
ckpt="v2-1_768-ema-pruned.safetensors", ckpt="v2-1_768-ema-pruned.safetensors",
is_guided=True, is_guided=True,
default_params=SamplingParams(
width=768,
height=768,
steps=40,
scale=7.0,
),
), ),
ModelArchitecture.SDXL_V0_9_BASE: SamplingSpec( ModelArchitecture.SDXL_V0_9_BASE: SamplingSpec(
height=1024, height=1024,
@@ -121,6 +143,7 @@ model_specs = {
config="sd_xl_base.yaml", config="sd_xl_base.yaml",
ckpt="sd_xl_base_0.9.safetensors", ckpt="sd_xl_base_0.9.safetensors",
is_guided=True, is_guided=True,
default_params=SamplingParams(width=1024, height=1024, steps=40, scale=5.0),
), ),
ModelArchitecture.SDXL_V0_9_REFINER: SamplingSpec( ModelArchitecture.SDXL_V0_9_REFINER: SamplingSpec(
height=1024, height=1024,
@@ -131,8 +154,11 @@ model_specs = {
config="sd_xl_refiner.yaml", config="sd_xl_refiner.yaml",
ckpt="sd_xl_refiner_0.9.safetensors", ckpt="sd_xl_refiner_0.9.safetensors",
is_guided=True, is_guided=True,
default_params=SamplingParams(
width=1024, height=1024, steps=40, scale=5.0, img2img_strength=0.15
),
), ),
ModelArchitecture.SDXL_V1_BASE: SamplingSpec( ModelArchitecture.SDXL_V1_0_BASE: SamplingSpec(
height=1024, height=1024,
width=1024, width=1024,
channels=4, channels=4,
@@ -141,8 +167,9 @@ model_specs = {
config="sd_xl_base.yaml", config="sd_xl_base.yaml",
ckpt="sd_xl_base_1.0.safetensors", ckpt="sd_xl_base_1.0.safetensors",
is_guided=True, is_guided=True,
default_params=SamplingParams(width=1024, height=1024, steps=40, scale=5.0),
), ),
ModelArchitecture.SDXL_V1_REFINER: SamplingSpec( ModelArchitecture.SDXL_V1_0_REFINER: SamplingSpec(
height=1024, height=1024,
width=1024, width=1024,
channels=4, channels=4,
@@ -151,34 +178,97 @@ model_specs = {
config="sd_xl_refiner.yaml", config="sd_xl_refiner.yaml",
ckpt="sd_xl_refiner_1.0.safetensors", ckpt="sd_xl_refiner_1.0.safetensors",
is_guided=True, is_guided=True,
default_params=SamplingParams(
width=1024, height=1024, steps=40, scale=5.0, img2img_strength=0.15
),
), ),
} }
def wrap_discretization(
discretization, image_strength=None, noise_strength=None, steps=None
):
if isinstance(discretization, Img2ImgDiscretizationWrapper) or isinstance(
discretization, Txt2NoisyDiscretizationWrapper
):
return discretization # Already wrapped
if image_strength is not None and image_strength < 1.0 and image_strength > 0.0:
discretization = Img2ImgDiscretizationWrapper(
discretization, strength=image_strength
)
if (
noise_strength is not None
and noise_strength < 1.0
and noise_strength > 0.0
and steps is not None
):
discretization = Txt2NoisyDiscretizationWrapper(
discretization,
strength=noise_strength,
original_steps=steps,
)
return discretization
class SamplingPipeline: class SamplingPipeline:
def __init__( def __init__(
self, self,
model_id: ModelArchitecture, model_id: Optional[ModelArchitecture] = None,
model_path="checkpoints", model_spec: Optional[SamplingSpec] = None,
config_path="configs/inference", model_path: Optional[str] = None,
device="cuda", config_path: Optional[str] = None,
use_fp16=True, use_fp16: bool = True,
device: Optional[Union[DeviceModelManager, str, torch.device]] = None,
) -> None: ) -> None:
if model_id not in model_specs: """
raise ValueError(f"Model {model_id} not supported") Sampling pipeline for generating images from a model.
self.model_id = model_id
self.specs = model_specs[self.model_id]
self.config = str(pathlib.Path(config_path, self.specs.config))
self.ckpt = str(pathlib.Path(model_path, self.specs.ckpt))
self.device = device
self.model = self._load_model(device=device, use_fp16=use_fp16)
def _load_model(self, device="cuda", use_fp16=True): @param model_id: Model architecture to use. If not specified, model_spec must be specified.
@param model_spec: Model specification to use. If not specified, model_id must be specified.
@param model_path: Path to model checkpoints folder.
@param config_path: Path to model config folder.
@param use_fp16: Whether to use fp16 for sampling.
@param device: Device manager to use with this pipeline. If a string or torch.device is passed, a device manager will be created based on device type if possible.
"""
self.model_id = model_id
if model_spec is not None:
self.specs = model_spec
elif model_id is not None:
if model_id not in model_specs:
raise ValueError(f"Model {model_id} not supported")
self.specs = model_specs[model_id]
else:
raise ValueError("Either model_id or model_spec should be provided")
if model_path is None:
model_path = get_checkpoints_path()
if config_path is None:
config_path = get_configs_path()
self.config = os.path.join(config_path, "inference", self.specs.config)
self.ckpt = os.path.join(model_path, self.specs.ckpt)
if not os.path.exists(self.config):
raise ValueError(
f"Config {self.config} not found, check model spec or config_path"
)
if not os.path.exists(self.ckpt):
raise ValueError(
f"Checkpoint {self.ckpt} not found, check model spec or config_path"
)
self.device_manager = get_model_manager(device)
self.model = self._load_model(
device_manager=self.device_manager, use_fp16=use_fp16
)
def _load_model(self, device_manager: DeviceModelManager, use_fp16=True):
config = OmegaConf.load(self.config) config = OmegaConf.load(self.config)
model = load_model_from_config(config, self.ckpt) model = load_model_from_config(config, self.ckpt)
if model is None: if model is None:
raise ValueError(f"Model {self.model_id} could not be loaded") raise ValueError(f"Model {self.model_id} could not be loaded")
model.to(device) device_manager.load(model)
if use_fp16: if use_fp16:
model.conditioner.half() model.conditioner.half()
model.model.half() model.model.half()
@@ -186,13 +276,34 @@ class SamplingPipeline:
def text_to_image( def text_to_image(
self, self,
params: SamplingParams,
prompt: str, prompt: str,
params: Optional[SamplingParams] = None,
negative_prompt: str = "", negative_prompt: str = "",
samples: int = 1, samples: int = 1,
return_latents: bool = False, return_latents: bool = False,
noise_strength: Optional[float] = None,
filter=None,
): ):
if params is None:
params = self.specs.default_params
else:
# Set defaults if optional params are not specified
if params.width is None:
params.width = self.specs.default_params.width
if params.height is None:
params.height = self.specs.default_params.height
if params.steps is None:
params.steps = self.specs.default_params.steps
sampler = get_sampler_config(params) sampler = get_sampler_config(params)
sampler.discretization = wrap_discretization(
sampler.discretization,
image_strength=None,
noise_strength=noise_strength,
steps=params.steps,
)
value_dict = asdict(params) value_dict = asdict(params)
value_dict["prompt"] = prompt value_dict["prompt"] = prompt
value_dict["negative_prompt"] = negative_prompt value_dict["negative_prompt"] = negative_prompt
@@ -209,31 +320,40 @@ class SamplingPipeline:
self.specs.factor, self.specs.factor,
force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [], force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
return_latents=return_latents, return_latents=return_latents,
filter=None, filter=filter,
device=self.device_manager,
) )
def image_to_image( def image_to_image(
self, self,
params: SamplingParams, image: torch.Tensor,
image,
prompt: str, prompt: str,
params: Optional[SamplingParams] = None,
negative_prompt: str = "", negative_prompt: str = "",
samples: int = 1, samples: int = 1,
return_latents: bool = False, return_latents: bool = False,
noise_strength: Optional[float] = None,
filter=None,
): ):
if params is None:
params = self.specs.default_params
sampler = get_sampler_config(params) sampler = get_sampler_config(params)
if params.img2img_strength < 1.0: sampler.discretization = wrap_discretization(
sampler.discretization = Img2ImgDiscretizationWrapper( sampler.discretization,
sampler.discretization, image_strength=params.img2img_strength,
strength=params.img2img_strength, noise_strength=noise_strength,
) steps=params.steps,
)
height, width = image.shape[2], image.shape[3] height, width = image.shape[2], image.shape[3]
value_dict = asdict(params) value_dict = asdict(params)
value_dict["prompt"] = prompt value_dict["prompt"] = prompt
value_dict["negative_prompt"] = negative_prompt value_dict["negative_prompt"] = negative_prompt
value_dict["target_width"] = width value_dict["target_width"] = width
value_dict["target_height"] = height value_dict["target_height"] = height
value_dict["orig_width"] = width
value_dict["orig_height"] = height
return do_img2img( return do_img2img(
image, image,
self.model, self.model,
@@ -242,18 +362,24 @@ class SamplingPipeline:
samples, samples,
force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [], force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
return_latents=return_latents, return_latents=return_latents,
filter=None, filter=filter,
device=self.device_manager,
) )
def refiner( def refiner(
self, self,
params: SamplingParams, image: torch.Tensor,
image,
prompt: str, prompt: str,
negative_prompt: Optional[str] = None, negative_prompt: str = "",
params: Optional[SamplingParams] = None,
samples: int = 1, samples: int = 1,
return_latents: bool = False, return_latents: bool = False,
filter: Any = None,
add_noise: bool = False,
): ):
if params is None:
params = self.specs.default_params
sampler = get_sampler_config(params) sampler = get_sampler_config(params)
value_dict = { value_dict = {
"orig_width": image.shape[3] * 8, "orig_width": image.shape[3] * 8,
@@ -268,6 +394,10 @@ class SamplingPipeline:
"negative_aesthetic_score": 2.5, "negative_aesthetic_score": 2.5,
} }
sampler.discretization = wrap_discretization(
sampler.discretization, image_strength=params.img2img_strength
)
return do_img2img( return do_img2img(
image, image,
self.model, self.model,
@@ -276,11 +406,14 @@ class SamplingPipeline:
samples, samples,
skip_encode=True, skip_encode=True,
return_latents=return_latents, return_latents=return_latents,
filter=None, filter=filter,
add_noise=add_noise,
device=self.device_manager,
) )
def get_guider_config(params: SamplingParams): def get_guider_config(params: SamplingParams) -> Dict[str, Any]:
guider_config: Dict[str, Any]
if params.guider == Guider.IDENTITY: if params.guider == Guider.IDENTITY:
guider_config = { guider_config = {
"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider" "target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"
@@ -306,7 +439,8 @@ def get_guider_config(params: SamplingParams):
return guider_config return guider_config
def get_discretization_config(params: SamplingParams): def get_discretization_config(params: SamplingParams) -> Dict[str, Any]:
discretization_config: Dict[str, Any]
if params.discretization == Discretization.LEGACY_DDPM: if params.discretization == Discretization.LEGACY_DDPM:
discretization_config = { discretization_config = {
"target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization", "target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",

View File

@@ -1,3 +1,4 @@
import contextlib
import os import os
from typing import Union, List, Optional from typing import Union, List, Optional
@@ -8,7 +9,6 @@ from PIL import Image
from einops import rearrange from einops import rearrange
from imwatermark import WatermarkEncoder from imwatermark import WatermarkEncoder
from omegaconf import ListConfig from omegaconf import ListConfig
from torch import autocast
from sgm.util import append_dims from sgm.util import append_dims
@@ -58,6 +58,73 @@ WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
embed_watermark = WatermarkEmbedder(WATERMARK_BITS) embed_watermark = WatermarkEmbedder(WATERMARK_BITS)
class DeviceModelManager(object):
"""
Default model loading class, should work for all device classes.
"""
def __init__(
self,
device: Union[torch.device, str],
swap_device: Optional[Union[torch.device, str]] = None,
) -> None:
"""
Args:
device (Union[torch.device, str]): The device to use for the model.
"""
self.device = torch.device(device)
self.swap_device = (
torch.device(swap_device) if swap_device is not None else self.device
)
def load(self, model: torch.nn.Module) -> None:
"""
Loads a model to the (swap) device.
"""
model.to(self.swap_device)
def autocast(self):
"""
Context manager that enables autocast for the device if supported.
"""
if self.device.type not in ("cuda", "cpu"):
return contextlib.nullcontext()
return torch.autocast(self.device.type)
@contextlib.contextmanager
def use(self, model: torch.nn.Module):
"""
Context manager that ensures a model is on the correct device during use.
The default model loader does not perform any swapping, so the model will
stay on device.
"""
try:
model.to(self.device)
yield
finally:
if self.device != self.swap_device:
model.to(self.swap_device)
class CudaModelManager(DeviceModelManager):
"""
Device manager that loads a model to a CUDA device, optionally swapping to CPU when not in use.
"""
@contextlib.contextmanager
def use(self, model: Union[torch.nn.Module, torch.Tensor]):
"""
Context manager that ensures a model is on the correct device during use.
If a swap device was provided, this will move the model to it after use and clear cache.
"""
model.to(self.device)
yield
if self.device != self.swap_device:
model.to(self.swap_device)
if torch.cuda.is_available():
torch.cuda.empty_cache()
def get_unique_embedder_keys_from_conditioner(conditioner): def get_unique_embedder_keys_from_conditioner(conditioner):
return list({x.input_key for x in conditioner.embedders}) return list({x.input_key for x in conditioner.embedders})
@@ -74,6 +141,20 @@ def perform_save_locally(save_path, samples):
base_count += 1 base_count += 1
def get_model_manager(
device: Optional[Union[DeviceModelManager, str, torch.device]]
) -> DeviceModelManager:
if isinstance(device, DeviceModelManager):
return device
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(device)
if device.type == "cuda":
return CudaModelManager(device=device)
else:
return DeviceModelManager(device=device)
class Img2ImgDiscretizationWrapper: class Img2ImgDiscretizationWrapper:
""" """
wraps a discretizer, and prunes the sigmas wraps a discretizer, and prunes the sigmas
@@ -98,6 +179,36 @@ class Img2ImgDiscretizationWrapper:
return sigmas return sigmas
class Txt2NoisyDiscretizationWrapper:
"""
wraps a discretizer, and prunes the sigmas
params:
strength: float between 0.0 and 1.0. 0.0 means full sampling (all sigmas are returned)
"""
def __init__(self, discretization, strength: float = 0.0, original_steps=None):
self.discretization = discretization
self.strength = strength
self.original_steps = original_steps
assert 0.0 <= self.strength <= 1.0
def __call__(self, *args, **kwargs):
# sigmas start large first, and decrease then
sigmas = self.discretization(*args, **kwargs)
print(f"sigmas after discretization, before pruning img2img: ", sigmas)
sigmas = torch.flip(sigmas, (0,))
if self.original_steps is None:
steps = len(sigmas)
else:
steps = self.original_steps + 1
prune_index = max(min(int(self.strength * steps) - 1, steps - 1), 0)
sigmas = sigmas[prune_index:]
print("prune index:", prune_index)
sigmas = torch.flip(sigmas, (0,))
print(f"sigmas after pruning: ", sigmas)
return sigmas
def do_sample( def do_sample(
model, model,
sampler, sampler,
@@ -111,39 +222,45 @@ def do_sample(
batch2model_input: Optional[List] = None, batch2model_input: Optional[List] = None,
return_latents=False, return_latents=False,
filter=None, filter=None,
device="cuda", device: Optional[Union[DeviceModelManager, str, torch.device]] = None,
): ):
if force_uc_zero_embeddings is None: if force_uc_zero_embeddings is None:
force_uc_zero_embeddings = [] force_uc_zero_embeddings = []
if batch2model_input is None: if batch2model_input is None:
batch2model_input = [] batch2model_input = []
device_manager = get_model_manager(device=device)
with torch.no_grad(): with torch.no_grad():
with autocast(device) as precision_scope: with device_manager.autocast():
with model.ema_scope(): with model.ema_scope():
num_samples = [num_samples] num_samples = [num_samples]
batch, batch_uc = get_batch( with device_manager.use(model.conditioner):
get_unique_embedder_keys_from_conditioner(model.conditioner), batch, batch_uc = get_batch(
value_dict, get_unique_embedder_keys_from_conditioner(model.conditioner),
num_samples, value_dict,
) num_samples,
for key in batch: )
if isinstance(batch[key], torch.Tensor): for key in batch:
print(key, batch[key].shape) if isinstance(batch[key], torch.Tensor):
elif isinstance(batch[key], list): print(key, batch[key].shape)
print(key, [len(l) for l in batch[key]]) elif isinstance(batch[key], list):
else: print(key, [len(l) for l in batch[key]])
print(key, batch[key]) else:
c, uc = model.conditioner.get_unconditional_conditioning( print(key, batch[key])
batch, c, uc = model.conditioner.get_unconditional_conditioning(
batch_uc=batch_uc, batch,
force_uc_zero_embeddings=force_uc_zero_embeddings, batch_uc=batch_uc,
) force_uc_zero_embeddings=force_uc_zero_embeddings,
)
for k in c: for k in c:
if not k == "crossattn": if not k == "crossattn":
c[k], uc[k] = map( c[k], uc[k] = map(
lambda y: y[k][: math.prod(num_samples)].to(device), (c, uc) lambda y: y[k][: math.prod(num_samples)].to(
device_manager.device
),
(c, uc),
) )
additional_model_inputs = {} additional_model_inputs = {}
@@ -151,16 +268,20 @@ def do_sample(
additional_model_inputs[k] = batch[k] additional_model_inputs[k] = batch[k]
shape = (math.prod(num_samples), C, H // F, W // F) shape = (math.prod(num_samples), C, H // F, W // F)
randn = torch.randn(shape).to(device) randn = torch.randn(shape).to(device_manager.device)
def denoiser(input, sigma, c): def denoiser(input, sigma, c):
return model.denoiser( return model.denoiser(
model.model, input, sigma, c, **additional_model_inputs model.model, input, sigma, c, **additional_model_inputs
) )
samples_z = sampler(denoiser, randn, cond=c, uc=uc) with device_manager.use(model.denoiser):
samples_x = model.decode_first_stage(samples_z) with device_manager.use(model.model):
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) samples_z = sampler(denoiser, randn, cond=c, uc=uc)
with device_manager.use(model.first_stage_model):
samples_x = model.decode_first_stage(samples_z)
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
if filter is not None: if filter is not None:
samples = filter(samples) samples = filter(samples)
@@ -252,32 +373,40 @@ def do_img2img(
return_latents=False, return_latents=False,
skip_encode=False, skip_encode=False,
filter=None, filter=None,
device="cuda", add_noise=True,
device: Optional[Union[DeviceModelManager, str, torch.device]] = None,
): ):
device_manager = get_model_manager(device)
with torch.no_grad(): with torch.no_grad():
with autocast(device) as precision_scope: with device_manager.autocast():
with model.ema_scope(): with model.ema_scope():
batch, batch_uc = get_batch( with device_manager.use(model.conditioner):
get_unique_embedder_keys_from_conditioner(model.conditioner), batch, batch_uc = get_batch(
value_dict, get_unique_embedder_keys_from_conditioner(model.conditioner),
[num_samples], value_dict,
) [num_samples],
c, uc = model.conditioner.get_unconditional_conditioning( )
batch, c, uc = model.conditioner.get_unconditional_conditioning(
batch_uc=batch_uc, batch,
force_uc_zero_embeddings=force_uc_zero_embeddings, batch_uc=batch_uc,
) force_uc_zero_embeddings=force_uc_zero_embeddings,
)
for k in c: for k in c:
c[k], uc[k] = map(lambda y: y[k][:num_samples].to(device), (c, uc)) c[k], uc[k] = map(
lambda y: y[k][:num_samples].to(device_manager.device), (c, uc)
)
for k in additional_kwargs: for k in additional_kwargs:
c[k] = uc[k] = additional_kwargs[k] c[k] = uc[k] = additional_kwargs[k]
if skip_encode: if skip_encode:
z = img z = img
else: else:
z = model.encode_first_stage(img) with device_manager.use(model.first_stage_model):
z = model.encode_first_stage(img)
noise = torch.randn_like(z) noise = torch.randn_like(z)
sigmas = sampler.discretization(sampler.num_steps) sigmas = sampler.discretization(sampler.num_steps)
sigma = sigmas[0].to(z.device) sigma = sigmas[0].to(z.device)
@@ -285,17 +414,24 @@ def do_img2img(
noise = noise + offset_noise_level * append_dims( noise = noise + offset_noise_level * append_dims(
torch.randn(z.shape[0], device=z.device), z.ndim torch.randn(z.shape[0], device=z.device), z.ndim
) )
noised_z = z + noise * append_dims(sigma, z.ndim) if add_noise:
noised_z = noised_z / torch.sqrt( noised_z = z + noise * append_dims(sigma, z.ndim).cuda()
1.0 + sigmas[0] ** 2.0 noised_z = noised_z / torch.sqrt(
) # Note: hardcoded to DDPM-like scaling. need to generalize later. 1.0 + sigmas[0] ** 2.0
) # Note: hardcoded to DDPM-like scaling. need to generalize later.
else:
noised_z = z / torch.sqrt(1.0 + sigmas[0] ** 2.0)
def denoiser(x, sigma, c): def denoiser(x, sigma, c):
return model.denoiser(model.model, x, sigma, c) return model.denoiser(model.model, x, sigma, c)
samples_z = sampler(denoiser, noised_z, cond=c, uc=uc) with device_manager.use(model.denoiser):
samples_x = model.decode_first_stage(samples_z) with device_manager.use(model.model):
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) samples_z = sampler(denoiser, noised_z, cond=c, uc=uc)
with device_manager.use(model.first_stage_model):
samples_x = model.decode_first_stage(samples_z)
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
if filter is not None: if filter is not None:
samples = filter(samples) samples = filter(samples)

View File

@@ -26,7 +26,7 @@ class Discretization:
class EDMDiscretization(Discretization): class EDMDiscretization(Discretization):
def __init__(self, sigma_min=0.02, sigma_max=80.0, rho=7.0): def __init__(self, sigma_min=0.002, sigma_max=80.0, rho=7.0):
self.sigma_min = sigma_min self.sigma_min = sigma_min
self.sigma_max = sigma_max self.sigma_max = sigma_max
self.rho = rho self.rho = rho

View File

@@ -230,6 +230,24 @@ def load_model_from_config(config, ckpt, verbose=True, freeze=True):
return model return model
def get_checkpoints_path() -> str:
"""
Get the `checkpoints` directory.
This could be in the root of the repository for a working copy,
or in the cwd for other use cases.
"""
this_dir = os.path.dirname(__file__)
candidates = (
os.path.join(this_dir, "checkpoints"),
os.path.join(os.getcwd(), "checkpoints"),
)
for candidate in candidates:
candidate = os.path.abspath(candidate)
if os.path.isdir(candidate):
return candidate
raise FileNotFoundError(f"Could not find SGM checkpoints in {candidates}")
def get_configs_path() -> str: def get_configs_path() -> str:
""" """
Get the `configs` directory. Get the `configs` directory.

View File

@@ -27,7 +27,7 @@ class TestInference:
@fixture( @fixture(
scope="class", scope="class",
params=[ params=[
[ModelArchitecture.SDXL_V1_BASE, ModelArchitecture.SDXL_V1_REFINER], [ModelArchitecture.SDXL_V1_0_BASE, ModelArchitecture.SDXL_V1_0_REFINER],
[ModelArchitecture.SDXL_V0_9_BASE, ModelArchitecture.SDXL_V0_9_REFINER], [ModelArchitecture.SDXL_V0_9_BASE, ModelArchitecture.SDXL_V0_9_REFINER],
], ],
ids=["SDXL_V1", "SDXL_V0_9"], ids=["SDXL_V1", "SDXL_V0_9"],
@@ -68,9 +68,7 @@ class TestInference:
assert output is not None assert output is not None
@pytest.mark.parametrize("sampler_enum", Sampler) @pytest.mark.parametrize("sampler_enum", Sampler)
@pytest.mark.parametrize( @pytest.mark.parametrize("use_init_image", [True, False], ids=["img2img", "txt2img"])
"use_init_image", [True, False], ids=["img2img", "txt2img"]
)
def test_sdxl_with_refiner( def test_sdxl_with_refiner(
self, self,
sdxl_pipelines: Tuple[SamplingPipeline, SamplingPipeline], sdxl_pipelines: Tuple[SamplingPipeline, SamplingPipeline],
@@ -81,13 +79,12 @@ class TestInference:
if use_init_image: if use_init_image:
output = base_pipeline.image_to_image( output = base_pipeline.image_to_image(
params=SamplingParams(sampler=sampler_enum.value, steps=10), params=SamplingParams(sampler=sampler_enum.value, steps=10),
image=self.create_init_image( image=self.create_init_image(base_pipeline.specs.height, base_pipeline.specs.width),
base_pipeline.specs.height, base_pipeline.specs.width
),
prompt="A professional photograph of an astronaut riding a pig", prompt="A professional photograph of an astronaut riding a pig",
negative_prompt="", negative_prompt="",
samples=1, samples=1,
return_latents=True, return_latents=True,
noise_strength=0.15,
) )
else: else:
output = base_pipeline.text_to_image( output = base_pipeline.text_to_image(
@@ -96,6 +93,7 @@ class TestInference:
negative_prompt="", negative_prompt="",
samples=1, samples=1,
return_latents=True, return_latents=True,
noise_strength=0.15,
) )
assert isinstance(output, (tuple, list)) assert isinstance(output, (tuple, list))
@@ -103,9 +101,9 @@ class TestInference:
assert samples is not None assert samples is not None
assert samples_z is not None assert samples_z is not None
refiner_pipeline.refiner( refiner_pipeline.refiner(
params=SamplingParams(sampler=sampler_enum.value, steps=10),
image=samples_z, image=samples_z,
prompt="A professional photograph of an astronaut riding a pig", prompt="A professional photograph of an astronaut riding a pig",
params=SamplingParams(sampler=sampler_enum.value, steps=40, img2img_strength=0.15),
negative_prompt="", negative_prompt="",
samples=1, samples=1,
) )

View File

@@ -0,0 +1,44 @@
import pytest
import torch
from sgm.inference.api import (
SamplingPipeline,
ModelArchitecture,
)
import sgm.inference.helpers as helpers
def get_torch_device(model: torch.nn.Module) -> torch.device:
param = next(model.parameters(), None)
if param is not None:
return param.device
else:
buf = next(model.buffers(), None)
if buf is not None:
return buf.device
else:
raise TypeError("Could not determine device of input model")
@pytest.mark.inference
def test_default_loading():
pipeline = SamplingPipeline(model_id=ModelArchitecture.SD_2_1)
assert get_torch_device(pipeline.model.model).type == "cuda"
assert get_torch_device(pipeline.model.conditioner).type == "cuda"
with pipeline.device_manager.use(pipeline.model.model):
assert get_torch_device(pipeline.model.model).type == "cuda"
assert get_torch_device(pipeline.model.model).type == "cuda"
with pipeline.device_manager.use(pipeline.model.conditioner):
assert get_torch_device(pipeline.model.conditioner).type == "cuda"
assert get_torch_device(pipeline.model.conditioner).type == "cuda"
@pytest.mark.inference
def test_model_swapping():
pipeline = SamplingPipeline(model_id=ModelArchitecture.SD_2_1, device=helpers.CudaModelManager(device="cuda", swap_device="cpu"))
assert get_torch_device(pipeline.model.model).type == "cpu"
assert get_torch_device(pipeline.model.conditioner).type == "cpu"
with pipeline.device_manager.use(pipeline.model.model):
assert get_torch_device(pipeline.model.model).type == "cuda"
assert get_torch_device(pipeline.model.model).type == "cpu"
with pipeline.device_manager.use(pipeline.model.conditioner):
assert get_torch_device(pipeline.model.conditioner).type == "cuda"
assert get_torch_device(pipeline.model.conditioner).type == "cpu"