2 Commits

Author SHA1 Message Date
Stephan Auerhahn
3167b358c0 use feature to install pytorch 2023-08-12 05:59:57 +00:00
Stephan Auerhahn
ab11af1431 Add devcontainer configs 2023-08-03 17:44:56 -07:00
14 changed files with 1040 additions and 732 deletions

View File

@@ -0,0 +1,45 @@
{
"name": "Python 3",
"image": "mcr.microsoft.com/devcontainers/python:1-3.10-bullseye",
"features": {
"ghcr.io/devcontainers-contrib/features/hatch:2": {
"version": "latest"
},
"ghcr.io/devcontainers-contrib/features/pipx-package:1": {
"package": "black",
"version": "latest",
"injections": "pylint pytest",
"interpreter": "python3"
},
"ghcr.io/devcontainers-contrib/features/apt-get-packages:1": {
"packages": "libgl1-mesa-glx"
},
"ghcr.io/stability-ai/devcontainer-features/pytorch:1.0.1": {
"version": "2.0.1",
"cudaVersion": "cpu"
}
},
"customizations": {
"vscode": {
"extensions": [
"ms-python.python",
"ms-python.vscode-pylance",
"ms-python.black-formatter"
],
"settings": {
"[python]": {
"diffEditor.ignoreTrimWhitespace": false,
"editor.defaultFormatter": "ms-python.black-formatter",
"editor.formatOnSave": true,
"editor.wordBasedSuggestions": false
},
"python.analysis.typeCheckingMode": "basic",
"black-formatter.args": [
"-l 88"
]
}
}
},
"postCreateCommand": "git config --global --add safe.directory '*'",
"postAttachCommand": "pip install -U -r requirements/pt2.txt;pip install -U -e ."
}

View File

@@ -0,0 +1,52 @@
{
"name": "Python 3",
"image": "mcr.microsoft.com/devcontainers/python:1-3.8-bullseye",
"features": {
"ghcr.io/devcontainers/features/nvidia-cuda:1": {
"cudaVersion": "11.7"
},
"ghcr.io/devcontainers-contrib/features/hatch:2": {
"version": "latest"
},
"ghcr.io/devcontainers-contrib/features/pipx-package:1": {
"package": "black",
"version": "latest",
"injections": "pylint pytest",
"interpreter": "python3"
},
"ghcr.io/devcontainers-contrib/features/apt-get-packages:1": {
"packages": "libgl1-mesa-glx"
},
"ghcr.io/stability-ai/devcontainer-features/pytorch:1.0.1": {
"version": "1.13.1",
"cudaVersion": "cu117"
}
},
"customizations": {
"vscode": {
"extensions": [
"ms-python.python",
"ms-python.vscode-pylance",
"ms-python.black-formatter"
],
"settings": {
"[python]": {
"diffEditor.ignoreTrimWhitespace": false,
"editor.defaultFormatter": "ms-python.black-formatter",
"editor.formatOnSave": true,
"editor.wordBasedSuggestions": false
},
"python.analysis.typeCheckingMode": "basic",
"black-formatter.args": [
"-l 88"
]
}
}
},
"runArgs": [
"--gpus",
"all"
],
"postCreateCommand": "git config --global --add safe.directory '*'",
"postAttachCommand": "pip install -U -r requirements/pt13.txt;pip install -U -e ."
}

View File

@@ -0,0 +1,52 @@
{
"name": "Python 3",
"image": "mcr.microsoft.com/devcontainers/python:1-3.10-bullseye",
"features": {
"ghcr.io/devcontainers/features/nvidia-cuda:1": {
"cudaVersion": "11.8"
},
"ghcr.io/devcontainers-contrib/features/hatch:2": {
"version": "latest"
},
"ghcr.io/devcontainers-contrib/features/pipx-package:1": {
"package": "black",
"version": "latest",
"injections": "pylint pytest",
"interpreter": "python3"
},
"ghcr.io/devcontainers-contrib/features/apt-get-packages:1": {
"packages": "libgl1-mesa-glx"
},
"ghcr.io/stability-ai/devcontainer-features/pytorch:1.0.1": {
"version": "2.0.1",
"cudaVersion": "cu118"
}
},
"customizations": {
"vscode": {
"extensions": [
"ms-python.python",
"ms-python.vscode-pylance",
"ms-python.black-formatter"
],
"settings": {
"[python]": {
"diffEditor.ignoreTrimWhitespace": false,
"editor.defaultFormatter": "ms-python.black-formatter",
"editor.formatOnSave": true,
"editor.wordBasedSuggestions": false
},
"python.analysis.typeCheckingMode": "basic",
"black-formatter.args": [
"-l 88"
]
}
}
},
"runArgs": [
"--gpus",
"all"
],
"postCreateCommand": "git config --global --add safe.directory '*'",
"postAttachCommand": "pip install -U -r requirements/pt2.txt;pip install -U -e ."
}

View File

@@ -15,7 +15,7 @@ jobs:
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
- name: "Symlink checkpoints" - name: "Symlink checkpoints"
run: ln -s $SGM_CHECKPOINTS checkpoints run: ln -s ${{vars.SGM_CHECKPOINTS_PATH}} checkpoints
- name: "Setup python" - name: "Setup python"
uses: actions/setup-python@v4 uses: actions/setup-python@v4
with: with:

View File

