2 Commits

Author SHA1 Message Date
Stephan Auerhahn
3167b358c0 use feature to install pytorch 2023-08-12 05:59:57 +00:00
Stephan Auerhahn
ab11af1431 Add devcontainer configs 2023-08-03 17:44:56 -07:00
14 changed files with 1040 additions and 732 deletions

View File

@@ -0,0 +1,45 @@
{
"name": "Python 3",
"image": "mcr.microsoft.com/devcontainers/python:1-3.10-bullseye",
"features": {
"ghcr.io/devcontainers-contrib/features/hatch:2": {
"version": "latest"
},
"ghcr.io/devcontainers-contrib/features/pipx-package:1": {
"package": "black",
"version": "latest",
"injections": "pylint pytest",
"interpreter": "python3"
},
"ghcr.io/devcontainers-contrib/features/apt-get-packages:1": {
"packages": "libgl1-mesa-glx"
},
"ghcr.io/stability-ai/devcontainer-features/pytorch:1.0.1": {
"version": "2.0.1",
"cudaVersion": "cpu"
}
},
"customizations": {
"vscode": {
"extensions": [
"ms-python.python",
"ms-python.vscode-pylance",
"ms-python.black-formatter"
],
"settings": {
"[python]": {
"diffEditor.ignoreTrimWhitespace": false,
"editor.defaultFormatter": "ms-python.black-formatter",
"editor.formatOnSave": true,
"editor.wordBasedSuggestions": false
},
"python.analysis.typeCheckingMode": "basic",
"black-formatter.args": [
"-l 88"
]
}
}
},
"postCreateCommand": "git config --global --add safe.directory '*'",
"postAttachCommand": "pip install -U -r requirements/pt2.txt;pip install -U -e ."
}

View File

@@ -0,0 +1,52 @@
{
"name": "Python 3",
"image": "mcr.microsoft.com/devcontainers/python:1-3.8-bullseye",
"features": {
"ghcr.io/devcontainers/features/nvidia-cuda:1": {
"cudaVersion": "11.7"
},
"ghcr.io/devcontainers-contrib/features/hatch:2": {
"version": "latest"
},
"ghcr.io/devcontainers-contrib/features/pipx-package:1": {
"package": "black",
"version": "latest",
"injections": "pylint pytest",
"interpreter": "python3"
},
"ghcr.io/devcontainers-contrib/features/apt-get-packages:1": {
"packages": "libgl1-mesa-glx"
},
"ghcr.io/stability-ai/devcontainer-features/pytorch:1.0.1": {
"version": "1.13.1",
"cudaVersion": "cu117"
}
},
"customizations": {
"vscode": {
"extensions": [
"ms-python.python",
"ms-python.vscode-pylance",
"ms-python.black-formatter"
],
"settings": {
"[python]": {
"diffEditor.ignoreTrimWhitespace": false,
"editor.defaultFormatter": "ms-python.black-formatter",
"editor.formatOnSave": true,
"editor.wordBasedSuggestions": false
},
"python.analysis.typeCheckingMode": "basic",
"black-formatter.args": [
"-l 88"
]
}
}
},
"runArgs": [
"--gpus",
"all"
],
"postCreateCommand": "git config --global --add safe.directory '*'",
"postAttachCommand": "pip install -U -r requirements/pt13.txt;pip install -U -e ."
}

View File

@@ -0,0 +1,52 @@
{
"name": "Python 3",
"image": "mcr.microsoft.com/devcontainers/python:1-3.10-bullseye",
"features": {
"ghcr.io/devcontainers/features/nvidia-cuda:1": {
"cudaVersion": "11.8"
},
"ghcr.io/devcontainers-contrib/features/hatch:2": {
"version": "latest"
},
"ghcr.io/devcontainers-contrib/features/pipx-package:1": {
"package": "black",
"version": "latest",
"injections": "pylint pytest",
"interpreter": "python3"
},
"ghcr.io/devcontainers-contrib/features/apt-get-packages:1": {
"packages": "libgl1-mesa-glx"
},
"ghcr.io/stability-ai/devcontainer-features/pytorch:1.0.1": {
"version": "2.0.1",
"cudaVersion": "cu118"
}
},
"customizations": {
"vscode": {
"extensions": [
"ms-python.python",
"ms-python.vscode-pylance",
"ms-python.black-formatter"
],
"settings": {
"[python]": {
"diffEditor.ignoreTrimWhitespace": false,
"editor.defaultFormatter": "ms-python.black-formatter",
"editor.formatOnSave": true,
"editor.wordBasedSuggestions": false
},
"python.analysis.typeCheckingMode": "basic",
"black-formatter.args": [
"-l 88"
]
}
}
},
"runArgs": [
"--gpus",
"all"
],
"postCreateCommand": "git config --global --add safe.directory '*'",
"postAttachCommand": "pip install -U -r requirements/pt2.txt;pip install -U -e ."
}

View File

@@ -15,7 +15,7 @@ jobs:
steps:
- uses: actions/checkout@v3
- name: "Symlink checkpoints"
run: ln -s $SGM_CHECKPOINTS checkpoints
run: ln -s ${{vars.SGM_CHECKPOINTS_PATH}} checkpoints
- name: "Setup python"
uses: actions/setup-python@v4
with:

View File