@@ -44,5 +44,5 @@ dependencies = [
test-inference = [ test-inference = [
"pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 torchaudio==2.0.2+cu118 --index-url https://download.pytorch.org/whl/cu118", "pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 torchaudio==2.0.2+cu118 --index-url https://download.pytorch.org/whl/cu118",
"pip install -r requirements/pt2.txt", "pip install -r requirements/pt2.txt",
"pytest -v tests/inference {args}", "pytest -v tests/inference/test_inference.py {args}",
] ]

View File

@@ -1,30 +1,6 @@
import os
import numpy as np
import streamlit as st
import torch
from einops import repeat
from pytorch_lightning import seed_everything from pytorch_lightning import seed_everything
from sgm.inference.api import ( from scripts.demo.streamlit_helpers import *
SamplingSpec,
SamplingParams,
ModelArchitecture,
SamplingPipeline,
model_specs,
)
from sgm.inference.helpers import (
get_unique_embedder_keys_from_conditioner,
perform_save_locally,
)
from scripts.demo.streamlit_helpers import (
get_interactive_image,
init_embedder_options,
init_sampling,
init_save_locally,
init_st,
show_samples,
)
SAVE_PATH = "outputs/demo/txt2img/" SAVE_PATH = "outputs/demo/txt2img/"
@@ -57,6 +33,63 @@ SD_XL_BASE_RATIOS = {
"3.0": (1728, 576), "3.0": (1728, 576),
} }
VERSION2SPECS = {
"SDXL-base-1.0": {
"H": 1024,
"W": 1024,
"C": 4,
"f": 8,
"is_legacy": False,
"config": "configs/inference/sd_xl_base.yaml",
"ckpt": "checkpoints/sd_xl_base_1.0.safetensors",
},
"SDXL-base-0.9": {
"H": 1024,
"W": 1024,
"C": 4,
"f": 8,
"is_legacy": False,
"config": "configs/inference/sd_xl_base.yaml",
"ckpt": "checkpoints/sd_xl_base_0.9.safetensors",
},
"SD-2.1": {
"H": 512,
"W": 512,
"C": 4,
"f": 8,
"is_legacy": True,
"config": "configs/inference/sd_2_1.yaml",
"ckpt": "checkpoints/v2-1_512-ema-pruned.safetensors",
},
"SD-2.1-768": {
"H": 768,
"W": 768,
"C": 4,
"f": 8,
"is_legacy": True,
"config": "configs/inference/sd_2_1_768.yaml",
"ckpt": "checkpoints/v2-1_768-ema-pruned.safetensors",
},
"SDXL-refiner-0.9": {
"H": 1024,
"W": 1024,
"C": 4,
"f": 8,
"is_legacy": True,
"config": "configs/inference/sd_xl_refiner.yaml",
"ckpt": "checkpoints/sd_xl_refiner_0.9.safetensors",
},
"SDXL-refiner-1.0": {
"H": 1024,
"W": 1024,
"C": 4,
"f": 8,
"is_legacy": True,
"config": "configs/inference/sd_xl_refiner.yaml",
"ckpt": "checkpoints/sd_xl_refiner_1.0.safetensors",
},
}
def load_img(display=True, key=None, device="cuda"): def load_img(display=True, key=None, device="cuda"):
image = get_interactive_image(key=key) image = get_interactive_image(key=key)
@@ -78,181 +111,170 @@ def load_img(display=True, key=None, device="cuda"):
def run_txt2img( def run_txt2img(
state, state,
model_id: ModelArchitecture, version,
prompt: str, version_dict,
negative_prompt: str, is_legacy=False,
return_latents=False, return_latents=False,
filter=None,
stage2strength=None, stage2strength=None,
): ):
model: SamplingPipeline = state["model"] if version.startswith("SDXL-base"):
params: SamplingParams = state["params"] W, H = st.selectbox("Resolution:", list(SD_XL_BASE_RATIOS.values()), 10)
if model_id in sdxl_base_model_list:
width, height = st.selectbox(
"Resolution:", list(SD_XL_BASE_RATIOS.values()), 10
)
else: else:
height = int( H = st.number_input("H", value=version_dict["H"], min_value=64, max_value=2048)
st.number_input("H", value=params.height, min_value=64, max_value=2048) W = st.number_input("W", value=version_dict["W"], min_value=64, max_value=2048)
) C = version_dict["C"]
width = int( F = version_dict["f"]
st.number_input("W", value=params.width, min_value=64, max_value=2048)
)
params = init_embedder_options( init_dict = {
get_unique_embedder_keys_from_conditioner(model.model.conditioner), "orig_width": W,
params=params, "orig_height": H,
"target_width": W,
"target_height": H,
}
value_dict = init_embedder_options(
get_unique_embedder_keys_from_conditioner(state["model"].conditioner),
init_dict,
prompt=prompt, prompt=prompt,
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
) )
params, num_rows, num_cols = init_sampling(params=params) sampler, num_rows, num_cols = init_sampling(stage2strength=stage2strength)
num_samples = num_rows * num_cols num_samples = num_rows * num_cols
params.height = height
params.width = width
if st.button("Sample"): if st.button("Sample"):
st.write(f"**Model I:** {version}") st.write(f"**Model I:** {version}")
outputs = st.empty() out = do_sample(
st.text("Sampling") state["model"],
out = model.text_to_image( sampler,
params=params, value_dict,
prompt=prompt, num_samples,
negative_prompt=negative_prompt, H,
samples=int(num_samples), W,
C,
F,
force_uc_zero_embeddings=["txt"] if not is_legacy else [],
return_latents=return_latents, return_latents=return_latents,
noise_strength=stage2strength, filter=filter,
filter=state["filter"],
) )
show_samples(out, outputs)
return out return out
def run_img2img( def run_img2img(
state, state,
prompt: str, version_dict,
negative_prompt: str, is_legacy=False,
return_latents=False, return_latents=False,
filter=None,
stage2strength=None, stage2strength=None,
): ):
model: SamplingPipeline = state["model"]
params: SamplingParams = state["params"]
img = load_img() img = load_img()
if img is None: if img is None:
return None return None
params.height, params.width = img.shape[2], img.shape[3] H, W = img.shape[2], img.shape[3]
params = init_embedder_options( init_dict = {
get_unique_embedder_keys_from_conditioner(model.model.conditioner), "orig_width": W,
params=params, "orig_height": H,
"target_width": W,
"target_height": H,
}
value_dict = init_embedder_options(
get_unique_embedder_keys_from_conditioner(state["model"].conditioner),
init_dict,
prompt=prompt, prompt=prompt,
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
) )
params.img2img_strength = st.number_input( strength = st.number_input(
"**Img2Img Strength**", value=0.75, min_value=0.0, max_value=1.0 "**Img2Img Strength**", value=0.75, min_value=0.0, max_value=1.0
) )
params, num_rows, num_cols = init_sampling(params=params) sampler, num_rows, num_cols = init_sampling(
img2img_strength=strength,
stage2strength=stage2strength,
)
num_samples = num_rows * num_cols num_samples = num_rows * num_cols
if st.button("Sample"): if st.button("Sample"):
outputs = st.empty() out = do_img2img(
st.text("Sampling") repeat(img, "1 ... -> n ...", n=num_samples),
out = model.image_to_image( state["model"],
image=repeat(img, "1 ... -> n ...", n=num_samples), sampler,
params=params, value_dict,
prompt=prompt, num_samples,
negative_prompt=negative_prompt, force_uc_zero_embeddings=["txt"] if not is_legacy else [],
samples=int(num_samples),
return_latents=return_latents, return_latents=return_latents,
noise_strength=stage2strength, filter=filter,
filter=state["filter"],
) )
show_samples(out, outputs)
return out return out
def apply_refiner( def apply_refiner(
input, input,
state, state,
num_samples: int, sampler,
prompt: str, num_samples,
negative_prompt: str, prompt,
negative_prompt,
filter=None,
finish_denoising=False, finish_denoising=False,
): ):
model: SamplingPipeline = state["model"] init_dict = {
params: SamplingParams = state["params"] "orig_width": input.shape[3] * 8,
"orig_height": input.shape[2] * 8,
"target_width": input.shape[3] * 8,
"target_height": input.shape[2] * 8,
}
params.orig_width = input.shape[3] * 8 value_dict = init_dict
params.orig_height = input.shape[2] * 8 value_dict["prompt"] = prompt
params.width = input.shape[3] * 8 value_dict["negative_prompt"] = negative_prompt
params.height = input.shape[2] * 8
value_dict["crop_coords_top"] = 0
value_dict["crop_coords_left"] = 0
value_dict["aesthetic_score"] = 6.0
value_dict["negative_aesthetic_score"] = 2.5
st.warning(f"refiner input shape: {input.shape}") st.warning(f"refiner input shape: {input.shape}")
samples = do_img2img(
samples = model.refiner( input,
image=input, state["model"],
params=params, sampler,
prompt=prompt, value_dict,
negative_prompt=negative_prompt, num_samples,
samples=num_samples, skip_encode=True,
return_latents=False, filter=filter,
filter=state["filter"],
add_noise=not finish_denoising, add_noise=not finish_denoising,
) )
return samples return samples
sdxl_base_model_list = [
ModelArchitecture.SDXL_V1_0_BASE,
ModelArchitecture.SDXL_V0_9_BASE,
]
sdxl_refiner_model_list = [
ModelArchitecture.SDXL_V1_0_REFINER,
ModelArchitecture.SDXL_V0_9_REFINER,
]
if __name__ == "__main__": if __name__ == "__main__":
st.title("Stable Diffusion") st.title("Stable Diffusion")
version = st.selectbox( version = st.selectbox("Model Version", list(VERSION2SPECS.keys()), 0)
"Model Version", version_dict = VERSION2SPECS[version]
[member.value for member in ModelArchitecture],
0,
)
version_enum = ModelArchitecture(version)
specs = model_specs[version_enum]
mode = st.radio("Mode", ("txt2img", "img2img"), 0) mode = st.radio("Mode", ("txt2img", "img2img"), 0)
st.write("__________________________") st.write("__________________________")
st.write("**Performance Options:**") set_lowvram_mode(st.checkbox("Low vram mode", True))
use_fp16 = st.checkbox("Use fp16 (Saves VRAM)", True)
enable_swap = st.checkbox("Swap models to CPU (Saves VRAM, uses RAM)", True)
st.write("__________________________")
if version_enum in sdxl_base_model_list: if version.startswith("SDXL-base"):
add_pipeline = st.checkbox("Load SDXL-refiner?", False) add_pipeline = st.checkbox("Load SDXL-refiner?", False)
st.write("__________________________") st.write("__________________________")
else: else:
add_pipeline = False add_pipeline = False
seed = int( seed = st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9))
st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9))
)
seed_everything(seed) seed_everything(seed)
save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, str(version))) save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, version))
state = init_st(
model_specs[version_enum], state = init_st(version_dict, load_filter=True)
load_filter=True, if state["msg"]:
use_fp16=use_fp16, st.info(state["msg"])
enable_swap=enable_swap,
)
model = state["model"] model = state["model"]
is_legacy = specs.is_legacy is_legacy = version_dict["is_legacy"]
prompt = st.text_input( prompt = st.text_input(
"prompt", "prompt",
@@ -268,58 +290,47 @@ if __name__ == "__main__":
if add_pipeline: if add_pipeline:
st.write("__________________________") st.write("__________________________")
version2 = ModelArchitecture( version2 = st.selectbox("Refiner:", ["SDXL-refiner-1.0", "SDXL-refiner-0.9"])
st.selectbox(
"Refiner:",
[member.value for member in sdxl_refiner_model_list],
)
)
st.warning( st.warning(
f"Running with {version2} as the second stage model. Make sure to provide (V)RAM :) " f"Running with {version2} as the second stage model. Make sure to provide (V)RAM :) "
) )
st.write("**Refiner Options:**") st.write("**Refiner Options:**")
specs2 = model_specs[version2] version_dict2 = VERSION2SPECS[version2]
state2 = init_st( state2 = init_st(version_dict2, load_filter=False)
specs2, load_filter=False, use_fp16=use_fp16, enable_swap=enable_swap st.info(state2["msg"])
)
params2 = state2["params"]
params2.img2img_strength = st.number_input( stage2strength = st.number_input(
"**Refinement strength**", value=0.15, min_value=0.0, max_value=1.0 "**Refinement strength**", value=0.15, min_value=0.0, max_value=1.0
) )
params2, *_ = init_sampling( sampler2, *_ = init_sampling(
params=state2["params"],
key=2, key=2,
img2img_strength=stage2strength,
specify_num_samples=False, specify_num_samples=False,
) )
st.write("__________________________") st.write("__________________________")
finish_denoising = st.checkbox("Finish denoising with refiner.", True) finish_denoising = st.checkbox("Finish denoising with refiner.", True)
if finish_denoising: if not finish_denoising:
stage2strength = params2.img2img_strength
else:
stage2strength = None stage2strength = None
else:
state2 = None
params2 = None
stage2strength = None
if mode == "txt2img": if mode == "txt2img":
out = run_txt2img( out = run_txt2img(
state=state, state,
model_id=version_enum, version,
prompt=prompt, version_dict,
negative_prompt=negative_prompt, is_legacy=is_legacy,
return_latents=add_pipeline, return_latents=add_pipeline,
filter=state.get("filter"),
stage2strength=stage2strength, stage2strength=stage2strength,
) )
elif mode == "img2img": elif mode == "img2img":
out = run_img2img( out = run_img2img(
state=state, state,
prompt=prompt, version_dict,
negative_prompt=negative_prompt, is_legacy=is_legacy,
return_latents=add_pipeline, return_latents=add_pipeline,
filter=state.get("filter"),
stage2strength=stage2strength, stage2strength=stage2strength,
) )
else: else:
@@ -331,17 +342,17 @@ if __name__ == "__main__":
samples_z = None samples_z = None
if add_pipeline and samples_z is not None: if add_pipeline and samples_z is not None:
outputs = st.empty()
st.write("**Running Refinement Stage**") st.write("**Running Refinement Stage**")
samples = apply_refiner( samples = apply_refiner(
input=samples_z, samples_z,
state=state2, state2,
num_samples=samples_z.shape[0], sampler2,
samples_z.shape[0],
prompt=prompt, prompt=prompt,
negative_prompt=negative_prompt if is_legacy else "", negative_prompt=negative_prompt if is_legacy else "",
filter=state.get("filter"),
finish_denoising=finish_denoising, finish_denoising=finish_denoising,
) )
show_samples(samples, outputs)
if save_locally and samples is not None: if save_locally and samples is not None:
perform_save_locally(save_path, samples) perform_save_locally(save_path, samples)

View File

@@ -1,68 +1,166 @@
import math
import os import os
from typing import List, Union
import numpy as np import numpy as np
import streamlit as st import streamlit as st
import torch import torch
from einops import rearrange, repeat from einops import rearrange, repeat
from imwatermark import WatermarkEncoder
from omegaconf import ListConfig, OmegaConf
from PIL import Image from PIL import Image
from safetensors.torch import load_file as load_safetensors
from torch import autocast
from torchvision import transforms from torchvision import transforms
from typing import Optional, Tuple, Dict, Any from torchvision.utils import make_grid
from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering
from sgm.modules.diffusionmodules.sampling import (
from sgm.inference.api import ( DPMPP2MSampler,
Discretization, DPMPP2SAncestralSampler,
Guider, EulerAncestralSampler,
Sampler, EulerEDMSampler,
SamplingParams, HeunEDMSampler,
SamplingSpec, LinearMultistepSampler,
SamplingPipeline,
Thresholder,
) )
from sgm.inference.helpers import embed_watermark, CudaModelManager from sgm.util import append_dims, instantiate_from_config
class WatermarkEmbedder:
def __init__(self, watermark):
self.watermark = watermark
self.num_bits = len(WATERMARK_BITS)
self.encoder = WatermarkEncoder()
self.encoder.set_watermark("bits", self.watermark)
def __call__(self, image: torch.Tensor):
"""
Adds a predefined watermark to the input image
Args:
image: ([N,] B, C, H, W) in range [0, 1]
Returns:
same as input but watermarked
"""
# watermarking libary expects input as cv2 BGR format
squeeze = len(image.shape) == 4
if squeeze:
image = image[None, ...]
n = image.shape[0]
image_np = rearrange(
(255 * image).detach().cpu(), "n b c h w -> (n b) h w c"
).numpy()[:, :, :, ::-1]
# torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
for k in range(image_np.shape[0]):
image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
image = torch.from_numpy(
rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)
).to(image.device)
image = torch.clamp(image / 255, min=0.0, max=1.0)
if squeeze:
image = image[0]
return image
# A fixed 48-bit message that was choosen at random
# WATERMARK_MESSAGE = 0xB3EC907BB19E
WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110
# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
embed_watemark = WatermarkEmbedder(WATERMARK_BITS)
@st.cache_resource() @st.cache_resource()
def init_st( def init_st(version_dict, load_ckpt=True, load_filter=True):
spec: SamplingSpec, state = dict()
load_ckpt=True, if not "model" in state:
load_filter=True, config = version_dict["config"]
use_fp16=True, ckpt = version_dict["ckpt"]
enable_swap=True,
) -> Dict[str, Any]:
state: Dict[str, Any] = dict()
config = spec.config
ckpt = spec.ckpt
if enable_swap: config = OmegaConf.load(config)
pipeline = SamplingPipeline( model, msg = load_model_from_config(config, ckpt if load_ckpt else None)
model_spec=spec,
use_fp16=use_fp16,
device=CudaModelManager(device="cuda", swap_device="cpu"),
)
else:
pipeline = SamplingPipeline(model_spec=spec, use_fp16=use_fp16)
state["spec"] = spec state["msg"] = msg
state["model"] = pipeline state["model"] = model
state["ckpt"] = ckpt if load_ckpt else None state["ckpt"] = ckpt if load_ckpt else None
state["config"] = config state["config"] = config
state["params"] = spec.default_params if load_filter:
if load_filter: state["filter"] = DeepFloydDataFiltering(verbose=False)
state["filter"] = DeepFloydDataFiltering(verbose=False)
else:
state["filter"] = None
return state return state
def load_model(model):
model.cuda()
lowvram_mode = False
def set_lowvram_mode(mode):
global lowvram_mode
lowvram_mode = mode
def initial_model_load(model):
global lowvram_mode
if lowvram_mode:
model.model.half()
else:
model.cuda()
return model
def unload_model(model):
global lowvram_mode
if lowvram_mode:
model.cpu()
torch.cuda.empty_cache()
def load_model_from_config(config, ckpt=None, verbose=True):
model = instantiate_from_config(config.model)
if ckpt is not None:
print(f"Loading model from {ckpt}")
if ckpt.endswith("ckpt"):
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
global_step = pl_sd["global_step"]
st.info(f"loaded ckpt from global step {global_step}")
print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"]
elif ckpt.endswith("safetensors"):
sd = load_safetensors(ckpt)
else:
raise NotImplementedError
msg = None
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print("missing keys:")
print(m)
if len(u) > 0 and verbose:
print("unexpected keys:")
print(u)
else:
msg = None
model = initial_model_load(model)
model.eval()
return model, msg
def get_unique_embedder_keys_from_conditioner(conditioner): def get_unique_embedder_keys_from_conditioner(conditioner):
return list(set([x.input_key for x in conditioner.embedders])) return list(set([x.input_key for x in conditioner.embedders]))
def init_embedder_options( def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None):
keys, params: SamplingParams, prompt=None, negative_prompt=None # Hardcoded demo settings; might undergo some changes in the future
) -> SamplingParams:
value_dict = {}
for key in keys: for key in keys:
if key == "txt": if key == "txt":
if prompt is None: if prompt is None:
@@ -72,38 +170,46 @@ def init_embedder_options(
if negative_prompt is None: if negative_prompt is None:
negative_prompt = st.text_input("Negative prompt", "") negative_prompt = st.text_input("Negative prompt", "")
value_dict["prompt"] = prompt
value_dict["negative_prompt"] = negative_prompt
if key == "original_size_as_tuple": if key == "original_size_as_tuple":
orig_width = st.number_input( orig_width = st.number_input(
"orig_width", "orig_width",
value=params.orig_width, value=init_dict["orig_width"],
min_value=16, min_value=16,
) )
orig_height = st.number_input( orig_height = st.number_input(
"orig_height", "orig_height",
value=params.orig_height, value=init_dict["orig_height"],
min_value=16, min_value=16,
) )
params.orig_width = int(orig_width) value_dict["orig_width"] = orig_width
params.orig_height = int(orig_height) value_dict["orig_height"] = orig_height
if key == "crop_coords_top_left": if key == "crop_coords_top_left":
crop_coord_top = st.number_input( crop_coord_top = st.number_input("crop_coords_top", value=0, min_value=0)
"crop_coords_top", value=params.crop_coords_top, min_value=0 crop_coord_left = st.number_input("crop_coords_left", value=0, min_value=0)
)
crop_coord_left = st.number_input(
"crop_coords_left", value=params.crop_coords_left, min_value=0
)
params.crop_coords_top = int(crop_coord_top) value_dict["crop_coords_top"] = crop_coord_top
params.crop_coords_left = int(crop_coord_left) value_dict["crop_coords_left"] = crop_coord_left
return params
if key == "aesthetic_score":
value_dict["aesthetic_score"] = 6.0
value_dict["negative_aesthetic_score"] = 2.5
if key == "target_size_as_tuple":
value_dict["target_width"] = init_dict["target_width"]
value_dict["target_height"] = init_dict["target_height"]
return value_dict
def perform_save_locally(save_path, samples): def perform_save_locally(save_path, samples):
os.makedirs(os.path.join(save_path), exist_ok=True) os.makedirs(os.path.join(save_path), exist_ok=True)
base_count = len(os.listdir(os.path.join(save_path))) base_count = len(os.listdir(os.path.join(save_path)))
samples = embed_watermark(samples) samples = embed_watemark(samples)
for sample in samples: for sample in samples:
sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c") sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c")
Image.fromarray(sample.astype(np.uint8)).save( Image.fromarray(sample.astype(np.uint8)).save(
@@ -122,26 +228,78 @@ def init_save_locally(_dir, init_value: bool = False):
return save_locally, save_path return save_locally, save_path
def show_samples(samples, outputs): class Img2ImgDiscretizationWrapper:
if isinstance(samples, tuple): """
samples, _ = samples wraps a discretizer, and prunes the sigmas
grid = embed_watermark(torch.stack([samples])) params:
grid = rearrange(grid, "n b c h w -> (n h) (b w) c") strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned)
outputs.image(grid.cpu().numpy()) """
def __init__(self, discretization, strength: float = 1.0):
self.discretization = discretization
self.strength = strength
assert 0.0 <= self.strength <= 1.0
def __call__(self, *args, **kwargs):
# sigmas start large first, and decrease then
sigmas = self.discretization(*args, **kwargs)
print(f"sigmas after discretization, before pruning img2img: ", sigmas)
sigmas = torch.flip(sigmas, (0,))
sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)]
print("prune index:", max(int(self.strength * len(sigmas)), 1))
sigmas = torch.flip(sigmas, (0,))
print(f"sigmas after pruning: ", sigmas)
return sigmas
def get_guider(params: SamplingParams, key=1) -> SamplingParams: class Txt2NoisyDiscretizationWrapper:
params.guider = Guider( """
st.sidebar.selectbox( wraps a discretizer, and prunes the sigmas
f"Discretization #{key}", [member.value for member in Guider] params:
) strength: float between 0.0 and 1.0. 0.0 means full sampling (all sigmas are returned)
"""
def __init__(self, discretization, strength: float = 0.0, original_steps=None):
self.discretization = discretization
self.strength = strength
self.original_steps = original_steps
assert 0.0 <= self.strength <= 1.0
def __call__(self, *args, **kwargs):
# sigmas start large first, and decrease then
sigmas = self.discretization(*args, **kwargs)
print(f"sigmas after discretization, before pruning img2img: ", sigmas)
sigmas = torch.flip(sigmas, (0,))
if self.original_steps is None:
steps = len(sigmas)
else:
steps = self.original_steps + 1
prune_index = max(min(int(self.strength * steps) - 1, steps - 1), 0)
sigmas = sigmas[prune_index:]
print("prune index:", prune_index)
sigmas = torch.flip(sigmas, (0,))
print(f"sigmas after pruning: ", sigmas)
return sigmas
def get_guider(key):
guider = st.sidebar.selectbox(
f"Discretization #{key}",
[
"VanillaCFG",
"IdentityGuider",
],
) )
if params.guider == Guider.VANILLA: if guider == "IdentityGuider":
guider_config = {
"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"
}
elif guider == "VanillaCFG":
scale = st.number_input( scale = st.number_input(
f"cfg-scale #{key}", value=params.scale, min_value=0.0, max_value=100.0 f"cfg-scale #{key}", value=5.0, min_value=0.0, max_value=100.0
) )
params.scale = scale
thresholder = st.sidebar.selectbox( thresholder = st.sidebar.selectbox(
f"Thresholder #{key}", f"Thresholder #{key}",
[ [
@@ -150,97 +308,182 @@ def get_guider(params: SamplingParams, key=1) -> SamplingParams:
) )
if thresholder == "None": if thresholder == "None":
params.thresholder = Thresholder.NONE dyn_thresh_config = {
"target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"
}
else: else:
raise NotImplementedError raise NotImplementedError
return params
guider_config = {
"target": "sgm.modules.diffusionmodules.guiders.VanillaCFG",
"params": {"scale": scale, "dyn_thresh_config": dyn_thresh_config},
}
else:
raise NotImplementedError
return guider_config
def init_sampling( def init_sampling(
params: SamplingParams,
key=1, key=1,
img2img_strength=1.0,
specify_num_samples=True, specify_num_samples=True,
) -> Tuple[SamplingParams, int, int]: stage2strength=None,
):
num_rows, num_cols = 1, 1 num_rows, num_cols = 1, 1
if specify_num_samples: if specify_num_samples:
num_cols = st.number_input( num_cols = st.number_input(
f"num cols #{key}", value=2, min_value=1, max_value=10 f"num cols #{key}", value=2, min_value=1, max_value=10
) )
params.steps = int( steps = st.sidebar.number_input(
st.sidebar.number_input( f"steps #{key}", value=40, min_value=1, max_value=1000
f"steps #{key}", value=params.steps, min_value=1, max_value=1000 )
) sampler = st.sidebar.selectbox(
f"Sampler #{key}",
[
"EulerEDMSampler",
"HeunEDMSampler",
"EulerAncestralSampler",
"DPMPP2SAncestralSampler",
"DPMPP2MSampler",
"LinearMultistepSampler",
],
0,
)
discretization = st.sidebar.selectbox(
f"Discretization #{key}",
[
"LegacyDDPMDiscretization",
"EDMDiscretization",
],
) )
params.sampler = Sampler( discretization_config = get_discretization(discretization, key=key)
st.sidebar.selectbox(
f"Sampler #{key}", guider_config = get_guider(key=key)
[member.value for member in Sampler],
0, sampler = get_sampler(sampler, steps, discretization_config, guider_config, key=key)
if img2img_strength < 1.0:
st.warning(
f"Wrapping {sampler.__class__.__name__} with Img2ImgDiscretizationWrapper"
) )
) sampler.discretization = Img2ImgDiscretizationWrapper(
params.discretization = Discretization( sampler.discretization, strength=img2img_strength
st.sidebar.selectbox(
f"Discretization #{key}",
[member.value for member in Discretization],
) )
) if stage2strength is not None:
sampler.discretization = Txt2NoisyDiscretizationWrapper(
params = get_discretization(params=params, key=key) sampler.discretization, strength=stage2strength, original_steps=steps
params = get_guider(params=params, key=key)
params = get_sampler(params=params, key=key)
return params, num_rows, num_cols
def get_discretization(params: SamplingParams, key=1) -> SamplingParams:
if params.discretization == Discretization.EDM:
params.sigma_min = st.number_input(f"sigma_min #{key}", value=params.sigma_min)
params.sigma_max = st.number_input(f"sigma_max #{key}", value=params.sigma_max)
params.rho = st.number_input(f"rho #{key}", value=params.rho)
return params
def get_sampler(params: SamplingParams, key=1) -> SamplingParams:
if params.sampler in (Sampler.EULER_EDM, Sampler.HEUN_EDM):
params.s_churn = st.sidebar.number_input(
f"s_churn #{key}", value=params.s_churn, min_value=0.0
) )
params.s_tmin = st.sidebar.number_input( return sampler, num_rows, num_cols
f"s_tmin #{key}", value=params.s_tmin, min_value=0.0
)
params.s_tmax = st.sidebar.number_input(
f"s_tmax #{key}", value=params.s_tmax, min_value=0.0
)
params.s_noise = st.sidebar.number_input(
f"s_noise #{key}", value=params.s_noise, min_value=0.0
)
elif params.sampler in (Sampler.EULER_ANCESTRAL, Sampler.DPMPP2S_ANCESTRAL):
params.s_noise = st.sidebar.number_input(
"s_noise", value=params.s_noise, min_value=0.0
)
params.eta = st.sidebar.number_input("eta", value=params.eta, min_value=0.0)
elif params.sampler == Sampler.LINEAR_MULTISTEP:
params.order = int(
st.sidebar.number_input("order", value=params.order, min_value=1)
)
return params
def get_interactive_image(key=None) -> Optional[Image.Image]: def get_discretization(discretization, key=1):
if discretization == "LegacyDDPMDiscretization":
discretization_config = {
"target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",
}
elif discretization == "EDMDiscretization":
sigma_min = st.number_input(f"sigma_min #{key}", value=0.03) # 0.0292
sigma_max = st.number_input(f"sigma_max #{key}", value=14.61) # 14.6146
rho = st.number_input(f"rho #{key}", value=3.0)
discretization_config = {
"target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization",
"params": {
"sigma_min": sigma_min,
"sigma_max": sigma_max,
"rho": rho,
},
}
return discretization_config
def get_sampler(sampler_name, steps, discretization_config, guider_config, key=1):
if sampler_name == "EulerEDMSampler" or sampler_name == "HeunEDMSampler":
s_churn = st.sidebar.number_input(f"s_churn #{key}", value=0.0, min_value=0.0)
s_tmin = st.sidebar.number_input(f"s_tmin #{key}", value=0.0, min_value=0.0)
s_tmax = st.sidebar.number_input(f"s_tmax #{key}", value=999.0, min_value=0.0)
s_noise = st.sidebar.number_input(f"s_noise #{key}", value=1.0, min_value=0.0)
if sampler_name == "EulerEDMSampler":
sampler = EulerEDMSampler(
num_steps=steps,
discretization_config=discretization_config,
guider_config=guider_config,
s_churn=s_churn,
s_tmin=s_tmin,
s_tmax=s_tmax,
s_noise=s_noise,
verbose=True,
)
elif sampler_name == "HeunEDMSampler":
sampler = HeunEDMSampler(
num_steps=steps,
discretization_config=discretization_config,
guider_config=guider_config,
s_churn=s_churn,
s_tmin=s_tmin,
s_tmax=s_tmax,
s_noise=s_noise,
verbose=True,
)
elif (
sampler_name == "EulerAncestralSampler"
or sampler_name == "DPMPP2SAncestralSampler"
):
s_noise = st.sidebar.number_input("s_noise", value=1.0, min_value=0.0)
eta = st.sidebar.number_input("eta", value=1.0, min_value=0.0)
if sampler_name == "EulerAncestralSampler":
sampler = EulerAncestralSampler(
num_steps=steps,
discretization_config=discretization_config,
guider_config=guider_config,
eta=eta,
s_noise=s_noise,
verbose=True,
)
elif sampler_name == "DPMPP2SAncestralSampler":
sampler = DPMPP2SAncestralSampler(
num_steps=steps,
discretization_config=discretization_config,
guider_config=guider_config,
eta=eta,
s_noise=s_noise,
verbose=True,
)
elif sampler_name == "DPMPP2MSampler":
sampler = DPMPP2MSampler(
num_steps=steps,
discretization_config=discretization_config,
guider_config=guider_config,
verbose=True,
)
elif sampler_name == "LinearMultistepSampler":
order = st.sidebar.number_input("order", value=4, min_value=1)
sampler = LinearMultistepSampler(
num_steps=steps,
discretization_config=discretization_config,
guider_config=guider_config,
order=order,
verbose=True,
)
else:
raise ValueError(f"unknown sampler {sampler_name}!")
return sampler
def get_interactive_image(key=None) -> Image.Image:
image = st.file_uploader("Input", type=["jpg", "JPEG", "png"], key=key) image = st.file_uploader("Input", type=["jpg", "JPEG", "png"], key=key)
if image is not None: if image is not None:
image = Image.open(image) image = Image.open(image)
if not image.mode == "RGB": if not image.mode == "RGB":
image = image.convert("RGB") image = image.convert("RGB")
return image return image
return None
def load_img(display=True, key=None) -> Optional[torch.Tensor]: def load_img(display=True, key=None):
image = get_interactive_image(key=key) image = get_interactive_image(key=key)
if image is None: if image is None:
return None return None
@@ -264,3 +507,238 @@ def get_init_img(batch_size=1, key=None):
init_image = load_img(key=key).cuda() init_image = load_img(key=key).cuda()
init_image = repeat(init_image, "1 ... -> b ...", b=batch_size) init_image = repeat(init_image, "1 ... -> b ...", b=batch_size)
return init_image return init_image
def do_sample(
model,
sampler,
value_dict,
num_samples,
H,
W,
C,
F,
force_uc_zero_embeddings: List = None,
batch2model_input: List = None,
return_latents=False,
filter=None,
):
if force_uc_zero_embeddings is None:
force_uc_zero_embeddings = []
if batch2model_input is None:
batch2model_input = []
st.text("Sampling")
outputs = st.empty()
precision_scope = autocast
with torch.no_grad():
with precision_scope("cuda"):
with model.ema_scope():
num_samples = [num_samples]
load_model(model.conditioner)
batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner),
value_dict,
num_samples,
)
for key in batch:
if isinstance(batch[key], torch.Tensor):
print(key, batch[key].shape)
elif isinstance(batch[key], list):
print(key, [len(l) for l in batch[key]])
else:
print(key, batch[key])
c, uc = model.conditioner.get_unconditional_conditioning(
batch,
batch_uc=batch_uc,
force_uc_zero_embeddings=force_uc_zero_embeddings,
)
unload_model(model.conditioner)
for k in c:
if not k == "crossattn":
c[k], uc[k] = map(
lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc)
)
additional_model_inputs = {}
for k in batch2model_input:
additional_model_inputs[k] = batch[k]
shape = (math.prod(num_samples), C, H // F, W // F)
randn = torch.randn(shape).to("cuda")
def denoiser(input, sigma, c):
return model.denoiser(
model.model, input, sigma, c, **additional_model_inputs
)
load_model(model.denoiser)
load_model(model.model)
samples_z = sampler(denoiser, randn, cond=c, uc=uc)
unload_model(model.model)
unload_model(model.denoiser)
load_model(model.first_stage_model)
samples_x = model.decode_first_stage(samples_z)
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
unload_model(model.first_stage_model)
if filter is not None:
samples = filter(samples)
grid = torch.stack([samples])
grid = rearrange(grid, "n b c h w -> (n h) (b w) c")
outputs.image(grid.cpu().numpy())
if return_latents:
return samples, samples_z
return samples
def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
# Hardcoded demo setups; might undergo some changes in the future
batch = {}
batch_uc = {}
for key in keys:
if key == "txt":
batch["txt"] = (
np.repeat([value_dict["prompt"]], repeats=math.prod(N))
.reshape(N)
.tolist()
)
batch_uc["txt"] = (
np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N))
.reshape(N)
.tolist()
)
elif key == "original_size_as_tuple":
batch["original_size_as_tuple"] = (
torch.tensor([value_dict["orig_height"], value_dict["orig_width"]])
.to(device)
.repeat(*N, 1)
)
elif key == "crop_coords_top_left":
batch["crop_coords_top_left"] = (
torch.tensor(
[value_dict["crop_coords_top"], value_dict["crop_coords_left"]]
)
.to(device)
.repeat(*N, 1)
)
elif key == "aesthetic_score":
batch["aesthetic_score"] = (
torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
)
batch_uc["aesthetic_score"] = (
torch.tensor([value_dict["negative_aesthetic_score"]])
.to(device)
.repeat(*N, 1)
)
elif key == "target_size_as_tuple":
batch["target_size_as_tuple"] = (
torch.tensor([value_dict["target_height"], value_dict["target_width"]])
.to(device)
.repeat(*N, 1)
)
else:
batch[key] = value_dict[key]
for key in batch.keys():
if key not in batch_uc and isinstance(batch[key], torch.Tensor):
batch_uc[key] = torch.clone(batch[key])
return batch, batch_uc
@torch.no_grad()
def do_img2img(
img,
model,
sampler,
value_dict,
num_samples,
force_uc_zero_embeddings=[],
additional_kwargs={},
offset_noise_level: int = 0.0,
return_latents=False,
skip_encode=False,
filter=None,
add_noise=True,
):
st.text("Sampling")
outputs = st.empty()
precision_scope = autocast
with torch.no_grad():
with precision_scope("cuda"):
with model.ema_scope():
load_model(model.conditioner)
batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner),
value_dict,
[num_samples],
)
c, uc = model.conditioner.get_unconditional_conditioning(
batch,
batch_uc=batch_uc,
force_uc_zero_embeddings=force_uc_zero_embeddings,
)
unload_model(model.conditioner)
for k in c:
c[k], uc[k] = map(lambda y: y[k][:num_samples].to("cuda"), (c, uc))
for k in additional_kwargs:
c[k] = uc[k] = additional_kwargs[k]
if skip_encode:
z = img
else:
load_model(model.first_stage_model)
z = model.encode_first_stage(img)
unload_model(model.first_stage_model)
noise = torch.randn_like(z)
sigmas = sampler.discretization(sampler.num_steps).cuda()
sigma = sigmas[0]
st.info(f"all sigmas: {sigmas}")
st.info(f"noising sigma: {sigma}")
if offset_noise_level > 0.0:
noise = noise + offset_noise_level * append_dims(
torch.randn(z.shape[0], device=z.device), z.ndim
)
if add_noise:
noised_z = z + noise * append_dims(sigma, z.ndim).cuda()
noised_z = noised_z / torch.sqrt(
1.0 + sigmas[0] ** 2.0
) # Note: hardcoded to DDPM-like scaling. need to generalize later.
else:
noised_z = z / torch.sqrt(1.0 + sigmas[0] ** 2.0)
def denoiser(x, sigma, c):
return model.denoiser(model.model, x, sigma, c)
load_model(model.denoiser)
load_model(model.model)
samples_z = sampler(denoiser, noised_z, cond=c, uc=uc)
unload_model(model.model)
unload_model(model.denoiser)
load_model(model.first_stage_model)
samples_x = model.decode_first_stage(samples_z)
unload_model(model.first_stage_model)
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
if filter is not None:
samples = filter(samples)
grid = embed_watemark(torch.stack([samples]))
grid = rearrange(grid, "n b c h w -> (n h) (b w) c")
outputs.image(grid.cpu().numpy())
if return_latents:
return samples, samples_z
return samples