@@ -44,5 +44,5 @@ dependencies = [
test-inference = [
"pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 torchaudio==2.0.2+cu118 --index-url https://download.pytorch.org/whl/cu118",
"pip install -r requirements/pt2.txt",
"pytest -v tests/inference {args}",
"pytest -v tests/inference/test_inference.py {args}",
]

View File

@@ -1,30 +1,6 @@
import os
import numpy as np
import streamlit as st
import torch
from einops import repeat
from pytorch_lightning import seed_everything
from sgm.inference.api import (
SamplingSpec,
SamplingParams,
ModelArchitecture,
SamplingPipeline,
model_specs,
)
from sgm.inference.helpers import (
get_unique_embedder_keys_from_conditioner,
perform_save_locally,
)
from scripts.demo.streamlit_helpers import (
get_interactive_image,
init_embedder_options,
init_sampling,
init_save_locally,
init_st,
show_samples,
)
from scripts.demo.streamlit_helpers import *
SAVE_PATH = "outputs/demo/txt2img/"
@@ -57,6 +33,63 @@ SD_XL_BASE_RATIOS = {
"3.0": (1728, 576),
}
VERSION2SPECS = {
"SDXL-base-1.0": {
"H": 1024,
"W": 1024,
"C": 4,
"f": 8,
"is_legacy": False,
"config": "configs/inference/sd_xl_base.yaml",
"ckpt": "checkpoints/sd_xl_base_1.0.safetensors",
},
"SDXL-base-0.9": {
"H": 1024,
"W": 1024,
"C": 4,
"f": 8,
"is_legacy": False,
"config": "configs/inference/sd_xl_base.yaml",
"ckpt": "checkpoints/sd_xl_base_0.9.safetensors",
},
"SD-2.1": {
"H": 512,
"W": 512,
"C": 4,
"f": 8,
"is_legacy": True,
"config": "configs/inference/sd_2_1.yaml",
"ckpt": "checkpoints/v2-1_512-ema-pruned.safetensors",
},
"SD-2.1-768": {
"H": 768,
"W": 768,
"C": 4,
"f": 8,
"is_legacy": True,
"config": "configs/inference/sd_2_1_768.yaml",
"ckpt": "checkpoints/v2-1_768-ema-pruned.safetensors",
},
"SDXL-refiner-0.9": {
"H": 1024,
"W": 1024,
"C": 4,
"f": 8,
"is_legacy": True,
"config": "configs/inference/sd_xl_refiner.yaml",
"ckpt": "checkpoints/sd_xl_refiner_0.9.safetensors",
},
"SDXL-refiner-1.0": {
"H": 1024,
"W": 1024,
"C": 4,
"f": 8,
"is_legacy": True,
"config": "configs/inference/sd_xl_refiner.yaml",
"ckpt": "checkpoints/sd_xl_refiner_1.0.safetensors",
},
}
def load_img(display=True, key=None, device="cuda"):
image = get_interactive_image(key=key)
@@ -78,181 +111,170 @@ def load_img(display=True, key=None, device="cuda"):
def run_txt2img(
state,
model_id: ModelArchitecture,
prompt: str,
negative_prompt: str,
version,
version_dict,
is_legacy=False,
return_latents=False,
filter=None,
stage2strength=None,
):
model: SamplingPipeline = state["model"]
params: SamplingParams = state["params"]
if model_id in sdxl_base_model_list:
width, height = st.selectbox(
"Resolution:", list(SD_XL_BASE_RATIOS.values()), 10
)
if version.startswith("SDXL-base"):
W, H = st.selectbox("Resolution:", list(SD_XL_BASE_RATIOS.values()), 10)
else:
height = int(
st.number_input("H", value=params.height, min_value=64, max_value=2048)
)
width = int(
st.number_input("W", value=params.width, min_value=64, max_value=2048)
)
H = st.number_input("H", value=version_dict["H"], min_value=64, max_value=2048)
W = st.number_input("W", value=version_dict["W"], min_value=64, max_value=2048)
C = version_dict["C"]
F = version_dict["f"]
params = init_embedder_options(
get_unique_embedder_keys_from_conditioner(model.model.conditioner),
params=params,
init_dict = {
"orig_width": W,
"orig_height": H,
"target_width": W,
"target_height": H,
}
value_dict = init_embedder_options(
get_unique_embedder_keys_from_conditioner(state["model"].conditioner),
init_dict,
prompt=prompt,
negative_prompt=negative_prompt,
)
params, num_rows, num_cols = init_sampling(params=params)
sampler, num_rows, num_cols = init_sampling(stage2strength=stage2strength)
num_samples = num_rows * num_cols
params.height = height
params.width = width
if st.button("Sample"):
st.write(f"**Model I:** {version}")
outputs = st.empty()
st.text("Sampling")
out = model.text_to_image(
params=params,
prompt=prompt,
negative_prompt=negative_prompt,
samples=int(num_samples),
out = do_sample(
state["model"],
sampler,
value_dict,
num_samples,
H,
W,
C,
F,
force_uc_zero_embeddings=["txt"] if not is_legacy else [],
return_latents=return_latents,
noise_strength=stage2strength,
filter=state["filter"],
filter=filter,
)
show_samples(out, outputs)
return out
def run_img2img(
state,
prompt: str,
negative_prompt: str,
version_dict,
is_legacy=False,
return_latents=False,
filter=None,
stage2strength=None,
):
model: SamplingPipeline = state["model"]
params: SamplingParams = state["params"]
img = load_img()
if img is None:
return None
params.height, params.width = img.shape[2], img.shape[3]
H, W = img.shape[2], img.shape[3]
params = init_embedder_options(
get_unique_embedder_keys_from_conditioner(model.model.conditioner),
params=params,
init_dict = {
"orig_width": W,
"orig_height": H,
"target_width": W,
"target_height": H,
}
value_dict = init_embedder_options(
get_unique_embedder_keys_from_conditioner(state["model"].conditioner),
init_dict,
prompt=prompt,
negative_prompt=negative_prompt,
)
params.img2img_strength = st.number_input(
strength = st.number_input(
"**Img2Img Strength**", value=0.75, min_value=0.0, max_value=1.0
)
params, num_rows, num_cols = init_sampling(params=params)
sampler, num_rows, num_cols = init_sampling(
img2img_strength=strength,
stage2strength=stage2strength,
)
num_samples = num_rows * num_cols
if st.button("Sample"):
outputs = st.empty()
st.text("Sampling")
out = model.image_to_image(
image=repeat(img, "1 ... -> n ...", n=num_samples),
params=params,
prompt=prompt,
negative_prompt=negative_prompt,
samples=int(num_samples),
out = do_img2img(
repeat(img, "1 ... -> n ...", n=num_samples),
state["model"],
sampler,
value_dict,
num_samples,
force_uc_zero_embeddings=["txt"] if not is_legacy else [],
return_latents=return_latents,
noise_strength=stage2strength,
filter=state["filter"],
filter=filter,
)
show_samples(out, outputs)
return out
def apply_refiner(
input,
state,
num_samples: int,
prompt: str,
negative_prompt: str,
sampler,
num_samples,
prompt,
negative_prompt,
filter=None,
finish_denoising=False,
):
model: SamplingPipeline = state["model"]
params: SamplingParams = state["params"]
init_dict = {
"orig_width": input.shape[3] * 8,
"orig_height": input.shape[2] * 8,
"target_width": input.shape[3] * 8,
"target_height": input.shape[2] * 8,
}
params.orig_width = input.shape[3] * 8
params.orig_height = input.shape[2] * 8
params.width = input.shape[3] * 8
params.height = input.shape[2] * 8
value_dict = init_dict
value_dict["prompt"] = prompt
value_dict["negative_prompt"] = negative_prompt
value_dict["crop_coords_top"] = 0
value_dict["crop_coords_left"] = 0
value_dict["aesthetic_score"] = 6.0
value_dict["negative_aesthetic_score"] = 2.5
st.warning(f"refiner input shape: {input.shape}")
samples = model.refiner(
image=input,
params=params,
prompt=prompt,
negative_prompt=negative_prompt,
samples=num_samples,
return_latents=False,
filter=state["filter"],
samples = do_img2img(
input,
state["model"],
sampler,
value_dict,
num_samples,
skip_encode=True,
filter=filter,
add_noise=not finish_denoising,
)
return samples
sdxl_base_model_list = [
ModelArchitecture.SDXL_V1_0_BASE,
ModelArchitecture.SDXL_V0_9_BASE,
]
sdxl_refiner_model_list = [
ModelArchitecture.SDXL_V1_0_REFINER,
ModelArchitecture.SDXL_V0_9_REFINER,
]
if __name__ == "__main__":
st.title("Stable Diffusion")
version = st.selectbox(
"Model Version",
[member.value for member in ModelArchitecture],
0,
)
version_enum = ModelArchitecture(version)
specs = model_specs[version_enum]
version = st.selectbox("Model Version", list(VERSION2SPECS.keys()), 0)
version_dict = VERSION2SPECS[version]
mode = st.radio("Mode", ("txt2img", "img2img"), 0)
st.write("__________________________")
st.write("**Performance Options:**")
use_fp16 = st.checkbox("Use fp16 (Saves VRAM)", True)
enable_swap = st.checkbox("Swap models to CPU (Saves VRAM, uses RAM)", True)
st.write("__________________________")
set_lowvram_mode(st.checkbox("Low vram mode", True))
if version_enum in sdxl_base_model_list:
if version.startswith("SDXL-base"):
add_pipeline = st.checkbox("Load SDXL-refiner?", False)
st.write("__________________________")
else:
add_pipeline = False
seed = int(
st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9))
)
seed = st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9))
seed_everything(seed)
save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, str(version)))
state = init_st(
model_specs[version_enum],
load_filter=True,
use_fp16=use_fp16,
enable_swap=enable_swap,
)
save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, version))
state = init_st(version_dict, load_filter=True)
if state["msg"]:
st.info(state["msg"])
model = state["model"]
is_legacy = specs.is_legacy
is_legacy = version_dict["is_legacy"]
prompt = st.text_input(
"prompt",
@@ -268,58 +290,47 @@ if __name__ == "__main__":
if add_pipeline:
st.write("__________________________")
version2 = ModelArchitecture(
st.selectbox(
"Refiner:",
[member.value for member in sdxl_refiner_model_list],
)
)
version2 = st.selectbox("Refiner:", ["SDXL-refiner-1.0", "SDXL-refiner-0.9"])
st.warning(
f"Running with {version2} as the second stage model. Make sure to provide (V)RAM :) "
)
st.write("**Refiner Options:**")
specs2 = model_specs[version2]
state2 = init_st(
specs2, load_filter=False, use_fp16=use_fp16, enable_swap=enable_swap
)
params2 = state2["params"]
version_dict2 = VERSION2SPECS[version2]
state2 = init_st(version_dict2, load_filter=False)
st.info(state2["msg"])
params2.img2img_strength = st.number_input(
stage2strength = st.number_input(
"**Refinement strength**", value=0.15, min_value=0.0, max_value=1.0
)
params2, *_ = init_sampling(
params=state2["params"],
sampler2, *_ = init_sampling(
key=2,
img2img_strength=stage2strength,
specify_num_samples=False,
)
st.write("__________________________")
finish_denoising = st.checkbox("Finish denoising with refiner.", True)
if finish_denoising:
stage2strength = params2.img2img_strength
else:
if not finish_denoising:
stage2strength = None
else:
state2 = None
params2 = None
stage2strength = None
if mode == "txt2img":
out = run_txt2img(
state=state,
model_id=version_enum,
prompt=prompt,
negative_prompt=negative_prompt,
state,
version,
version_dict,
is_legacy=is_legacy,
return_latents=add_pipeline,
filter=state.get("filter"),
stage2strength=stage2strength,
)
elif mode == "img2img":
out = run_img2img(
state=state,
prompt=prompt,
negative_prompt=negative_prompt,
state,
version_dict,
is_legacy=is_legacy,
return_latents=add_pipeline,
filter=state.get("filter"),
stage2strength=stage2strength,
)
else:
@@ -331,17 +342,17 @@ if __name__ == "__main__":
samples_z = None
if add_pipeline and samples_z is not None:
outputs = st.empty()
st.write("**Running Refinement Stage**")
samples = apply_refiner(
input=samples_z,
state=state2,
num_samples=samples_z.shape[0],
samples_z,
state2,
sampler2,
samples_z.shape[0],
prompt=prompt,
negative_prompt=negative_prompt if is_legacy else "",
filter=state.get("filter"),
finish_denoising=finish_denoising,
)
show_samples(samples, outputs)
if save_locally and samples is not None:
perform_save_locally(save_path, samples)

View File

@@ -1,68 +1,166 @@
import math
import os
from typing import List, Union
import numpy as np
import streamlit as st
import torch
from einops import rearrange, repeat
from imwatermark import WatermarkEncoder
from omegaconf import ListConfig, OmegaConf
from PIL import Image
from safetensors.torch import load_file as load_safetensors
from torch import autocast
from torchvision import transforms
from typing import Optional, Tuple, Dict, Any
from torchvision.utils import make_grid
from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering
from sgm.inference.api import (
Discretization,
Guider,
Sampler,
SamplingParams,
SamplingSpec,
SamplingPipeline,
Thresholder,
from sgm.modules.diffusionmodules.sampling import (
DPMPP2MSampler,
DPMPP2SAncestralSampler,
EulerAncestralSampler,
EulerEDMSampler,
HeunEDMSampler,
LinearMultistepSampler,
)
from sgm.inference.helpers import embed_watermark, CudaModelManager
from sgm.util import append_dims, instantiate_from_config
class WatermarkEmbedder:
def __init__(self, watermark):
self.watermark = watermark
self.num_bits = len(WATERMARK_BITS)
self.encoder = WatermarkEncoder()
self.encoder.set_watermark("bits", self.watermark)
def __call__(self, image: torch.Tensor):
"""
Adds a predefined watermark to the input image
Args:
image: ([N,] B, C, H, W) in range [0, 1]
Returns:
same as input but watermarked
"""
# watermarking libary expects input as cv2 BGR format
squeeze = len(image.shape) == 4
if squeeze:
image = image[None, ...]
n = image.shape[0]
image_np = rearrange(
(255 * image).detach().cpu(), "n b c h w -> (n b) h w c"
).numpy()[:, :, :, ::-1]
# torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
for k in range(image_np.shape[0]):
image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
image = torch.from_numpy(
rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)
).to(image.device)
image = torch.clamp(image / 255, min=0.0, max=1.0)
if squeeze:
image = image[0]
return image
# A fixed 48-bit message that was choosen at random
# WATERMARK_MESSAGE = 0xB3EC907BB19E
WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110
# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
embed_watemark = WatermarkEmbedder(WATERMARK_BITS)
@st.cache_resource()
def init_st(
spec: SamplingSpec,
load_ckpt=True,
load_filter=True,
use_fp16=True,
enable_swap=True,
) -> Dict[str, Any]:
state: Dict[str, Any] = dict()
config = spec.config
ckpt = spec.ckpt
def init_st(version_dict, load_ckpt=True, load_filter=True):
state = dict()
if not "model" in state:
config = version_dict["config"]
ckpt = version_dict["ckpt"]
if enable_swap:
pipeline = SamplingPipeline(
model_spec=spec,
use_fp16=use_fp16,
device=CudaModelManager(device="cuda", swap_device="cpu"),
)
else:
pipeline = SamplingPipeline(model_spec=spec, use_fp16=use_fp16)
config = OmegaConf.load(config)
model, msg = load_model_from_config(config, ckpt if load_ckpt else None)
state["spec"] = spec
state["model"] = pipeline
state["ckpt"] = ckpt if load_ckpt else None
state["config"] = config
state["params"] = spec.default_params
if load_filter:
state["filter"] = DeepFloydDataFiltering(verbose=False)
else:
state["filter"] = None
state["msg"] = msg
state["model"] = model
state["ckpt"] = ckpt if load_ckpt else None
state["config"] = config
if load_filter:
state["filter"] = DeepFloydDataFiltering(verbose=False)
return state
def load_model(model):
model.cuda()
lowvram_mode = False
def set_lowvram_mode(mode):
global lowvram_mode
lowvram_mode = mode
def initial_model_load(model):
global lowvram_mode
if lowvram_mode:
model.model.half()
else:
model.cuda()
return model
def unload_model(model):
global lowvram_mode
if lowvram_mode:
model.cpu()
torch.cuda.empty_cache()
def load_model_from_config(config, ckpt=None, verbose=True):
model = instantiate_from_config(config.model)
if ckpt is not None:
print(f"Loading model from {ckpt}")
if ckpt.endswith("ckpt"):
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
global_step = pl_sd["global_step"]
st.info(f"loaded ckpt from global step {global_step}")
print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"]
elif ckpt.endswith("safetensors"):
sd = load_safetensors(ckpt)
else:
raise NotImplementedError
msg = None
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print("missing keys:")
print(m)
if len(u) > 0 and verbose:
print("unexpected keys:")
print(u)
else:
msg = None
model = initial_model_load(model)
model.eval()
return model, msg
def get_unique_embedder_keys_from_conditioner(conditioner):
return list(set([x.input_key for x in conditioner.embedders]))
def init_embedder_options(
keys, params: SamplingParams, prompt=None, negative_prompt=None
) -> SamplingParams:
def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None):
# Hardcoded demo settings; might undergo some changes in the future
value_dict = {}
for key in keys:
if key == "txt":
if prompt is None:
@@ -72,38 +170,46 @@ def init_embedder_options(
if negative_prompt is None:
negative_prompt = st.text_input("Negative prompt", "")
value_dict["prompt"] = prompt
value_dict["negative_prompt"] = negative_prompt
if key == "original_size_as_tuple":
orig_width = st.number_input(
"orig_width",
value=params.orig_width,
value=init_dict["orig_width"],
min_value=16,
)
orig_height = st.number_input(
"orig_height",
value=params.orig_height,
value=init_dict["orig_height"],
min_value=16,
)
params.orig_width = int(orig_width)
params.orig_height = int(orig_height)
value_dict["orig_width"] = orig_width
value_dict["orig_height"] = orig_height
if key == "crop_coords_top_left":
crop_coord_top = st.number_input(
"crop_coords_top", value=params.crop_coords_top, min_value=0
)
crop_coord_left = st.number_input(
"crop_coords_left", value=params.crop_coords_left, min_value=0
)
crop_coord_top = st.number_input("crop_coords_top", value=0, min_value=0)
crop_coord_left = st.number_input("crop_coords_left", value=0, min_value=0)
params.crop_coords_top = int(crop_coord_top)
params.crop_coords_left = int(crop_coord_left)
return params
value_dict["crop_coords_top"] = crop_coord_top
value_dict["crop_coords_left"] = crop_coord_left
if key == "aesthetic_score":
value_dict["aesthetic_score"] = 6.0
value_dict["negative_aesthetic_score"] = 2.5
if key == "target_size_as_tuple":
value_dict["target_width"] = init_dict["target_width"]
value_dict["target_height"] = init_dict["target_height"]
return value_dict
def perform_save_locally(save_path, samples):
os.makedirs(os.path.join(save_path), exist_ok=True)
base_count = len(os.listdir(os.path.join(save_path)))
samples = embed_watermark(samples)
samples = embed_watemark(samples)
for sample in samples:
sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c")
Image.fromarray(sample.astype(np.uint8)).save(
@@ -122,26 +228,78 @@ def init_save_locally(_dir, init_value: bool = False):
return save_locally, save_path
def show_samples(samples, outputs):
if isinstance(samples, tuple):
samples, _ = samples
grid = embed_watermark(torch.stack([samples]))
grid = rearrange(grid, "n b c h w -> (n h) (b w) c")
outputs.image(grid.cpu().numpy())
class Img2ImgDiscretizationWrapper:
"""
wraps a discretizer, and prunes the sigmas
params:
strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned)
"""
def __init__(self, discretization, strength: float = 1.0):
self.discretization = discretization
self.strength = strength
assert 0.0 <= self.strength <= 1.0
def __call__(self, *args, **kwargs):
# sigmas start large first, and decrease then
sigmas = self.discretization(*args, **kwargs)
print(f"sigmas after discretization, before pruning img2img: ", sigmas)
sigmas = torch.flip(sigmas, (0,))
sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)]
print("prune index:", max(int(self.strength * len(sigmas)), 1))
sigmas = torch.flip(sigmas, (0,))
print(f"sigmas after pruning: ", sigmas)
return sigmas
def get_guider(params: SamplingParams, key=1) -> SamplingParams:
params.guider = Guider(
st.sidebar.selectbox(
f"Discretization #{key}", [member.value for member in Guider]
)
class Txt2NoisyDiscretizationWrapper:
"""
wraps a discretizer, and prunes the sigmas
params:
strength: float between 0.0 and 1.0. 0.0 means full sampling (all sigmas are returned)
"""
def __init__(self, discretization, strength: float = 0.0, original_steps=None):
self.discretization = discretization
self.strength = strength
self.original_steps = original_steps
assert 0.0 <= self.strength <= 1.0
def __call__(self, *args, **kwargs):
# sigmas start large first, and decrease then
sigmas = self.discretization(*args, **kwargs)
print(f"sigmas after discretization, before pruning img2img: ", sigmas)
sigmas = torch.flip(sigmas, (0,))
if self.original_steps is None:
steps = len(sigmas)
else:
steps = self.original_steps + 1
prune_index = max(min(int(self.strength * steps) - 1, steps - 1), 0)
sigmas = sigmas[prune_index:]
print("prune index:", prune_index)
sigmas = torch.flip(sigmas, (0,))
print(f"sigmas after pruning: ", sigmas)
return sigmas
def get_guider(key):
guider = st.sidebar.selectbox(
f"Discretization #{key}",
[
"VanillaCFG",
"IdentityGuider",
],
)
if params.guider == Guider.VANILLA:
if guider == "IdentityGuider":
guider_config = {
"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"
}
elif guider == "VanillaCFG":
scale = st.number_input(
f"cfg-scale #{key}", value=params.scale, min_value=0.0, max_value=100.0
f"cfg-scale #{key}", value=5.0, min_value=0.0, max_value=100.0
)
params.scale = scale
thresholder = st.sidebar.selectbox(
f"Thresholder #{key}",
[
@@ -150,97 +308,182 @@ def get_guider(params: SamplingParams, key=1) -> SamplingParams:
)
if thresholder == "None":
params.thresholder = Thresholder.NONE
dyn_thresh_config = {
"target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"
}
else:
raise NotImplementedError
return params
guider_config = {
"target": "sgm.modules.diffusionmodules.guiders.VanillaCFG",
"params": {"scale": scale, "dyn_thresh_config": dyn_thresh_config},
}
else:
raise NotImplementedError
return guider_config
def init_sampling(
params: SamplingParams,
key=1,
img2img_strength=1.0,
specify_num_samples=True,
) -> Tuple[SamplingParams, int, int]:
stage2strength=None,
):
num_rows, num_cols = 1, 1
if specify_num_samples:
num_cols = st.number_input(
f"num cols #{key}", value=2, min_value=1, max_value=10
)
params.steps = int(
st.sidebar.number_input(
f"steps #{key}", value=params.steps, min_value=1, max_value=1000
)
steps = st.sidebar.number_input(
f"steps #{key}", value=40, min_value=1, max_value=1000
)
sampler = st.sidebar.selectbox(
f"Sampler #{key}",
[
"EulerEDMSampler",
"HeunEDMSampler",
"EulerAncestralSampler",
"DPMPP2SAncestralSampler",
"DPMPP2MSampler",
"LinearMultistepSampler",
],
0,
)
discretization = st.sidebar.selectbox(
f"Discretization #{key}",
[
"LegacyDDPMDiscretization",
"EDMDiscretization",
],
)
params.sampler = Sampler(
st.sidebar.selectbox(
f"Sampler #{key}",
[member.value for member in Sampler],
0,
discretization_config = get_discretization(discretization, key=key)
guider_config = get_guider(key=key)
sampler = get_sampler(sampler, steps, discretization_config, guider_config, key=key)
if img2img_strength < 1.0:
st.warning(
f"Wrapping {sampler.__class__.__name__} with Img2ImgDiscretizationWrapper"
)
)
params.discretization = Discretization(
st.sidebar.selectbox(
f"Discretization #{key}",
[member.value for member in Discretization],
sampler.discretization = Img2ImgDiscretizationWrapper(
sampler.discretization, strength=img2img_strength
)
)
params = get_discretization(params=params, key=key)
params = get_guider(params=params, key=key)
params = get_sampler(params=params, key=key)
return params, num_rows, num_cols
def get_discretization(params: SamplingParams, key=1) -> SamplingParams:
if params.discretization == Discretization.EDM:
params.sigma_min = st.number_input(f"sigma_min #{key}", value=params.sigma_min)
params.sigma_max = st.number_input(f"sigma_max #{key}", value=params.sigma_max)
params.rho = st.number_input(f"rho #{key}", value=params.rho)
return params
def get_sampler(params: SamplingParams, key=1) -> SamplingParams:
if params.sampler in (Sampler.EULER_EDM, Sampler.HEUN_EDM):
params.s_churn = st.sidebar.number_input(
f"s_churn #{key}", value=params.s_churn, min_value=0.0
if stage2strength is not None:
sampler.discretization = Txt2NoisyDiscretizationWrapper(
sampler.discretization, strength=stage2strength, original_steps=steps
)
params.s_tmin = st.sidebar.number_input(
f"s_tmin #{key}", value=params.s_tmin, min_value=0.0
)
params.s_tmax = st.sidebar.number_input(
f"s_tmax #{key}", value=params.s_tmax, min_value=0.0
)
params.s_noise = st.sidebar.number_input(
f"s_noise #{key}", value=params.s_noise, min_value=0.0
)
elif params.sampler in (Sampler.EULER_ANCESTRAL, Sampler.DPMPP2S_ANCESTRAL):
params.s_noise = st.sidebar.number_input(
"s_noise", value=params.s_noise, min_value=0.0
)
params.eta = st.sidebar.number_input("eta", value=params.eta, min_value=0.0)
elif params.sampler == Sampler.LINEAR_MULTISTEP:
params.order = int(
st.sidebar.number_input("order", value=params.order, min_value=1)
)
return params
return sampler, num_rows, num_cols
def get_interactive_image(key=None) -> Optional[Image.Image]:
def get_discretization(discretization, key=1):
if discretization == "LegacyDDPMDiscretization":
discretization_config = {
"target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",
}
elif discretization == "EDMDiscretization":
sigma_min = st.number_input(f"sigma_min #{key}", value=0.03) # 0.0292
sigma_max = st.number_input(f"sigma_max #{key}", value=14.61) # 14.6146
rho = st.number_input(f"rho #{key}", value=3.0)
discretization_config = {
"target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization",
"params": {
"sigma_min": sigma_min,
"sigma_max": sigma_max,
"rho": rho,
},
}
return discretization_config
def get_sampler(sampler_name, steps, discretization_config, guider_config, key=1):
if sampler_name == "EulerEDMSampler" or sampler_name == "HeunEDMSampler":
s_churn = st.sidebar.number_input(f"s_churn #{key}", value=0.0, min_value=0.0)
s_tmin = st.sidebar.number_input(f"s_tmin #{key}", value=0.0, min_value=0.0)
s_tmax = st.sidebar.number_input(f"s_tmax #{key}", value=999.0, min_value=0.0)
s_noise = st.sidebar.number_input(f"s_noise #{key}", value=1.0, min_value=0.0)
if sampler_name == "EulerEDMSampler":
sampler = EulerEDMSampler(
num_steps=steps,
discretization_config=discretization_config,
guider_config=guider_config,
s_churn=s_churn,
s_tmin=s_tmin,
s_tmax=s_tmax,
s_noise=s_noise,
verbose=True,
)
elif sampler_name == "HeunEDMSampler":
sampler = HeunEDMSampler(
num_steps=steps,
discretization_config=discretization_config,
guider_config=guider_config,
s_churn=s_churn,
s_tmin=s_tmin,
s_tmax=s_tmax,
s_noise=s_noise,
verbose=True,
)
elif (
sampler_name == "EulerAncestralSampler"
or sampler_name == "DPMPP2SAncestralSampler"
):
s_noise = st.sidebar.number_input("s_noise", value=1.0, min_value=0.0)
eta = st.sidebar.number_input("eta", value=1.0, min_value=0.0)
if sampler_name == "EulerAncestralSampler":
sampler = EulerAncestralSampler(
num_steps=steps,
discretization_config=discretization_config,
guider_config=guider_config,
eta=eta,
s_noise=s_noise,
verbose=True,
)
elif sampler_name == "DPMPP2SAncestralSampler":
sampler = DPMPP2SAncestralSampler(
num_steps=steps,
discretization_config=discretization_config,
guider_config=guider_config,
eta=eta,
s_noise=s_noise,
verbose=True,
)
elif sampler_name == "DPMPP2MSampler":
sampler = DPMPP2MSampler(
num_steps=steps,
discretization_config=discretization_config,
guider_config=guider_config,
verbose=True,
)
elif sampler_name == "LinearMultistepSampler":
order = st.sidebar.number_input("order", value=4, min_value=1)
sampler = LinearMultistepSampler(
num_steps=steps,
discretization_config=discretization_config,
guider_config=guider_config,
order=order,
verbose=True,
)
else:
raise ValueError(f"unknown sampler {sampler_name}!")
return sampler
def get_interactive_image(key=None) -> Image.Image:
image = st.file_uploader("Input", type=["jpg", "JPEG", "png"], key=key)
if image is not None:
image = Image.open(image)
if not image.mode == "RGB":
image = image.convert("RGB")
return image
return None
def load_img(display=True, key=None) -> Optional[torch.Tensor]:
def load_img(display=True, key=None):
image = get_interactive_image(key=key)
if image is None:
return None
@@ -264,3 +507,238 @@ def get_init_img(batch_size=1, key=None):
init_image = load_img(key=key).cuda()
init_image = repeat(init_image, "1 ... -> b ...", b=batch_size)
return init_image
def do_sample(
model,
sampler,
value_dict,
num_samples,
H,
W,
C,
F,
force_uc_zero_embeddings: List = None,
batch2model_input: List = None,
return_latents=False,
filter=None,
):
if force_uc_zero_embeddings is None:
force_uc_zero_embeddings = []
if batch2model_input is None:
batch2model_input = []
st.text("Sampling")
outputs = st.empty()
precision_scope = autocast
with torch.no_grad():
with precision_scope("cuda"):
with model.ema_scope():
num_samples = [num_samples]
load_model(model.conditioner)
batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner),
value_dict,
num_samples,
)
for key in batch:
if isinstance(batch[key], torch.Tensor):
print(key, batch[key].shape)
elif isinstance(batch[key], list):
print(key, [len(l) for l in batch[key]])
else:
print(key, batch[key])
c, uc = model.conditioner.get_unconditional_conditioning(
batch,
batch_uc=batch_uc,
force_uc_zero_embeddings=force_uc_zero_embeddings,
)
unload_model(model.conditioner)
for k in c:
if not k == "crossattn":
c[k], uc[k] = map(
lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc)
)
additional_model_inputs = {}
for k in batch2model_input:
additional_model_inputs[k] = batch[k]
shape = (math.prod(num_samples), C, H // F, W // F)
randn = torch.randn(shape).to("cuda")
def denoiser(input, sigma, c):
return model.denoiser(
model.model, input, sigma, c, **additional_model_inputs
)
load_model(model.denoiser)
load_model(model.model)
samples_z = sampler(denoiser, randn, cond=c, uc=uc)
unload_model(model.model)
unload_model(model.denoiser)
load_model(model.first_stage_model)
samples_x = model.decode_first_stage(samples_z)
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
unload_model(model.first_stage_model)
if filter is not None:
samples = filter(samples)
grid = torch.stack([samples])
grid = rearrange(grid, "n b c h w -> (n h) (b w) c")
outputs.image(grid.cpu().numpy())
if return_latents:
return samples, samples_z
return samples
def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
# Hardcoded demo setups; might undergo some changes in the future
batch = {}
batch_uc = {}
for key in keys:
if key == "txt":
batch["txt"] = (
np.repeat([value_dict["prompt"]], repeats=math.prod(N))
.reshape(N)
.tolist()
)
batch_uc["txt"] = (
np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N))
.reshape(N)
.tolist()
)
elif key == "original_size_as_tuple":
batch["original_size_as_tuple"] = (
torch.tensor([value_dict["orig_height"], value_dict["orig_width"]])
.to(device)
.repeat(*N, 1)
)
elif key == "crop_coords_top_left":
batch["crop_coords_top_left"] = (
torch.tensor(
[value_dict["crop_coords_top"], value_dict["crop_coords_left"]]
)
.to(device)
.repeat(*N, 1)
)
elif key == "aesthetic_score":
batch["aesthetic_score"] = (
torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
)
batch_uc["aesthetic_score"] = (
torch.tensor([value_dict["negative_aesthetic_score"]])
.to(device)
.repeat(*N, 1)
)
elif key == "target_size_as_tuple":
batch["target_size_as_tuple"] = (
torch.tensor([value_dict["target_height"], value_dict["target_width"]])
.to(device)
.repeat(*N, 1)
)
else:
batch[key] = value_dict[key]
for key in batch.keys():
if key not in batch_uc and isinstance(batch[key], torch.Tensor):
batch_uc[key] = torch.clone(batch[key])
return batch, batch_uc
@torch.no_grad()
def do_img2img(
img,
model,
sampler,
value_dict,
num_samples,
force_uc_zero_embeddings=[],
additional_kwargs={},
offset_noise_level: int = 0.0,
return_latents=False,
skip_encode=False,
filter=None,
add_noise=True,
):
st.text("Sampling")
outputs = st.empty()
precision_scope = autocast
with torch.no_grad():
with precision_scope("cuda"):
with model.ema_scope():
load_model(model.conditioner)
batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner),
value_dict,
[num_samples],
)
c, uc = model.conditioner.get_unconditional_conditioning(
batch,
batch_uc=batch_uc,
force_uc_zero_embeddings=force_uc_zero_embeddings,
)
unload_model(model.conditioner)
for k in c:
c[k], uc[k] = map(lambda y: y[k][:num_samples].to("cuda"), (c, uc))
for k in additional_kwargs:
c[k] = uc[k] = additional_kwargs[k]
if skip_encode:
z = img
else:
load_model(model.first_stage_model)
z = model.encode_first_stage(img)
unload_model(model.first_stage_model)
noise = torch.randn_like(z)
sigmas = sampler.discretization(sampler.num_steps).cuda()
sigma = sigmas[0]
st.info(f"all sigmas: {sigmas}")
st.info(f"noising sigma: {sigma}")
if offset_noise_level > 0.0:
noise = noise + offset_noise_level * append_dims(
torch.randn(z.shape[0], device=z.device), z.ndim
)
if add_noise:
noised_z = z + noise * append_dims(sigma, z.ndim).cuda()
noised_z = noised_z / torch.sqrt(
1.0 + sigmas[0] ** 2.0
) # Note: hardcoded to DDPM-like scaling. need to generalize later.
else:
noised_z = z / torch.sqrt(1.0 + sigmas[0] ** 2.0)
def denoiser(x, sigma, c):
return model.denoiser(model.model, x, sigma, c)
load_model(model.denoiser)
load_model(model.model)
samples_z = sampler(denoiser, noised_z, cond=c, uc=uc)
unload_model(model.model)
unload_model(model.denoiser)
load_model(model.first_stage_model)
samples_x = model.decode_first_stage(samples_z)
unload_model(model.first_stage_model)
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
if filter is not None:
samples = filter(samples)
grid = embed_watemark(torch.stack([samples]))
grid = rearrange(grid, "n b c h w -> (n h) (b w) c")
outputs.image(grid.cpu().numpy())
if return_latents:
return samples, samples_z
return samples

View File

@@ -1,4 +1,4 @@
from .models import AutoencodingEngine, DiffusionEngine
from .util import get_configs_path, instantiate_from_config
__version__ = "0.1.1"
__version__ = "0.1.0"

View File

@@ -1,14 +1,11 @@
from dataclasses import dataclass, asdict
from enum import Enum
from omegaconf import OmegaConf
import os
import pathlib
from sgm.inference.helpers import (
do_sample,
do_img2img,
DeviceModelManager,
get_model_manager,
Img2ImgDiscretizationWrapper,
Txt2NoisyDiscretizationWrapper,
)
from sgm.modules.diffusionmodules.sampling import (
EulerEDMSampler,
@@ -18,18 +15,17 @@ from sgm.modules.diffusionmodules.sampling import (
DPMPP2MSampler,
LinearMultistepSampler,
)
from sgm.util import load_model_from_config, get_configs_path, get_checkpoints_path
import torch
from typing import Optional, Dict, Any, Union
from sgm.util import load_model_from_config
from typing import Optional
class ModelArchitecture(str, Enum):
SDXL_V1_0_BASE = "stable-diffusion-xl-v1-base"
SDXL_V1_0_REFINER = "stable-diffusion-xl-v1-refiner"
SDXL_V0_9_BASE = "stable-diffusion-xl-v0-9-base"
SDXL_V0_9_REFINER = "stable-diffusion-xl-v0-9-refiner"
SD_2_1 = "stable-diffusion-v2-1"
SD_2_1_768 = "stable-diffusion-v2-1-768"
SDXL_V0_9_BASE = "stable-diffusion-xl-v0-9-base"
SDXL_V0_9_REFINER = "stable-diffusion-xl-v0-9-refiner"
SDXL_V1_BASE = "stable-diffusion-xl-v1-base"
SDXL_V1_REFINER = "stable-diffusion-xl-v1-refiner"
class Sampler(str, Enum):
@@ -57,20 +53,16 @@ class Thresholder(str, Enum):
@dataclass
class SamplingParams:
"""
Parameters for sampling.
"""
width: Optional[int] = None
height: Optional[int] = None
steps: Optional[int] = None
sampler: Sampler = Sampler.EULER_EDM
width: int = 1024
height: int = 1024
steps: int = 50
sampler: Sampler = Sampler.DPMPP2M
discretization: Discretization = Discretization.LEGACY_DDPM
guider: Guider = Guider.VANILLA
thresholder: Thresholder = Thresholder.NONE
scale: float = 5.0
aesthetic_score: float = 6.0
negative_aesthetic_score: float = 2.5
scale: float = 6.0
aesthetic_score: float = 5.0
negative_aesthetic_score: float = 5.0
img2img_strength: float = 1.0
orig_width: int = 1024
orig_height: int = 1024
@@ -97,10 +89,8 @@ class SamplingSpec:
config: str
ckpt: str
is_guided: bool
default_params: SamplingParams
# The defaults here are derived from user preference testing.
model_specs = {
ModelArchitecture.SD_2_1: SamplingSpec(
height=512,
@@ -111,12 +101,6 @@ model_specs = {
config="sd_2_1.yaml",
ckpt="v2-1_512-ema-pruned.safetensors",
is_guided=True,
default_params=SamplingParams(
width=512,
height=512,
steps=40,
scale=7.0,
),
),
ModelArchitecture.SD_2_1_768: SamplingSpec(
height=768,
@@ -127,12 +111,6 @@ model_specs = {
config="sd_2_1_768.yaml",
ckpt="v2-1_768-ema-pruned.safetensors",
is_guided=True,
default_params=SamplingParams(
width=768,
height=768,
steps=40,
scale=7.0,
),
),
ModelArchitecture.SDXL_V0_9_BASE: SamplingSpec(
height=1024,
@@ -143,7 +121,6 @@ model_specs = {
config="sd_xl_base.yaml",
ckpt="sd_xl_base_0.9.safetensors",
is_guided=True,
default_params=SamplingParams(width=1024, height=1024, steps=40, scale=5.0),
),
ModelArchitecture.SDXL_V0_9_REFINER: SamplingSpec(
height=1024,
@@ -154,11 +131,8 @@ model_specs = {
config="sd_xl_refiner.yaml",
ckpt="sd_xl_refiner_0.9.safetensors",
is_guided=True,
default_params=SamplingParams(
width=1024, height=1024, steps=40, scale=5.0, img2img_strength=0.15
),
),
ModelArchitecture.SDXL_V1_0_BASE: SamplingSpec(
ModelArchitecture.SDXL_V1_BASE: SamplingSpec(
height=1024,
width=1024,
channels=4,
@@ -167,9 +141,8 @@ model_specs = {
config="sd_xl_base.yaml",
ckpt="sd_xl_base_1.0.safetensors",
is_guided=True,
default_params=SamplingParams(width=1024, height=1024, steps=40, scale=5.0),
),
ModelArchitecture.SDXL_V1_0_REFINER: SamplingSpec(
ModelArchitecture.SDXL_V1_REFINER: SamplingSpec(
height=1024,
width=1024,
channels=4,
@@ -178,97 +151,34 @@ model_specs = {
config="sd_xl_refiner.yaml",
ckpt="sd_xl_refiner_1.0.safetensors",
is_guided=True,
default_params=SamplingParams(
width=1024, height=1024, steps=40, scale=5.0, img2img_strength=0.15
),
),
}
def wrap_discretization(
discretization, image_strength=None, noise_strength=None, steps=None
):
if isinstance(discretization, Img2ImgDiscretizationWrapper) or isinstance(
discretization, Txt2NoisyDiscretizationWrapper
):
return discretization # Already wrapped
if image_strength is not None and image_strength < 1.0 and image_strength > 0.0:
discretization = Img2ImgDiscretizationWrapper(
discretization, strength=image_strength
)
if (
noise_strength is not None
and noise_strength < 1.0
and noise_strength > 0.0
and steps is not None
):
discretization = Txt2NoisyDiscretizationWrapper(
discretization,
strength=noise_strength,
original_steps=steps,
)
return discretization
class SamplingPipeline:
def __init__(
self,
model_id: Optional[ModelArchitecture] = None,
model_spec: Optional[SamplingSpec] = None,
model_path: Optional[str] = None,
config_path: Optional[str] = None,
use_fp16: bool = True,
device: Optional[Union[DeviceModelManager, str, torch.device]] = None,
model_id: ModelArchitecture,
model_path="checkpoints",
config_path="configs/inference",
device="cuda",
use_fp16=True,
) -> None:
"""
Sampling pipeline for generating images from a model.
@param model_id: Model architecture to use. If not specified, model_spec must be specified.
@param model_spec: Model specification to use. If not specified, model_id must be specified.
@param model_path: Path to model checkpoints folder.
@param config_path: Path to model config folder.
@param use_fp16: Whether to use fp16 for sampling.
@param device: Device manager to use with this pipeline. If a string or torch.device is passed, a device manager will be created based on device type if possible.
"""
if model_id not in model_specs:
raise ValueError(f"Model {model_id} not supported")
self.model_id = model_id
if model_spec is not None:
self.specs = model_spec
elif model_id is not None:
if model_id not in model_specs:
raise ValueError(f"Model {model_id} not supported")
self.specs = model_specs[model_id]
else:
raise ValueError("Either model_id or model_spec should be provided")
self.specs = model_specs[self.model_id]
self.config = str(pathlib.Path(config_path, self.specs.config))
self.ckpt = str(pathlib.Path(model_path, self.specs.ckpt))
self.device = device
self.model = self._load_model(device=device, use_fp16=use_fp16)
if model_path is None:
model_path = get_checkpoints_path()
if config_path is None:
config_path = get_configs_path()
self.config = os.path.join(config_path, "inference", self.specs.config)
self.ckpt = os.path.join(model_path, self.specs.ckpt)
if not os.path.exists(self.config):
raise ValueError(
f"Config {self.config} not found, check model spec or config_path"
)
if not os.path.exists(self.ckpt):
raise ValueError(
f"Checkpoint {self.ckpt} not found, check model spec or config_path"
)
self.device_manager = get_model_manager(device)
self.model = self._load_model(
device_manager=self.device_manager, use_fp16=use_fp16
)
def _load_model(self, device_manager: DeviceModelManager, use_fp16=True):
def _load_model(self, device="cuda", use_fp16=True):
config = OmegaConf.load(self.config)
model = load_model_from_config(config, self.ckpt)
if model is None:
raise ValueError(f"Model {self.model_id} could not be loaded")
device_manager.load(model)
model.to(device)
if use_fp16:
model.conditioner.half()
model.model.half()
@@ -276,34 +186,13 @@ class SamplingPipeline:
def text_to_image(
self,
params: SamplingParams,
prompt: str,
params: Optional[SamplingParams] = None,
negative_prompt: str = "",
samples: int = 1,
return_latents: bool = False,
noise_strength: Optional[float] = None,
filter=None,
):
if params is None:
params = self.specs.default_params
else:
# Set defaults if optional params are not specified
if params.width is None:
params.width = self.specs.default_params.width
if params.height is None:
params.height = self.specs.default_params.height
if params.steps is None:
params.steps = self.specs.default_params.steps
sampler = get_sampler_config(params)
sampler.discretization = wrap_discretization(
sampler.discretization,
image_strength=None,
noise_strength=noise_strength,
steps=params.steps,
)
value_dict = asdict(params)
value_dict["prompt"] = prompt
value_dict["negative_prompt"] = negative_prompt
@@ -320,40 +209,31 @@ class SamplingPipeline:
self.specs.factor,
force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
return_latents=return_latents,
filter=filter,
device=self.device_manager,
filter=None,
)
def image_to_image(
self,
image: torch.Tensor,
params: SamplingParams,
image,
prompt: str,
params: Optional[SamplingParams] = None,
negative_prompt: str = "",
samples: int = 1,
return_latents: bool = False,
noise_strength: Optional[float] = None,
filter=None,
):
if params is None:
params = self.specs.default_params
sampler = get_sampler_config(params)
sampler.discretization = wrap_discretization(
sampler.discretization,
image_strength=params.img2img_strength,
noise_strength=noise_strength,
steps=params.steps,
)
if params.img2img_strength < 1.0:
sampler.discretization = Img2ImgDiscretizationWrapper(
sampler.discretization,
strength=params.img2img_strength,
)
height, width = image.shape[2], image.shape[3]
value_dict = asdict(params)
value_dict["prompt"] = prompt
value_dict["negative_prompt"] = negative_prompt
value_dict["target_width"] = width
value_dict["target_height"] = height
value_dict["orig_width"] = width
value_dict["orig_height"] = height
return do_img2img(
image,
self.model,
@@ -362,24 +242,18 @@ class SamplingPipeline:
samples,
force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
return_latents=return_latents,
filter=filter,
device=self.device_manager,
filter=None,
)
def refiner(
self,
image: torch.Tensor,
params: SamplingParams,
image,
prompt: str,
negative_prompt: str = "",
params: Optional[SamplingParams] = None,
negative_prompt: Optional[str] = None,
samples: int = 1,
return_latents: bool = False,
filter: Any = None,
add_noise: bool = False,
):
if params is None:
params = self.specs.default_params
sampler = get_sampler_config(params)
value_dict = {
"orig_width": image.shape[3] * 8,
@@ -394,10 +268,6 @@ class SamplingPipeline:
"negative_aesthetic_score": 2.5,
}
sampler.discretization = wrap_discretization(
sampler.discretization, image_strength=params.img2img_strength
)
return do_img2img(
image,
self.model,
@@ -406,14 +276,11 @@ class SamplingPipeline:
samples,
skip_encode=True,
return_latents=return_latents,
filter=filter,
add_noise=add_noise,
device=self.device_manager,
filter=None,
)
def get_guider_config(params: SamplingParams) -> Dict[str, Any]:
guider_config: Dict[str, Any]
def get_guider_config(params: SamplingParams):
if params.guider == Guider.IDENTITY:
guider_config = {
"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"
@@ -439,8 +306,7 @@ def get_guider_config(params: SamplingParams) -> Dict[str, Any]:
return guider_config
def get_discretization_config(params: SamplingParams) -> Dict[str, Any]:
discretization_config: Dict[str, Any]
def get_discretization_config(params: SamplingParams):
if params.discretization == Discretization.LEGACY_DDPM:
discretization_config = {
"target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",

View File

@@ -1,4 +1,3 @@
import contextlib
import os
from typing import Union, List, Optional
@@ -9,6 +8,7 @@ from PIL import Image
from einops import rearrange
from imwatermark import WatermarkEncoder
from omegaconf import ListConfig
from torch import autocast
from sgm.util import append_dims
@@ -58,73 +58,6 @@ WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
embed_watermark = WatermarkEmbedder(WATERMARK_BITS)
class DeviceModelManager(object):
"""
Default model loading class, should work for all device classes.
"""
def __init__(
self,
device: Union[torch.device, str],
swap_device: Optional[Union[torch.device, str]] = None,
) -> None:
"""
Args:
device (Union[torch.device, str]): The device to use for the model.
"""
self.device = torch.device(device)
self.swap_device = (
torch.device(swap_device) if swap_device is not None else self.device
)
def load(self, model: torch.nn.Module) -> None:
"""
Loads a model to the (swap) device.
"""
model.to(self.swap_device)
def autocast(self):
"""
Context manager that enables autocast for the device if supported.
"""
if self.device.type not in ("cuda", "cpu"):
return contextlib.nullcontext()
return torch.autocast(self.device.type)
@contextlib.contextmanager
def use(self, model: torch.nn.Module):
"""
Context manager that ensures a model is on the correct device during use.
The default model loader does not perform any swapping, so the model will
stay on device.
"""
try:
model.to(self.device)
yield
finally:
if self.device != self.swap_device:
model.to(self.swap_device)
class CudaModelManager(DeviceModelManager):
"""
Device manager that loads a model to a CUDA device, optionally swapping to CPU when not in use.
"""
@contextlib.contextmanager
def use(self, model: Union[torch.nn.Module, torch.Tensor]):
"""
Context manager that ensures a model is on the correct device during use.
If a swap device was provided, this will move the model to it after use and clear cache.
"""
model.to(self.device)
yield
if self.device != self.swap_device:
model.to(self.swap_device)
if torch.cuda.is_available():
torch.cuda.empty_cache()
def get_unique_embedder_keys_from_conditioner(conditioner):
return list({x.input_key for x in conditioner.embedders})
@@ -141,20 +74,6 @@ def perform_save_locally(save_path, samples):
base_count += 1
def get_model_manager(
device: Optional[Union[DeviceModelManager, str, torch.device]]
) -> DeviceModelManager:
if isinstance(device, DeviceModelManager):
return device
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(device)
if device.type == "cuda":
return CudaModelManager(device=device)
else:
return DeviceModelManager(device=device)
class Img2ImgDiscretizationWrapper:
"""
wraps a discretizer, and prunes the sigmas
@@ -179,36 +98,6 @@ class Img2ImgDiscretizationWrapper:
return sigmas
class Txt2NoisyDiscretizationWrapper:
"""
wraps a discretizer, and prunes the sigmas
params:
strength: float between 0.0 and 1.0. 0.0 means full sampling (all sigmas are returned)
"""
def __init__(self, discretization, strength: float = 0.0, original_steps=None):
self.discretization = discretization
self.strength = strength
self.original_steps = original_steps
assert 0.0 <= self.strength <= 1.0
def __call__(self, *args, **kwargs):
# sigmas start large first, and decrease then
sigmas = self.discretization(*args, **kwargs)
print(f"sigmas after discretization, before pruning img2img: ", sigmas)
sigmas = torch.flip(sigmas, (0,))
if self.original_steps is None:
steps = len(sigmas)
else:
steps = self.original_steps + 1
prune_index = max(min(int(self.strength * steps) - 1, steps - 1), 0)
sigmas = sigmas[prune_index:]
print("prune index:", prune_index)
sigmas = torch.flip(sigmas, (0,))
print(f"sigmas after pruning: ", sigmas)
return sigmas
def do_sample(
model,
sampler,
@@ -222,45 +111,39 @@ def do_sample(
batch2model_input: Optional[List] = None,
return_latents=False,
filter=None,
device: Optional[Union[DeviceModelManager, str, torch.device]] = None,
device="cuda",
):
if force_uc_zero_embeddings is None:
force_uc_zero_embeddings = []
if batch2model_input is None:
batch2model_input = []
device_manager = get_model_manager(device=device)
with torch.no_grad():
with device_manager.autocast():
with autocast(device) as precision_scope:
with model.ema_scope():
num_samples = [num_samples]
with device_manager.use(model.conditioner):
batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner),
value_dict,
num_samples,
)
for key in batch:
if isinstance(batch[key], torch.Tensor):
print(key, batch[key].shape)
elif isinstance(batch[key], list):
print(key, [len(l) for l in batch[key]])
else:
print(key, batch[key])
c, uc = model.conditioner.get_unconditional_conditioning(
batch,
batch_uc=batch_uc,
force_uc_zero_embeddings=force_uc_zero_embeddings,
)
batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner),
value_dict,
num_samples,
)
for key in batch:
if isinstance(batch[key], torch.Tensor):
print(key, batch[key].shape)
elif isinstance(batch[key], list):
print(key, [len(l) for l in batch[key]])
else:
print(key, batch[key])
c, uc = model.conditioner.get_unconditional_conditioning(
batch,
batch_uc=batch_uc,
force_uc_zero_embeddings=force_uc_zero_embeddings,
)
for k in c:
if not k == "crossattn":
c[k], uc[k] = map(
lambda y: y[k][: math.prod(num_samples)].to(
device_manager.device
),
(c, uc),
lambda y: y[k][: math.prod(num_samples)].to(device), (c, uc)
)
additional_model_inputs = {}
@@ -268,20 +151,16 @@ def do_sample(
additional_model_inputs[k] = batch[k]
shape = (math.prod(num_samples), C, H // F, W // F)
randn = torch.randn(shape).to(device_manager.device)
randn = torch.randn(shape).to(device)
def denoiser(input, sigma, c):
return model.denoiser(
model.model, input, sigma, c, **additional_model_inputs
)
with device_manager.use(model.denoiser):
with device_manager.use(model.model):
samples_z = sampler(denoiser, randn, cond=c, uc=uc)
with device_manager.use(model.first_stage_model):
samples_x = model.decode_first_stage(samples_z)
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
samples_z = sampler(denoiser, randn, cond=c, uc=uc)
samples_x = model.decode_first_stage(samples_z)
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
if filter is not None:
samples = filter(samples)
@@ -373,40 +252,32 @@ def do_img2img(
return_latents=False,
skip_encode=False,
filter=None,
add_noise=True,
device: Optional[Union[DeviceModelManager, str, torch.device]] = None,
device="cuda",
):
device_manager = get_model_manager(device)
with torch.no_grad():
with device_manager.autocast():
with autocast(device) as precision_scope:
with model.ema_scope():
with device_manager.use(model.conditioner):
batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner),
value_dict,
[num_samples],
)
c, uc = model.conditioner.get_unconditional_conditioning(
batch,
batch_uc=batch_uc,
force_uc_zero_embeddings=force_uc_zero_embeddings,
)
batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner),
value_dict,
[num_samples],
)
c, uc = model.conditioner.get_unconditional_conditioning(
batch,
batch_uc=batch_uc,
force_uc_zero_embeddings=force_uc_zero_embeddings,
)
for k in c:
c[k], uc[k] = map(
lambda y: y[k][:num_samples].to(device_manager.device), (c, uc)
)
c[k], uc[k] = map(lambda y: y[k][:num_samples].to(device), (c, uc))
for k in additional_kwargs:
c[k] = uc[k] = additional_kwargs[k]
if skip_encode:
z = img
else:
with device_manager.use(model.first_stage_model):
z = model.encode_first_stage(img)
z = model.encode_first_stage(img)
noise = torch.randn_like(z)
sigmas = sampler.discretization(sampler.num_steps)
sigma = sigmas[0].to(z.device)
@@ -414,24 +285,17 @@ def do_img2img(
noise = noise + offset_noise_level * append_dims(
torch.randn(z.shape[0], device=z.device), z.ndim
)
if add_noise:
noised_z = z + noise * append_dims(sigma, z.ndim).cuda()
noised_z = noised_z / torch.sqrt(
1.0 + sigmas[0] ** 2.0
) # Note: hardcoded to DDPM-like scaling. need to generalize later.
else:
noised_z = z / torch.sqrt(1.0 + sigmas[0] ** 2.0)
noised_z = z + noise * append_dims(sigma, z.ndim)
noised_z = noised_z / torch.sqrt(
1.0 + sigmas[0] ** 2.0
) # Note: hardcoded to DDPM-like scaling. need to generalize later.
def denoiser(x, sigma, c):
return model.denoiser(model.model, x, sigma, c)
with device_manager.use(model.denoiser):
with device_manager.use(model.model):
samples_z = sampler(denoiser, noised_z, cond=c, uc=uc)
with device_manager.use(model.first_stage_model):
samples_x = model.decode_first_stage(samples_z)
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
samples_z = sampler(denoiser, noised_z, cond=c, uc=uc)
samples_x = model.decode_first_stage(samples_z)
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
if filter is not None:
samples = filter(samples)

View File

@@ -26,7 +26,7 @@ class Discretization:
class EDMDiscretization(Discretization):
def __init__(self, sigma_min=0.002, sigma_max=80.0, rho=7.0):
def __init__(self, sigma_min=0.02, sigma_max=80.0, rho=7.0):
self.sigma_min = sigma_min
self.sigma_max = sigma_max
self.rho = rho

View File

@@ -230,24 +230,6 @@ def load_model_from_config(config, ckpt, verbose=True, freeze=True):
return model
def get_checkpoints_path() -> str:
"""
Get the `checkpoints` directory.
This could be in the root of the repository for a working copy,
or in the cwd for other use cases.
"""
this_dir = os.path.dirname(__file__)
candidates = (
os.path.join(this_dir, "checkpoints"),
os.path.join(os.getcwd(), "checkpoints"),
)
for candidate in candidates:
candidate = os.path.abspath(candidate)
if os.path.isdir(candidate):
return candidate
raise FileNotFoundError(f"Could not find SGM checkpoints in {candidates}")
def get_configs_path() -> str:
"""
Get the `configs` directory.

View File

@@ -27,7 +27,7 @@ class TestInference:
@fixture(
scope="class",
params=[
[ModelArchitecture.SDXL_V1_0_BASE, ModelArchitecture.SDXL_V1_0_REFINER],
[ModelArchitecture.SDXL_V1_BASE, ModelArchitecture.SDXL_V1_REFINER],
[ModelArchitecture.SDXL_V0_9_BASE, ModelArchitecture.SDXL_V0_9_REFINER],
],
ids=["SDXL_V1", "SDXL_V0_9"],
@@ -68,7 +68,9 @@ class TestInference:
assert output is not None
@pytest.mark.parametrize("sampler_enum", Sampler)
@pytest.mark.parametrize("use_init_image", [True, False], ids=["img2img", "txt2img"])
@pytest.mark.parametrize(
"use_init_image", [True, False], ids=["img2img", "txt2img"]
)
def test_sdxl_with_refiner(
self,
sdxl_pipelines: Tuple[SamplingPipeline, SamplingPipeline],
@@ -79,12 +81,13 @@ class TestInference:
if use_init_image:
output = base_pipeline.image_to_image(
params=SamplingParams(sampler=sampler_enum.value, steps=10),
image=self.create_init_image(base_pipeline.specs.height, base_pipeline.specs.width),
image=self.create_init_image(
base_pipeline.specs.height, base_pipeline.specs.width
),
prompt="A professional photograph of an astronaut riding a pig",
negative_prompt="",
samples=1,
return_latents=True,
noise_strength=0.15,
)
else:
output = base_pipeline.text_to_image(
@@ -93,7 +96,6 @@ class TestInference:
negative_prompt="",
samples=1,
return_latents=True,
noise_strength=0.15,
)
assert isinstance(output, (tuple, list))
@@ -101,9 +103,9 @@ class TestInference:
assert samples is not None
assert samples_z is not None
refiner_pipeline.refiner(
params=SamplingParams(sampler=sampler_enum.value, steps=10),
image=samples_z,
prompt="A professional photograph of an astronaut riding a pig",
params=SamplingParams(sampler=sampler_enum.value, steps=40, img2img_strength=0.15),
negative_prompt="",
samples=1,
)

View File

@@ -1,44 +0,0 @@
import pytest
import torch
from sgm.inference.api import (
SamplingPipeline,
ModelArchitecture,
)
import sgm.inference.helpers as helpers
def get_torch_device(model: torch.nn.Module) -> torch.device:
param = next(model.parameters(), None)
if param is not None:
return param.device
else:
buf = next(model.buffers(), None)
if buf is not None:
return buf.device
else:
raise TypeError("Could not determine device of input model")
@pytest.mark.inference
def test_default_loading():
pipeline = SamplingPipeline(model_id=ModelArchitecture.SD_2_1)
assert get_torch_device(pipeline.model.model).type == "cuda"
assert get_torch_device(pipeline.model.conditioner).type == "cuda"
with pipeline.device_manager.use(pipeline.model.model):
assert get_torch_device(pipeline.model.model).type == "cuda"
assert get_torch_device(pipeline.model.model).type == "cuda"
with pipeline.device_manager.use(pipeline.model.conditioner):
assert get_torch_device(pipeline.model.conditioner).type == "cuda"
assert get_torch_device(pipeline.model.conditioner).type == "cuda"
@pytest.mark.inference
def test_model_swapping():
pipeline = SamplingPipeline(model_id=ModelArchitecture.SD_2_1, device=helpers.CudaModelManager(device="cuda", swap_device="cpu"))
assert get_torch_device(pipeline.model.model).type == "cpu"
assert get_torch_device(pipeline.model.conditioner).type == "cpu"
with pipeline.device_manager.use(pipeline.model.model):
assert get_torch_device(pipeline.model.model).type == "cuda"
assert get_torch_device(pipeline.model.model).type == "cpu"
with pipeline.device_manager.use(pipeline.model.conditioner):
assert get_torch_device(pipeline.model.conditioner).type == "cuda"
assert get_torch_device(pipeline.model.conditioner).type == "cpu"