View File

@@ -1,4 +1,4 @@
from .models import AutoencodingEngine, DiffusionEngine from .models import AutoencodingEngine, DiffusionEngine
from .util import get_configs_path, instantiate_from_config from .util import get_configs_path, instantiate_from_config
__version__ = "0.1.1" __version__ = "0.1.0"

View File

@@ -1,14 +1,11 @@
from dataclasses import dataclass, asdict from dataclasses import dataclass, asdict
from enum import Enum from enum import Enum
from omegaconf import OmegaConf from omegaconf import OmegaConf
import os import pathlib
from sgm.inference.helpers import ( from sgm.inference.helpers import (
do_sample, do_sample,
do_img2img, do_img2img,
DeviceModelManager,
get_model_manager,
Img2ImgDiscretizationWrapper, Img2ImgDiscretizationWrapper,
Txt2NoisyDiscretizationWrapper,
) )
from sgm.modules.diffusionmodules.sampling import ( from sgm.modules.diffusionmodules.sampling import (
EulerEDMSampler, EulerEDMSampler,
@@ -18,18 +15,17 @@ from sgm.modules.diffusionmodules.sampling import (
DPMPP2MSampler, DPMPP2MSampler,
LinearMultistepSampler, LinearMultistepSampler,
) )
from sgm.util import load_model_from_config, get_configs_path, get_checkpoints_path from sgm.util import load_model_from_config
import torch from typing import Optional
from typing import Optional, Dict, Any, Union
class ModelArchitecture(str, Enum): class ModelArchitecture(str, Enum):
SDXL_V1_0_BASE = "stable-diffusion-xl-v1-base"
SDXL_V1_0_REFINER = "stable-diffusion-xl-v1-refiner"
SDXL_V0_9_BASE = "stable-diffusion-xl-v0-9-base"
SDXL_V0_9_REFINER = "stable-diffusion-xl-v0-9-refiner"
SD_2_1 = "stable-diffusion-v2-1" SD_2_1 = "stable-diffusion-v2-1"
SD_2_1_768 = "stable-diffusion-v2-1-768" SD_2_1_768 = "stable-diffusion-v2-1-768"
SDXL_V0_9_BASE = "stable-diffusion-xl-v0-9-base"
SDXL_V0_9_REFINER = "stable-diffusion-xl-v0-9-refiner"
SDXL_V1_BASE = "stable-diffusion-xl-v1-base"
SDXL_V1_REFINER = "stable-diffusion-xl-v1-refiner"
class Sampler(str, Enum): class Sampler(str, Enum):
@@ -57,20 +53,16 @@ class Thresholder(str, Enum):
@dataclass @dataclass
class SamplingParams: class SamplingParams:
""" width: int = 1024
Parameters for sampling. height: int = 1024
""" steps: int = 50
sampler: Sampler = Sampler.DPMPP2M
width: Optional[int] = None
height: Optional[int] = None
steps: Optional[int] = None
sampler: Sampler = Sampler.EULER_EDM
discretization: Discretization = Discretization.LEGACY_DDPM discretization: Discretization = Discretization.LEGACY_DDPM
guider: Guider = Guider.VANILLA guider: Guider = Guider.VANILLA
thresholder: Thresholder = Thresholder.NONE thresholder: Thresholder = Thresholder.NONE
scale: float = 5.0 scale: float = 6.0
aesthetic_score: float = 6.0 aesthetic_score: float = 5.0
negative_aesthetic_score: float = 2.5 negative_aesthetic_score: float = 5.0
img2img_strength: float = 1.0 img2img_strength: float = 1.0
orig_width: int = 1024 orig_width: int = 1024
orig_height: int = 1024 orig_height: int = 1024
@@ -97,10 +89,8 @@ class SamplingSpec:
config: str config: str
ckpt: str ckpt: str
is_guided: bool is_guided: bool
default_params: SamplingParams
# The defaults here are derived from user preference testing.
model_specs = { model_specs = {
ModelArchitecture.SD_2_1: SamplingSpec( ModelArchitecture.SD_2_1: SamplingSpec(
height=512, height=512,
@@ -111,12 +101,6 @@ model_specs = {
config="sd_2_1.yaml", config="sd_2_1.yaml",
ckpt="v2-1_512-ema-pruned.safetensors", ckpt="v2-1_512-ema-pruned.safetensors",
is_guided=True, is_guided=True,
default_params=SamplingParams(
width=512,
height=512,
steps=40,
scale=7.0,
),
), ),
ModelArchitecture.SD_2_1_768: SamplingSpec( ModelArchitecture.SD_2_1_768: SamplingSpec(
height=768, height=768,
@@ -127,12 +111,6 @@ model_specs = {
config="sd_2_1_768.yaml", config="sd_2_1_768.yaml",
ckpt="v2-1_768-ema-pruned.safetensors", ckpt="v2-1_768-ema-pruned.safetensors",
is_guided=True, is_guided=True,
default_params=SamplingParams(
width=768,
height=768,
steps=40,
scale=7.0,
),
), ),
ModelArchitecture.SDXL_V0_9_BASE: SamplingSpec( ModelArchitecture.SDXL_V0_9_BASE: SamplingSpec(
height=1024, height=1024,
@@ -143,7 +121,6 @@ model_specs = {
config="sd_xl_base.yaml", config="sd_xl_base.yaml",
ckpt="sd_xl_base_0.9.safetensors", ckpt="sd_xl_base_0.9.safetensors",
is_guided=True, is_guided=True,
default_params=SamplingParams(width=1024, height=1024, steps=40, scale=5.0),
), ),
ModelArchitecture.SDXL_V0_9_REFINER: SamplingSpec( ModelArchitecture.SDXL_V0_9_REFINER: SamplingSpec(
height=1024, height=1024,
@@ -154,11 +131,8 @@ model_specs = {
config="sd_xl_refiner.yaml", config="sd_xl_refiner.yaml",
ckpt="sd_xl_refiner_0.9.safetensors", ckpt="sd_xl_refiner_0.9.safetensors",
is_guided=True, is_guided=True,
default_params=SamplingParams(
width=1024, height=1024, steps=40, scale=5.0, img2img_strength=0.15
),
), ),
ModelArchitecture.SDXL_V1_0_BASE: SamplingSpec( ModelArchitecture.SDXL_V1_BASE: SamplingSpec(
height=1024, height=1024,
width=1024, width=1024,
channels=4, channels=4,
@@ -167,9 +141,8 @@ model_specs = {
config="sd_xl_base.yaml", config="sd_xl_base.yaml",
ckpt="sd_xl_base_1.0.safetensors", ckpt="sd_xl_base_1.0.safetensors",
is_guided=True, is_guided=True,
default_params=SamplingParams(width=1024, height=1024, steps=40, scale=5.0),
), ),
ModelArchitecture.SDXL_V1_0_REFINER: SamplingSpec( ModelArchitecture.SDXL_V1_REFINER: SamplingSpec(
height=1024, height=1024,
width=1024, width=1024,
channels=4, channels=4,
@@ -178,97 +151,34 @@ model_specs = {
config="sd_xl_refiner.yaml", config="sd_xl_refiner.yaml",
ckpt="sd_xl_refiner_1.0.safetensors", ckpt="sd_xl_refiner_1.0.safetensors",
is_guided=True, is_guided=True,
default_params=SamplingParams(
width=1024, height=1024, steps=40, scale=5.0, img2img_strength=0.15
),
), ),
} }
def wrap_discretization(
discretization, image_strength=None, noise_strength=None, steps=None
):
if isinstance(discretization, Img2ImgDiscretizationWrapper) or isinstance(
discretization, Txt2NoisyDiscretizationWrapper
):
return discretization # Already wrapped
if image_strength is not None and image_strength < 1.0 and image_strength > 0.0:
discretization = Img2ImgDiscretizationWrapper(
discretization, strength=image_strength
)
if (
noise_strength is not None
and noise_strength < 1.0
and noise_strength > 0.0
and steps is not None
):
discretization = Txt2NoisyDiscretizationWrapper(
discretization,
strength=noise_strength,
original_steps=steps,
)
return discretization
class SamplingPipeline: class SamplingPipeline:
def __init__( def __init__(
self, self,
model_id: Optional[ModelArchitecture] = None, model_id: ModelArchitecture,
model_spec: Optional[SamplingSpec] = None, model_path="checkpoints",
model_path: Optional[str] = None, config_path="configs/inference",
config_path: Optional[str] = None, device="cuda",
use_fp16: bool = True, use_fp16=True,
device: Optional[Union[DeviceModelManager, str, torch.device]] = None,
) -> None: ) -> None:
""" if model_id not in model_specs:
Sampling pipeline for generating images from a model. raise ValueError(f"Model {model_id} not supported")
@param model_id: Model architecture to use. If not specified, model_spec must be specified.
@param model_spec: Model specification to use. If not specified, model_id must be specified.
@param model_path: Path to model checkpoints folder.
@param config_path: Path to model config folder.
@param use_fp16: Whether to use fp16 for sampling.
@param device: Device manager to use with this pipeline. If a string or torch.device is passed, a device manager will be created based on device type if possible.
"""
self.model_id = model_id self.model_id = model_id
if model_spec is not None: self.specs = model_specs[self.model_id]
self.specs = model_spec self.config = str(pathlib.Path(config_path, self.specs.config))
elif model_id is not None: self.ckpt = str(pathlib.Path(model_path, self.specs.ckpt))
if model_id not in model_specs: self.device = device
raise ValueError(f"Model {model_id} not supported") self.model = self._load_model(device=device, use_fp16=use_fp16)
self.specs = model_specs[model_id]
else:
raise ValueError("Either model_id or model_spec should be provided")
if model_path is None: def _load_model(self, device="cuda", use_fp16=True):
model_path = get_checkpoints_path()
if config_path is None:
config_path = get_configs_path()
self.config = os.path.join(config_path, "inference", self.specs.config)
self.ckpt = os.path.join(model_path, self.specs.ckpt)
if not os.path.exists(self.config):
raise ValueError(
f"Config {self.config} not found, check model spec or config_path"
)
if not os.path.exists(self.ckpt):
raise ValueError(
f"Checkpoint {self.ckpt} not found, check model spec or config_path"
)
self.device_manager = get_model_manager(device)
self.model = self._load_model(
device_manager=self.device_manager, use_fp16=use_fp16
)
def _load_model(self, device_manager: DeviceModelManager, use_fp16=True):
config = OmegaConf.load(self.config) config = OmegaConf.load(self.config)
model = load_model_from_config(config, self.ckpt) model = load_model_from_config(config, self.ckpt)
if model is None: if model is None:
raise ValueError(f"Model {self.model_id} could not be loaded") raise ValueError(f"Model {self.model_id} could not be loaded")
device_manager.load(model) model.to(device)
if use_fp16: if use_fp16:
model.conditioner.half() model.conditioner.half()
model.model.half() model.model.half()
@@ -276,34 +186,13 @@ class SamplingPipeline:
def text_to_image( def text_to_image(
self, self,
params: SamplingParams,
prompt: str, prompt: str,
params: Optional[SamplingParams] = None,
negative_prompt: str = "", negative_prompt: str = "",
samples: int = 1, samples: int = 1,
return_latents: bool = False, return_latents: bool = False,
noise_strength: Optional[float] = None,
filter=None,
): ):
if params is None:
params = self.specs.default_params
else:
# Set defaults if optional params are not specified
if params.width is None:
params.width = self.specs.default_params.width
if params.height is None:
params.height = self.specs.default_params.height
if params.steps is None:
params.steps = self.specs.default_params.steps
sampler = get_sampler_config(params) sampler = get_sampler_config(params)
sampler.discretization = wrap_discretization(
sampler.discretization,
image_strength=None,
noise_strength=noise_strength,
steps=params.steps,
)
value_dict = asdict(params) value_dict = asdict(params)
value_dict["prompt"] = prompt value_dict["prompt"] = prompt
value_dict["negative_prompt"] = negative_prompt value_dict["negative_prompt"] = negative_prompt
@@ -320,40 +209,31 @@ class SamplingPipeline:
self.specs.factor, self.specs.factor,
force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [], force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
return_latents=return_latents, return_latents=return_latents,
filter=filter, filter=None,
device=self.device_manager,
) )
def image_to_image( def image_to_image(
self, self,
image: torch.Tensor, params: SamplingParams,
image,
prompt: str, prompt: str,
params: Optional[SamplingParams] = None,
negative_prompt: str = "", negative_prompt: str = "",
samples: int = 1, samples: int = 1,
return_latents: bool = False, return_latents: bool = False,
noise_strength: Optional[float] = None,
filter=None,
): ):
if params is None:
params = self.specs.default_params
sampler = get_sampler_config(params) sampler = get_sampler_config(params)
sampler.discretization = wrap_discretization( if params.img2img_strength < 1.0:
sampler.discretization, sampler.discretization = Img2ImgDiscretizationWrapper(
image_strength=params.img2img_strength, sampler.discretization,
noise_strength=noise_strength, strength=params.img2img_strength,
steps=params.steps, )
)
height, width = image.shape[2], image.shape[3] height, width = image.shape[2], image.shape[3]
value_dict = asdict(params) value_dict = asdict(params)
value_dict["prompt"] = prompt value_dict["prompt"] = prompt
value_dict["negative_prompt"] = negative_prompt value_dict["negative_prompt"] = negative_prompt
value_dict["target_width"] = width value_dict["target_width"] = width
value_dict["target_height"] = height value_dict["target_height"] = height
value_dict["orig_width"] = width
value_dict["orig_height"] = height
return do_img2img( return do_img2img(
image, image,
self.model, self.model,
@@ -362,24 +242,18 @@ class SamplingPipeline:
samples, samples,
force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [], force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
return_latents=return_latents, return_latents=return_latents,
filter=filter, filter=None,
device=self.device_manager,
) )
def refiner( def refiner(
self, self,
image: torch.Tensor, params: SamplingParams,
image,
prompt: str, prompt: str,
negative_prompt: str = "", negative_prompt: Optional[str] = None,
params: Optional[SamplingParams] = None,
samples: int = 1, samples: int = 1,
return_latents: bool = False, return_latents: bool = False,
filter: Any = None,
add_noise: bool = False,
): ):
if params is None:
params = self.specs.default_params
sampler = get_sampler_config(params) sampler = get_sampler_config(params)
value_dict = { value_dict = {
"orig_width": image.shape[3] * 8, "orig_width": image.shape[3] * 8,
@@ -394,10 +268,6 @@ class SamplingPipeline:
"negative_aesthetic_score": 2.5, "negative_aesthetic_score": 2.5,
} }
sampler.discretization = wrap_discretization(
sampler.discretization, image_strength=params.img2img_strength
)
return do_img2img( return do_img2img(
image, image,
self.model, self.model,
@@ -406,14 +276,11 @@ class SamplingPipeline:
samples, samples,
skip_encode=True, skip_encode=True,
return_latents=return_latents, return_latents=return_latents,
filter=filter, filter=None,
add_noise=add_noise,
device=self.device_manager,
) )
def get_guider_config(params: SamplingParams) -> Dict[str, Any]: def get_guider_config(params: SamplingParams):
guider_config: Dict[str, Any]
if params.guider == Guider.IDENTITY: if params.guider == Guider.IDENTITY:
guider_config = { guider_config = {
"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider" "target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"
@@ -439,8 +306,7 @@ def get_guider_config(params: SamplingParams) -> Dict[str, Any]:
return guider_config return guider_config
def get_discretization_config(params: SamplingParams) -> Dict[str, Any]: def get_discretization_config(params: SamplingParams):
discretization_config: Dict[str, Any]
if params.discretization == Discretization.LEGACY_DDPM: if params.discretization == Discretization.LEGACY_DDPM:
discretization_config = { discretization_config = {
"target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization", "target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",

View File

@@ -1,4 +1,3 @@
import contextlib
import os import os
from typing import Union, List, Optional from typing import Union, List, Optional
@@ -9,6 +8,7 @@ from PIL import Image
from einops import rearrange from einops import rearrange
from imwatermark import WatermarkEncoder from imwatermark import WatermarkEncoder
from omegaconf import ListConfig from omegaconf import ListConfig
from torch import autocast
from sgm.util import append_dims from sgm.util import append_dims
@@ -58,73 +58,6 @@ WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
embed_watermark = WatermarkEmbedder(WATERMARK_BITS) embed_watermark = WatermarkEmbedder(WATERMARK_BITS)
class DeviceModelManager(object):
"""
Default model loading class, should work for all device classes.
"""
def __init__(
self,
device: Union[torch.device, str],
swap_device: Optional[Union[torch.device, str]] = None,
) -> None:
"""
Args:
device (Union[torch.device, str]): The device to use for the model.
"""
self.device = torch.device(device)
self.swap_device = (
torch.device(swap_device) if swap_device is not None else self.device
)
def load(self, model: torch.nn.Module) -> None:
"""
Loads a model to the (swap) device.
"""
model.to(self.swap_device)
def autocast(self):
"""
Context manager that enables autocast for the device if supported.
"""
if self.device.type not in ("cuda", "cpu"):
return contextlib.nullcontext()
return torch.autocast(self.device.type)
@contextlib.contextmanager
def use(self, model: torch.nn.Module):
"""
Context manager that ensures a model is on the correct device during use.
The default model loader does not perform any swapping, so the model will
stay on device.
"""
try:
model.to(self.device)
yield
finally:
if self.device != self.swap_device:
model.to(self.swap_device)
class CudaModelManager(DeviceModelManager):
"""
Device manager that loads a model to a CUDA device, optionally swapping to CPU when not in use.
"""
@contextlib.contextmanager
def use(self, model: Union[torch.nn.Module, torch.Tensor]):
"""
Context manager that ensures a model is on the correct device during use.
If a swap device was provided, this will move the model to it after use and clear cache.
"""
model.to(self.device)
yield
if self.device != self.swap_device:
model.to(self.swap_device)
if torch.cuda.is_available():
torch.cuda.empty_cache()
def get_unique_embedder_keys_from_conditioner(conditioner): def get_unique_embedder_keys_from_conditioner(conditioner):
return list({x.input_key for x in conditioner.embedders}) return list({x.input_key for x in conditioner.embedders})
@@ -141,20 +74,6 @@ def perform_save_locally(save_path, samples):
base_count += 1 base_count += 1
def get_model_manager(
device: Optional[Union[DeviceModelManager, str, torch.device]]
) -> DeviceModelManager:
if isinstance(device, DeviceModelManager):
return device
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(device)
if device.type == "cuda":
return CudaModelManager(device=device)
else:
return DeviceModelManager(device=device)
class Img2ImgDiscretizationWrapper: class Img2ImgDiscretizationWrapper:
""" """
wraps a discretizer, and prunes the sigmas wraps a discretizer, and prunes the sigmas
@@ -179,36 +98,6 @@ class Img2ImgDiscretizationWrapper:
return sigmas return sigmas
class Txt2NoisyDiscretizationWrapper:
"""
wraps a discretizer, and prunes the sigmas
params:
strength: float between 0.0 and 1.0. 0.0 means full sampling (all sigmas are returned)
"""
def __init__(self, discretization, strength: float = 0.0, original_steps=None):
self.discretization = discretization
self.strength = strength
self.original_steps = original_steps
assert 0.0 <= self.strength <= 1.0
def __call__(self, *args, **kwargs):
# sigmas start large first, and decrease then
sigmas = self.discretization(*args, **kwargs)
print(f"sigmas after discretization, before pruning img2img: ", sigmas)
sigmas = torch.flip(sigmas, (0,))
if self.original_steps is None:
steps = len(sigmas)
else:
steps = self.original_steps + 1
prune_index = max(min(int(self.strength * steps) - 1, steps - 1), 0)
sigmas = sigmas[prune_index:]
print("prune index:", prune_index)
sigmas = torch.flip(sigmas, (0,))
print(f"sigmas after pruning: ", sigmas)
return sigmas
def do_sample( def do_sample(
model, model,
sampler, sampler,
@@ -222,45 +111,39 @@ def do_sample(
batch2model_input: Optional[List] = None, batch2model_input: Optional[List] = None,
return_latents=False, return_latents=False,
filter=None, filter=None,
device: Optional[Union[DeviceModelManager, str, torch.device]] = None, device="cuda",
): ):
if force_uc_zero_embeddings is None: if force_uc_zero_embeddings is None:
force_uc_zero_embeddings = [] force_uc_zero_embeddings = []
if batch2model_input is None: if batch2model_input is None:
batch2model_input = [] batch2model_input = []
device_manager = get_model_manager(device=device)
with torch.no_grad(): with torch.no_grad():
with device_manager.autocast(): with autocast(device) as precision_scope:
with model.ema_scope(): with model.ema_scope():
num_samples = [num_samples] num_samples = [num_samples]
with device_manager.use(model.conditioner): batch, batch_uc = get_batch(
batch, batch_uc = get_batch( get_unique_embedder_keys_from_conditioner(model.conditioner),
get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict,
value_dict, num_samples,
num_samples, )
) for key in batch:
for key in batch: if isinstance(batch[key], torch.Tensor):
if isinstance(batch[key], torch.Tensor): print(key, batch[key].shape)
print(key, batch[key].shape) elif isinstance(batch[key], list):
elif isinstance(batch[key], list): print(key, [len(l) for l in batch[key]])
print(key, [len(l) for l in batch[key]]) else:
else: print(key, batch[key])
print(key, batch[key]) c, uc = model.conditioner.get_unconditional_conditioning(
c, uc = model.conditioner.get_unconditional_conditioning( batch,
batch, batch_uc=batch_uc,
batch_uc=batch_uc, force_uc_zero_embeddings=force_uc_zero_embeddings,
force_uc_zero_embeddings=force_uc_zero_embeddings, )
)
for k in c: for k in c:
if not k == "crossattn": if not k == "crossattn":
c[k], uc[k] = map( c[k], uc[k] = map(
lambda y: y[k][: math.prod(num_samples)].to( lambda y: y[k][: math.prod(num_samples)].to(device), (c, uc)
device_manager.device
),
(c, uc),
) )
additional_model_inputs = {} additional_model_inputs = {}
@@ -268,20 +151,16 @@ def do_sample(
additional_model_inputs[k] = batch[k] additional_model_inputs[k] = batch[k]
shape = (math.prod(num_samples), C, H // F, W // F) shape = (math.prod(num_samples), C, H // F, W // F)
randn = torch.randn(shape).to(device_manager.device) randn = torch.randn(shape).to(device)
def denoiser(input, sigma, c): def denoiser(input, sigma, c):
return model.denoiser( return model.denoiser(
model.model, input, sigma, c, **additional_model_inputs model.model, input, sigma, c, **additional_model_inputs
) )
with device_manager.use(model.denoiser): samples_z = sampler(denoiser, randn, cond=c, uc=uc)
with device_manager.use(model.model): samples_x = model.decode_first_stage(samples_z)
samples_z = sampler(denoiser, randn, cond=c, uc=uc) samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
with device_manager.use(model.first_stage_model):
samples_x = model.decode_first_stage(samples_z)
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
if filter is not None: if filter is not None:
samples = filter(samples) samples = filter(samples)
@@ -373,40 +252,32 @@ def do_img2img(
return_latents=False, return_latents=False,
skip_encode=False, skip_encode=False,
filter=None, filter=None,
add_noise=True, device="cuda",
device: Optional[Union[DeviceModelManager, str, torch.device]] = None,
): ):
device_manager = get_model_manager(device)
with torch.no_grad(): with torch.no_grad():
with device_manager.autocast(): with autocast(device) as precision_scope:
with model.ema_scope(): with model.ema_scope():
with device_manager.use(model.conditioner): batch, batch_uc = get_batch(
batch, batch_uc = get_batch( get_unique_embedder_keys_from_conditioner(model.conditioner),
get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict,
value_dict, [num_samples],
[num_samples], )
) c, uc = model.conditioner.get_unconditional_conditioning(
c, uc = model.conditioner.get_unconditional_conditioning( batch,
batch, batch_uc=batch_uc,
batch_uc=batch_uc, force_uc_zero_embeddings=force_uc_zero_embeddings,
force_uc_zero_embeddings=force_uc_zero_embeddings, )
)
for k in c: for k in c:
c[k], uc[k] = map( c[k], uc[k] = map(lambda y: y[k][:num_samples].to(device), (c, uc))
lambda y: y[k][:num_samples].to(device_manager.device), (c, uc)
)
for k in additional_kwargs: for k in additional_kwargs:
c[k] = uc[k] = additional_kwargs[k] c[k] = uc[k] = additional_kwargs[k]
if skip_encode: if skip_encode:
z = img z = img
else: else:
with device_manager.use(model.first_stage_model): z = model.encode_first_stage(img)
z = model.encode_first_stage(img)
noise = torch.randn_like(z) noise = torch.randn_like(z)
sigmas = sampler.discretization(sampler.num_steps) sigmas = sampler.discretization(sampler.num_steps)
sigma = sigmas[0].to(z.device) sigma = sigmas[0].to(z.device)
@@ -414,24 +285,17 @@ def do_img2img(
noise = noise + offset_noise_level * append_dims( noise = noise + offset_noise_level * append_dims(
torch.randn(z.shape[0], device=z.device), z.ndim torch.randn(z.shape[0], device=z.device), z.ndim
) )
if add_noise: noised_z = z + noise * append_dims(sigma, z.ndim)
noised_z = z + noise * append_dims(sigma, z.ndim).cuda() noised_z = noised_z / torch.sqrt(
noised_z = noised_z / torch.sqrt( 1.0 + sigmas[0] ** 2.0
1.0 + sigmas[0] ** 2.0 ) # Note: hardcoded to DDPM-like scaling. need to generalize later.
) # Note: hardcoded to DDPM-like scaling. need to generalize later.
else:
noised_z = z / torch.sqrt(1.0 + sigmas[0] ** 2.0)
def denoiser(x, sigma, c): def denoiser(x, sigma, c):
return model.denoiser(model.model, x, sigma, c) return model.denoiser(model.model, x, sigma, c)
with device_manager.use(model.denoiser): samples_z = sampler(denoiser, noised_z, cond=c, uc=uc)
with device_manager.use(model.model): samples_x = model.decode_first_stage(samples_z)
samples_z = sampler(denoiser, noised_z, cond=c, uc=uc) samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
with device_manager.use(model.first_stage_model):
samples_x = model.decode_first_stage(samples_z)
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
if filter is not None: if filter is not None:
samples = filter(samples) samples = filter(samples)

View File

@@ -26,7 +26,7 @@ class Discretization:
class EDMDiscretization(Discretization): class EDMDiscretization(Discretization):
def __init__(self, sigma_min=0.002, sigma_max=80.0, rho=7.0): def __init__(self, sigma_min=0.02, sigma_max=80.0, rho=7.0):
self.sigma_min = sigma_min self.sigma_min = sigma_min
self.sigma_max = sigma_max self.sigma_max = sigma_max
self.rho = rho self.rho = rho

View File

@@ -230,24 +230,6 @@ def load_model_from_config(config, ckpt, verbose=True, freeze=True):
return model return model
def get_checkpoints_path() -> str:
"""
Get the `checkpoints` directory.
This could be in the root of the repository for a working copy,
or in the cwd for other use cases.
"""
this_dir = os.path.dirname(__file__)
candidates = (
os.path.join(this_dir, "checkpoints"),
os.path.join(os.getcwd(), "checkpoints"),
)
for candidate in candidates:
candidate = os.path.abspath(candidate)
if os.path.isdir(candidate):
return candidate
raise FileNotFoundError(f"Could not find SGM checkpoints in {candidates}")
def get_configs_path() -> str: def get_configs_path() -> str:
""" """
Get the `configs` directory. Get the `configs` directory.

View File

@@ -27,7 +27,7 @@ class TestInference:
@fixture( @fixture(
scope="class", scope="class",
params=[ params=[
[ModelArchitecture.SDXL_V1_0_BASE, ModelArchitecture.SDXL_V1_0_REFINER], [ModelArchitecture.SDXL_V1_BASE, ModelArchitecture.SDXL_V1_REFINER],
[ModelArchitecture.SDXL_V0_9_BASE, ModelArchitecture.SDXL_V0_9_REFINER], [ModelArchitecture.SDXL_V0_9_BASE, ModelArchitecture.SDXL_V0_9_REFINER],
], ],
ids=["SDXL_V1", "SDXL_V0_9"], ids=["SDXL_V1", "SDXL_V0_9"],
@@ -68,7 +68,9 @@ class TestInference:
assert output is not None assert output is not None
@pytest.mark.parametrize("sampler_enum", Sampler) @pytest.mark.parametrize("sampler_enum", Sampler)
@pytest.mark.parametrize("use_init_image", [True, False], ids=["img2img", "txt2img"]) @pytest.mark.parametrize(
"use_init_image", [True, False], ids=["img2img", "txt2img"]
)
def test_sdxl_with_refiner( def test_sdxl_with_refiner(
self, self,
sdxl_pipelines: Tuple[SamplingPipeline, SamplingPipeline], sdxl_pipelines: Tuple[SamplingPipeline, SamplingPipeline],
@@ -79,12 +81,13 @@ class TestInference:
if use_init_image: if use_init_image:
output = base_pipeline.image_to_image( output = base_pipeline.image_to_image(
params=SamplingParams(sampler=sampler_enum.value, steps=10), params=SamplingParams(sampler=sampler_enum.value, steps=10),
image=self.create_init_image(base_pipeline.specs.height, base_pipeline.specs.width), image=self.create_init_image(
base_pipeline.specs.height, base_pipeline.specs.width
),
prompt="A professional photograph of an astronaut riding a pig", prompt="A professional photograph of an astronaut riding a pig",
negative_prompt="", negative_prompt="",
samples=1, samples=1,
return_latents=True, return_latents=True,
noise_strength=0.15,
) )
else: else:
output = base_pipeline.text_to_image( output = base_pipeline.text_to_image(
@@ -93,7 +96,6 @@ class TestInference:
negative_prompt="", negative_prompt="",
samples=1, samples=1,
return_latents=True, return_latents=True,
noise_strength=0.15,
) )
assert isinstance(output, (tuple, list)) assert isinstance(output, (tuple, list))
@@ -101,9 +103,9 @@ class TestInference:
assert samples is not None assert samples is not None
assert samples_z is not None assert samples_z is not None
refiner_pipeline.refiner( refiner_pipeline.refiner(
params=SamplingParams(sampler=sampler_enum.value, steps=10),
image=samples_z, image=samples_z,
prompt="A professional photograph of an astronaut riding a pig", prompt="A professional photograph of an astronaut riding a pig",
params=SamplingParams(sampler=sampler_enum.value, steps=40, img2img_strength=0.15),
negative_prompt="", negative_prompt="",
samples=1, samples=1,
) )

View File

@@ -1,44 +0,0 @@
import pytest
import torch
from sgm.inference.api import (
SamplingPipeline,
ModelArchitecture,
)
import sgm.inference.helpers as helpers
def get_torch_device(model: torch.nn.Module) -> torch.device:
param = next(model.parameters(), None)
if param is not None:
return param.device
else:
buf = next(model.buffers(), None)
if buf is not None:
return buf.device
else:
raise TypeError("Could not determine device of input model")
@pytest.mark.inference
def test_default_loading():
pipeline = SamplingPipeline(model_id=ModelArchitecture.SD_2_1)
assert get_torch_device(pipeline.model.model).type == "cuda"
assert get_torch_device(pipeline.model.conditioner).type == "cuda"
with pipeline.device_manager.use(pipeline.model.model):
assert get_torch_device(pipeline.model.model).type == "cuda"
assert get_torch_device(pipeline.model.model).type == "cuda"
with pipeline.device_manager.use(pipeline.model.conditioner):
assert get_torch_device(pipeline.model.conditioner).type == "cuda"
assert get_torch_device(pipeline.model.conditioner).type == "cuda"
@pytest.mark.inference
def test_model_swapping():
pipeline = SamplingPipeline(model_id=ModelArchitecture.SD_2_1, device=helpers.CudaModelManager(device="cuda", swap_device="cpu"))
assert get_torch_device(pipeline.model.model).type == "cpu"
assert get_torch_device(pipeline.model.conditioner).type == "cpu"
with pipeline.device_manager.use(pipeline.model.model):
assert get_torch_device(pipeline.model.model).type == "cuda"
assert get_torch_device(pipeline.model.model).type == "cpu"
with pipeline.device_manager.use(pipeline.model.conditioner):
assert get_torch_device(pipeline.model.conditioner).type == "cuda"
assert get_torch_device(pipeline.model.conditioner).type == "cpu"