mirror of
https://github.com/Stability-AI/generative-models.git
synced 2026-02-02 12:24:27 +01:00
Streamlit refactor (#105)
* initial streamlit refactoring pass * cleanup and fixes * fix refiner strength * Modify params correctly * fix exception
This commit is contained in:
@@ -1,8 +1,13 @@
|
||||
from dataclasses import asdict
|
||||
from pytorch_lightning import seed_everything
|
||||
|
||||
from sgm.inference.api import (
|
||||
SamplingParams,
|
||||
ModelArchitecture,
|
||||
SamplingPipeline,
|
||||
model_specs,
|
||||
)
|
||||
from sgm.inference.helpers import (
|
||||
do_img2img,
|
||||
do_sample,
|
||||
get_unique_embedder_keys_from_conditioner,
|
||||
perform_save_locally,
|
||||
)
|
||||
@@ -39,63 +44,6 @@ SD_XL_BASE_RATIOS = {
|
||||
"3.0": (1728, 576),
|
||||
}
|
||||
|
||||
VERSION2SPECS = {
|
||||
"SDXL-base-1.0": {
|
||||
"H": 1024,
|
||||
"W": 1024,
|
||||
"C": 4,
|
||||
"f": 8,
|
||||
"is_legacy": False,
|
||||
"config": "configs/inference/sd_xl_base.yaml",
|
||||
"ckpt": "checkpoints/sd_xl_base_1.0.safetensors",
|
||||
},
|
||||
"SDXL-base-0.9": {
|
||||
"H": 1024,
|
||||
"W": 1024,
|
||||
"C": 4,
|
||||
"f": 8,
|
||||
"is_legacy": False,
|
||||
"config": "configs/inference/sd_xl_base.yaml",
|
||||
"ckpt": "checkpoints/sd_xl_base_0.9.safetensors",
|
||||
},
|
||||
"SD-2.1": {
|
||||
"H": 512,
|
||||
"W": 512,
|
||||
"C": 4,
|
||||
"f": 8,
|
||||
"is_legacy": True,
|
||||
"config": "configs/inference/sd_2_1.yaml",
|
||||
"ckpt": "checkpoints/v2-1_512-ema-pruned.safetensors",
|
||||
},
|
||||
"SD-2.1-768": {
|
||||
"H": 768,
|
||||
"W": 768,
|
||||
"C": 4,
|
||||
"f": 8,
|
||||
"is_legacy": True,
|
||||
"config": "configs/inference/sd_2_1_768.yaml",
|
||||
"ckpt": "checkpoints/v2-1_768-ema-pruned.safetensors",
|
||||
},
|
||||
"SDXL-refiner-0.9": {
|
||||
"H": 1024,
|
||||
"W": 1024,
|
||||
"C": 4,
|
||||
"f": 8,
|
||||
"is_legacy": True,
|
||||
"config": "configs/inference/sd_xl_refiner.yaml",
|
||||
"ckpt": "checkpoints/sd_xl_refiner_0.9.safetensors",
|
||||
},
|
||||
"SDXL-refiner-1.0": {
|
||||
"H": 1024,
|
||||
"W": 1024,
|
||||
"C": 4,
|
||||
"f": 8,
|
||||
"is_legacy": True,
|
||||
"config": "configs/inference/sd_xl_refiner.yaml",
|
||||
"ckpt": "checkpoints/sd_xl_refiner_1.0.safetensors",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def load_img(display=True, key=None, device="cuda"):
|
||||
image = get_interactive_image(key=key)
|
||||
@@ -117,52 +65,48 @@ def load_img(display=True, key=None, device="cuda"):
|
||||
|
||||
def run_txt2img(
|
||||
state,
|
||||
version,
|
||||
version_dict,
|
||||
is_legacy=False,
|
||||
version: str,
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
return_latents=False,
|
||||
filter=None,
|
||||
stage2strength=None,
|
||||
):
|
||||
if version.startswith("SDXL-base"):
|
||||
W, H = st.selectbox("Resolution:", list(SD_XL_BASE_RATIOS.values()), 10)
|
||||
spec: SamplingSpec = state.get("spec")
|
||||
model: SamplingPipeline = state.get("model")
|
||||
params: SamplingParams = state.get("params")
|
||||
if version.startswith("stable-diffusion-xl") and version.endswith("-base"):
|
||||
params.width, params.height = st.selectbox(
|
||||
"Resolution:", list(SD_XL_BASE_RATIOS.values()), 10
|
||||
)
|
||||
else:
|
||||
H = st.number_input("H", value=version_dict["H"], min_value=64, max_value=2048)
|
||||
W = st.number_input("W", value=version_dict["W"], min_value=64, max_value=2048)
|
||||
C = version_dict["C"]
|
||||
F = version_dict["f"]
|
||||
params.height = int(
|
||||
st.number_input("H", value=spec.height, min_value=64, max_value=2048)
|
||||
)
|
||||
params.width = int(
|
||||
st.number_input("W", value=spec.width, min_value=64, max_value=2048)
|
||||
)
|
||||
|
||||
init_dict = {
|
||||
"orig_width": W,
|
||||
"orig_height": H,
|
||||
"target_width": W,
|
||||
"target_height": H,
|
||||
}
|
||||
value_dict = init_embedder_options(
|
||||
get_unique_embedder_keys_from_conditioner(state["model"].conditioner),
|
||||
init_dict,
|
||||
params = init_embedder_options(
|
||||
get_unique_embedder_keys_from_conditioner(model.model.conditioner),
|
||||
params=params,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
)
|
||||
sampler, num_rows, num_cols = init_sampling(stage2strength=stage2strength)
|
||||
params, num_rows, num_cols = init_sampling(params=params)
|
||||
num_samples = num_rows * num_cols
|
||||
|
||||
if st.button("Sample"):
|
||||
st.write(f"**Model I:** {version}")
|
||||
outputs = st.empty()
|
||||
st.text("Sampling")
|
||||
out = do_sample(
|
||||
state["model"],
|
||||
sampler,
|
||||
value_dict,
|
||||
num_samples,
|
||||
H,
|
||||
W,
|
||||
C,
|
||||
F,
|
||||
force_uc_zero_embeddings=["txt"] if not is_legacy else [],
|
||||
out = model.text_to_image(
|
||||
params=params,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
samples=int(num_samples),
|
||||
return_latents=return_latents,
|
||||
filter=filter,
|
||||
noise_strength=stage2strength,
|
||||
filter=state.get("filter"),
|
||||
)
|
||||
|
||||
show_samples(out, outputs)
|
||||
@@ -172,51 +116,45 @@ def run_txt2img(
|
||||
|
||||
def run_img2img(
|
||||
state,
|
||||
version_dict,
|
||||
is_legacy=False,
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
return_latents=False,
|
||||
filter=None,
|
||||
stage2strength=None,
|
||||
):
|
||||
model: SamplingPipeline = state.get("model")
|
||||
params: SamplingParams = state.get("params")
|
||||
|
||||
img = load_img()
|
||||
if img is None:
|
||||
return None
|
||||
H, W = img.shape[2], img.shape[3]
|
||||
params.height, params.width = img.shape[2], img.shape[3]
|
||||
|
||||
init_dict = {
|
||||
"orig_width": W,
|
||||
"orig_height": H,
|
||||
"target_width": W,
|
||||
"target_height": H,
|
||||
}
|
||||
value_dict = init_embedder_options(
|
||||
get_unique_embedder_keys_from_conditioner(state["model"].conditioner),
|
||||
init_dict,
|
||||
params = init_embedder_options(
|
||||
get_unique_embedder_keys_from_conditioner(model.model.conditioner),
|
||||
params=params,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
)
|
||||
strength = st.number_input(
|
||||
params.img2img_strength = st.number_input(
|
||||
"**Img2Img Strength**", value=0.75, min_value=0.0, max_value=1.0
|
||||
)
|
||||
sampler, num_rows, num_cols = init_sampling(
|
||||
img2img_strength=strength,
|
||||
stage2strength=stage2strength,
|
||||
)
|
||||
params, num_rows, num_cols = init_sampling(params=params)
|
||||
num_samples = num_rows * num_cols
|
||||
|
||||
if st.button("Sample"):
|
||||
outputs = st.empty()
|
||||
st.text("Sampling")
|
||||
out = do_img2img(
|
||||
repeat(img, "1 ... -> n ...", n=num_samples),
|
||||
state["model"],
|
||||
sampler,
|
||||
value_dict,
|
||||
num_samples,
|
||||
force_uc_zero_embeddings=["txt"] if not is_legacy else [],
|
||||
out = model.image_to_image(
|
||||
image=repeat(img, "1 ... -> n ...", n=num_samples),
|
||||
params=params,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
samples=int(num_samples),
|
||||
return_latents=return_latents,
|
||||
filter=filter,
|
||||
noise_strength=stage2strength,
|
||||
filter=state.get("filter"),
|
||||
)
|
||||
|
||||
show_samples(out, outputs)
|
||||
return out
|
||||
|
||||
@@ -224,39 +162,29 @@ def run_img2img(
|
||||
def apply_refiner(
|
||||
input,
|
||||
state,
|
||||
sampler,
|
||||
num_samples,
|
||||
prompt,
|
||||
negative_prompt,
|
||||
filter=None,
|
||||
num_samples: int,
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
finish_denoising=False,
|
||||
):
|
||||
init_dict = {
|
||||
"orig_width": input.shape[3] * 8,
|
||||
"orig_height": input.shape[2] * 8,
|
||||
"target_width": input.shape[3] * 8,
|
||||
"target_height": input.shape[2] * 8,
|
||||
}
|
||||
model: SamplingPipeline = state.get("model")
|
||||
params: SamplingParams = state.get("params")
|
||||
|
||||
value_dict = init_dict
|
||||
value_dict["prompt"] = prompt
|
||||
value_dict["negative_prompt"] = negative_prompt
|
||||
|
||||
value_dict["crop_coords_top"] = 0
|
||||
value_dict["crop_coords_left"] = 0
|
||||
|
||||
value_dict["aesthetic_score"] = 6.0
|
||||
value_dict["negative_aesthetic_score"] = 2.5
|
||||
params.orig_width = input.shape[3] * 8
|
||||
params.orig_height = input.shape[2] * 8
|
||||
params.width = input.shape[3] * 8
|
||||
params.height = input.shape[2] * 8
|
||||
|
||||
st.warning(f"refiner input shape: {input.shape}")
|
||||
samples = do_img2img(
|
||||
input,
|
||||
state["model"],
|
||||
sampler,
|
||||
value_dict,
|
||||
num_samples,
|
||||
skip_encode=True,
|
||||
filter=filter,
|
||||
|
||||
samples = model.refiner(
|
||||
image=input,
|
||||
params=params,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
samples=num_samples,
|
||||
return_latents=False,
|
||||
filter=state.get("filter"),
|
||||
add_noise=not finish_denoising,
|
||||
)
|
||||
|
||||
@@ -265,28 +193,34 @@ def apply_refiner(
|
||||
|
||||
if __name__ == "__main__":
|
||||
st.title("Stable Diffusion")
|
||||
version = st.selectbox("Model Version", list(VERSION2SPECS.keys()), 0)
|
||||
version_dict = VERSION2SPECS[version]
|
||||
version = st.selectbox(
|
||||
"Model Version",
|
||||
[member.value for member in ModelArchitecture],
|
||||
0,
|
||||
)
|
||||
version_enum = ModelArchitecture(version)
|
||||
specs = model_specs[version_enum]
|
||||
mode = st.radio("Mode", ("txt2img", "img2img"), 0)
|
||||
st.write("__________________________")
|
||||
|
||||
set_lowvram_mode(st.checkbox("Low vram mode", True))
|
||||
|
||||
if version.startswith("SDXL-base"):
|
||||
if str(version).startswith("stable-diffusion-xl"):
|
||||
add_pipeline = st.checkbox("Load SDXL-refiner?", False)
|
||||
st.write("__________________________")
|
||||
else:
|
||||
add_pipeline = False
|
||||
|
||||
seed = st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9))
|
||||
seed = int(
|
||||
st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9))
|
||||
)
|
||||
seed_everything(seed)
|
||||
|
||||
save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, version))
|
||||
|
||||
state = init_st(version_dict, load_filter=True)
|
||||
save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, str(version)))
|
||||
state = init_st(model_specs[version_enum], load_filter=True)
|
||||
model = state["model"]
|
||||
|
||||
is_legacy = version_dict["is_legacy"]
|
||||
is_legacy = specs.is_legacy
|
||||
|
||||
prompt = st.text_input(
|
||||
"prompt",
|
||||
@@ -302,46 +236,59 @@ if __name__ == "__main__":
|
||||
|
||||
if add_pipeline:
|
||||
st.write("__________________________")
|
||||
version2 = st.selectbox("Refiner:", ["SDXL-refiner-1.0", "SDXL-refiner-0.9"])
|
||||
version2 = ModelArchitecture(
|
||||
st.selectbox(
|
||||
"Refiner:",
|
||||
[
|
||||
ModelArchitecture.SDXL_V1_REFINER.value,
|
||||
ModelArchitecture.SDXL_V0_9_REFINER.value,
|
||||
],
|
||||
)
|
||||
)
|
||||
st.warning(
|
||||
f"Running with {version2} as the second stage model. Make sure to provide (V)RAM :) "
|
||||
)
|
||||
st.write("**Refiner Options:**")
|
||||
|
||||
version_dict2 = VERSION2SPECS[version2]
|
||||
state2 = init_st(version_dict2, load_filter=False)
|
||||
specs2 = model_specs[version2]
|
||||
state2 = init_st(specs2, load_filter=False)
|
||||
params2 = state2["params"]
|
||||
|
||||
stage2strength = st.number_input(
|
||||
params2.img2img_strength = st.number_input(
|
||||
"**Refinement strength**", value=0.15, min_value=0.0, max_value=1.0
|
||||
)
|
||||
|
||||
sampler2, *_ = init_sampling(
|
||||
params2, *_ = init_sampling(
|
||||
key=2,
|
||||
img2img_strength=stage2strength,
|
||||
params=state2["params"],
|
||||
specify_num_samples=False,
|
||||
)
|
||||
st.write("__________________________")
|
||||
finish_denoising = st.checkbox("Finish denoising with refiner.", True)
|
||||
if not finish_denoising:
|
||||
if finish_denoising:
|
||||
stage2strength = params2.img2img_strength
|
||||
else:
|
||||
stage2strength = None
|
||||
else:
|
||||
state2 = None
|
||||
params2 = None
|
||||
stage2strength = None
|
||||
|
||||
if mode == "txt2img":
|
||||
out = run_txt2img(
|
||||
state,
|
||||
version,
|
||||
version_dict,
|
||||
is_legacy=is_legacy,
|
||||
state=state,
|
||||
version=str(version),
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
return_latents=add_pipeline,
|
||||
filter=state.get("filter"),
|
||||
stage2strength=stage2strength,
|
||||
)
|
||||
elif mode == "img2img":
|
||||
out = run_img2img(
|
||||
state,
|
||||
version_dict,
|
||||
is_legacy=is_legacy,
|
||||
state=state,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
return_latents=add_pipeline,
|
||||
filter=state.get("filter"),
|
||||
stage2strength=stage2strength,
|
||||
)
|
||||
else:
|
||||
@@ -356,13 +303,11 @@ if __name__ == "__main__":
|
||||
outputs = st.empty()
|
||||
st.write("**Running Refinement Stage**")
|
||||
samples = apply_refiner(
|
||||
samples_z,
|
||||
state2,
|
||||
sampler2,
|
||||
samples_z.shape[0],
|
||||
input=samples_z,
|
||||
state=state2,
|
||||
num_samples=samples_z.shape[0],
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt if is_legacy else "",
|
||||
filter=state.get("filter"),
|
||||
finish_denoising=finish_denoising,
|
||||
)
|
||||
show_samples(samples, outputs)
|
||||
|
||||
@@ -4,43 +4,46 @@ import numpy as np
|
||||
import streamlit as st
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
from omegaconf import OmegaConf
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
from typing import Optional, Tuple
|
||||
|
||||
|
||||
from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering
|
||||
from sgm.modules.diffusionmodules.sampling import (
|
||||
DPMPP2MSampler,
|
||||
DPMPP2SAncestralSampler,
|
||||
EulerAncestralSampler,
|
||||
EulerEDMSampler,
|
||||
HeunEDMSampler,
|
||||
LinearMultistepSampler,
|
||||
|
||||
from sgm.inference.api import (
|
||||
Discretization,
|
||||
Guider,
|
||||
Sampler,
|
||||
SamplingParams,
|
||||
SamplingSpec,
|
||||
SamplingPipeline,
|
||||
Thresholder,
|
||||
)
|
||||
from sgm.inference.helpers import (
|
||||
Img2ImgDiscretizationWrapper,
|
||||
Txt2NoisyDiscretizationWrapper,
|
||||
embed_watermark,
|
||||
)
|
||||
from sgm.util import load_model_from_config
|
||||
|
||||
|
||||
@st.cache_resource()
|
||||
def init_st(version_dict, load_ckpt=True, load_filter=True):
|
||||
def init_st(spec: SamplingSpec, load_ckpt=True, load_filter=True):
|
||||
global lowvram_mode
|
||||
state = dict()
|
||||
if not "model" in state:
|
||||
config = version_dict["config"]
|
||||
ckpt = version_dict["ckpt"]
|
||||
config = spec.config
|
||||
ckpt = spec.ckpt
|
||||
|
||||
config = OmegaConf.load(config)
|
||||
model = load_model_from_config(
|
||||
config, ckpt if load_ckpt else None, freeze=False
|
||||
pipeline = SamplingPipeline(
|
||||
model_spec=spec,
|
||||
use_fp16=lowvram_mode,
|
||||
device="cpu" if lowvram_mode else "cuda",
|
||||
)
|
||||
|
||||
state["model"] = model
|
||||
state["spec"] = spec
|
||||
state["model"] = pipeline
|
||||
state["ckpt"] = ckpt if load_ckpt else None
|
||||
state["config"] = config
|
||||
state["params"] = SamplingParams()
|
||||
if load_filter:
|
||||
state["filter"] = DeepFloydDataFiltering(verbose=False)
|
||||
return state
|
||||
@@ -54,23 +57,13 @@ def set_lowvram_mode(mode):
|
||||
lowvram_mode = mode
|
||||
|
||||
|
||||
def initial_model_load(model):
|
||||
global lowvram_mode
|
||||
if lowvram_mode:
|
||||
model.model.half()
|
||||
else:
|
||||
model.cuda()
|
||||
return model
|
||||
|
||||
|
||||
def get_unique_embedder_keys_from_conditioner(conditioner):
|
||||
return list(set([x.input_key for x in conditioner.embedders]))
|
||||
|
||||
|
||||
def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None):
|
||||
# Hardcoded demo settings; might undergo some changes in the future
|
||||
|
||||
value_dict = {}
|
||||
def init_embedder_options(
|
||||
keys, params: SamplingParams, prompt=None, negative_prompt=None
|
||||
) -> SamplingParams:
|
||||
for key in keys:
|
||||
if key == "txt":
|
||||
if prompt is None:
|
||||
@@ -80,40 +73,32 @@ def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None):
|
||||
if negative_prompt is None:
|
||||
negative_prompt = st.text_input("Negative prompt", "")
|
||||
|
||||
value_dict["prompt"] = prompt
|
||||
value_dict["negative_prompt"] = negative_prompt
|
||||
|
||||
if key == "original_size_as_tuple":
|
||||
orig_width = st.number_input(
|
||||
"orig_width",
|
||||
value=init_dict["orig_width"],
|
||||
value=params.orig_width,
|
||||
min_value=16,
|
||||
)
|
||||
orig_height = st.number_input(
|
||||
"orig_height",
|
||||
value=init_dict["orig_height"],
|
||||
value=params.orig_height,
|
||||
min_value=16,
|
||||
)
|
||||
|
||||
value_dict["orig_width"] = orig_width
|
||||
value_dict["orig_height"] = orig_height
|
||||
params.orig_width = int(orig_width)
|
||||
params.orig_height = int(orig_height)
|
||||
|
||||
if key == "crop_coords_top_left":
|
||||
crop_coord_top = st.number_input("crop_coords_top", value=0, min_value=0)
|
||||
crop_coord_left = st.number_input("crop_coords_left", value=0, min_value=0)
|
||||
crop_coord_top = st.number_input(
|
||||
"crop_coords_top", value=params.crop_coords_top, min_value=0
|
||||
)
|
||||
crop_coord_left = st.number_input(
|
||||
"crop_coords_left", value=params.crop_coords_left, min_value=0
|
||||
)
|
||||
|
||||
value_dict["crop_coords_top"] = crop_coord_top
|
||||
value_dict["crop_coords_left"] = crop_coord_left
|
||||
|
||||
if key == "aesthetic_score":
|
||||
value_dict["aesthetic_score"] = 6.0
|
||||
value_dict["negative_aesthetic_score"] = 2.5
|
||||
|
||||
if key == "target_size_as_tuple":
|
||||
value_dict["target_width"] = init_dict["target_width"]
|
||||
value_dict["target_height"] = init_dict["target_height"]
|
||||
|
||||
return value_dict
|
||||
params.crop_coords_top = int(crop_coord_top)
|
||||
params.crop_coords_left = int(crop_coord_left)
|
||||
return params
|
||||
|
||||
|
||||
def perform_save_locally(save_path, samples):
|
||||
@@ -146,24 +131,18 @@ def show_samples(samples, outputs):
|
||||
outputs.image(grid.cpu().numpy())
|
||||
|
||||
|
||||
def get_guider(key):
|
||||
guider = st.sidebar.selectbox(
|
||||
f"Discretization #{key}",
|
||||
[
|
||||
"VanillaCFG",
|
||||
"IdentityGuider",
|
||||
],
|
||||
def get_guider(key, params: SamplingParams) -> SamplingParams:
|
||||
params.guider = Guider(
|
||||
st.sidebar.selectbox(
|
||||
f"Discretization #{key}", [member.value for member in Guider]
|
||||
)
|
||||
)
|
||||
|
||||
if guider == "IdentityGuider":
|
||||
guider_config = {
|
||||
"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"
|
||||
}
|
||||
elif guider == "VanillaCFG":
|
||||
if params.guider == Guider.VANILLA:
|
||||
scale = st.number_input(
|
||||
f"cfg-scale #{key}", value=5.0, min_value=0.0, max_value=100.0
|
||||
f"cfg-scale #{key}", value=params.scale, min_value=0.0, max_value=100.0
|
||||
)
|
||||
|
||||
params.scale = scale
|
||||
thresholder = st.sidebar.selectbox(
|
||||
f"Thresholder #{key}",
|
||||
[
|
||||
@@ -172,173 +151,97 @@ def get_guider(key):
|
||||
)
|
||||
|
||||
if thresholder == "None":
|
||||
dyn_thresh_config = {
|
||||
"target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"
|
||||
}
|
||||
params.thresholder = Thresholder.NONE
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
guider_config = {
|
||||
"target": "sgm.modules.diffusionmodules.guiders.VanillaCFG",
|
||||
"params": {"scale": scale, "dyn_thresh_config": dyn_thresh_config},
|
||||
}
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return guider_config
|
||||
return params
|
||||
|
||||
|
||||
def init_sampling(
|
||||
key=1,
|
||||
img2img_strength=1.0,
|
||||
params: SamplingParams = SamplingParams(),
|
||||
specify_num_samples=True,
|
||||
stage2strength=None,
|
||||
):
|
||||
) -> Tuple[SamplingParams, int, int]:
|
||||
params = SamplingParams(img2img_strength=params.img2img_strength)
|
||||
|
||||
num_rows, num_cols = 1, 1
|
||||
if specify_num_samples:
|
||||
num_cols = st.number_input(
|
||||
f"num cols #{key}", value=2, min_value=1, max_value=10
|
||||
)
|
||||
|
||||
steps = st.sidebar.number_input(
|
||||
f"steps #{key}", value=40, min_value=1, max_value=1000
|
||||
)
|
||||
sampler = st.sidebar.selectbox(
|
||||
f"Sampler #{key}",
|
||||
[
|
||||
"EulerEDMSampler",
|
||||
"HeunEDMSampler",
|
||||
"EulerAncestralSampler",
|
||||
"DPMPP2SAncestralSampler",
|
||||
"DPMPP2MSampler",
|
||||
"LinearMultistepSampler",
|
||||
],
|
||||
0,
|
||||
)
|
||||
discretization = st.sidebar.selectbox(
|
||||
f"Discretization #{key}",
|
||||
[
|
||||
"LegacyDDPMDiscretization",
|
||||
"EDMDiscretization",
|
||||
],
|
||||
params.steps = int(
|
||||
st.sidebar.number_input(
|
||||
f"steps #{key}", value=params.steps, min_value=1, max_value=1000
|
||||
)
|
||||
)
|
||||
|
||||
discretization_config = get_discretization(discretization, key=key)
|
||||
|
||||
guider_config = get_guider(key=key)
|
||||
|
||||
sampler = get_sampler(sampler, steps, discretization_config, guider_config, key=key)
|
||||
if img2img_strength < 1.0:
|
||||
st.warning(
|
||||
f"Wrapping {sampler.__class__.__name__} with Img2ImgDiscretizationWrapper"
|
||||
params.sampler = Sampler(
|
||||
st.sidebar.selectbox(
|
||||
f"Sampler #{key}",
|
||||
[member.value for member in Sampler],
|
||||
0,
|
||||
)
|
||||
sampler.discretization = Img2ImgDiscretizationWrapper(
|
||||
sampler.discretization, strength=img2img_strength
|
||||
)
|
||||
params.discretization = Discretization(
|
||||
st.sidebar.selectbox(
|
||||
f"Discretization #{key}",
|
||||
[member.value for member in Discretization],
|
||||
)
|
||||
if stage2strength is not None:
|
||||
sampler.discretization = Txt2NoisyDiscretizationWrapper(
|
||||
sampler.discretization, strength=stage2strength, original_steps=steps
|
||||
)
|
||||
|
||||
params = get_discretization(params, key=key)
|
||||
|
||||
params = get_guider(key=key, params=params)
|
||||
|
||||
params = get_sampler(params, key=key)
|
||||
return params, num_rows, num_cols
|
||||
|
||||
|
||||
def get_discretization(params: SamplingParams, key=1) -> SamplingParams:
|
||||
if params.discretization == Discretization.EDM:
|
||||
params.sigma_min = st.number_input(
|
||||
f"sigma_min #{key}", value=params.sigma_min
|
||||
) # 0.0292
|
||||
params.sigma_max = st.number_input(
|
||||
f"sigma_max #{key}", value=params.sigma_max
|
||||
) # 14.6146
|
||||
params.rho = st.number_input(f"rho #{key}", value=params.rho)
|
||||
return params
|
||||
|
||||
|
||||
def get_sampler(params: SamplingParams, key=1) -> SamplingParams:
|
||||
if params.sampler == Sampler.EULER_EDM or params.sampler == Sampler.HEUN_EDM:
|
||||
params.s_churn = st.sidebar.number_input(
|
||||
f"s_churn #{key}", value=params.s_churn, min_value=0.0
|
||||
)
|
||||
params.s_tmin = st.sidebar.number_input(
|
||||
f"s_tmin #{key}", value=params.s_tmin, min_value=0.0
|
||||
)
|
||||
params.s_tmax = st.sidebar.number_input(
|
||||
f"s_tmax #{key}", value=params.s_tmax, min_value=0.0
|
||||
)
|
||||
params.s_noise = st.sidebar.number_input(
|
||||
f"s_noise #{key}", value=params.s_noise, min_value=0.0
|
||||
)
|
||||
return sampler, num_rows, num_cols
|
||||
|
||||
|
||||
def get_discretization(discretization, key=1):
|
||||
if discretization == "LegacyDDPMDiscretization":
|
||||
discretization_config = {
|
||||
"target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",
|
||||
}
|
||||
elif discretization == "EDMDiscretization":
|
||||
sigma_min = st.number_input(f"sigma_min #{key}", value=0.03) # 0.0292
|
||||
sigma_max = st.number_input(f"sigma_max #{key}", value=14.61) # 14.6146
|
||||
rho = st.number_input(f"rho #{key}", value=3.0)
|
||||
discretization_config = {
|
||||
"target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization",
|
||||
"params": {
|
||||
"sigma_min": sigma_min,
|
||||
"sigma_max": sigma_max,
|
||||
"rho": rho,
|
||||
},
|
||||
}
|
||||
|
||||
return discretization_config
|
||||
|
||||
|
||||
def get_sampler(sampler_name, steps, discretization_config, guider_config, key=1):
|
||||
if sampler_name == "EulerEDMSampler" or sampler_name == "HeunEDMSampler":
|
||||
s_churn = st.sidebar.number_input(f"s_churn #{key}", value=0.0, min_value=0.0)
|
||||
s_tmin = st.sidebar.number_input(f"s_tmin #{key}", value=0.0, min_value=0.0)
|
||||
s_tmax = st.sidebar.number_input(f"s_tmax #{key}", value=999.0, min_value=0.0)
|
||||
s_noise = st.sidebar.number_input(f"s_noise #{key}", value=1.0, min_value=0.0)
|
||||
|
||||
if sampler_name == "EulerEDMSampler":
|
||||
sampler = EulerEDMSampler(
|
||||
num_steps=steps,
|
||||
discretization_config=discretization_config,
|
||||
guider_config=guider_config,
|
||||
s_churn=s_churn,
|
||||
s_tmin=s_tmin,
|
||||
s_tmax=s_tmax,
|
||||
s_noise=s_noise,
|
||||
verbose=True,
|
||||
)
|
||||
elif sampler_name == "HeunEDMSampler":
|
||||
sampler = HeunEDMSampler(
|
||||
num_steps=steps,
|
||||
discretization_config=discretization_config,
|
||||
guider_config=guider_config,
|
||||
s_churn=s_churn,
|
||||
s_tmin=s_tmin,
|
||||
s_tmax=s_tmax,
|
||||
s_noise=s_noise,
|
||||
verbose=True,
|
||||
)
|
||||
elif (
|
||||
sampler_name == "EulerAncestralSampler"
|
||||
or sampler_name == "DPMPP2SAncestralSampler"
|
||||
params.sampler == Sampler.EULER_ANCESTRAL
|
||||
or params.sampler == Sampler.DPMPP2S_ANCESTRAL
|
||||
):
|
||||
s_noise = st.sidebar.number_input("s_noise", value=1.0, min_value=0.0)
|
||||
eta = st.sidebar.number_input("eta", value=1.0, min_value=0.0)
|
||||
|
||||
if sampler_name == "EulerAncestralSampler":
|
||||
sampler = EulerAncestralSampler(
|
||||
num_steps=steps,
|
||||
discretization_config=discretization_config,
|
||||
guider_config=guider_config,
|
||||
eta=eta,
|
||||
s_noise=s_noise,
|
||||
verbose=True,
|
||||
)
|
||||
elif sampler_name == "DPMPP2SAncestralSampler":
|
||||
sampler = DPMPP2SAncestralSampler(
|
||||
num_steps=steps,
|
||||
discretization_config=discretization_config,
|
||||
guider_config=guider_config,
|
||||
eta=eta,
|
||||
s_noise=s_noise,
|
||||
verbose=True,
|
||||
)
|
||||
elif sampler_name == "DPMPP2MSampler":
|
||||
sampler = DPMPP2MSampler(
|
||||
num_steps=steps,
|
||||
discretization_config=discretization_config,
|
||||
guider_config=guider_config,
|
||||
verbose=True,
|
||||
params.s_noise = st.sidebar.number_input(
|
||||
"s_noise", value=params.s_noise, min_value=0.0
|
||||
)
|
||||
elif sampler_name == "LinearMultistepSampler":
|
||||
order = st.sidebar.number_input("order", value=4, min_value=1)
|
||||
sampler = LinearMultistepSampler(
|
||||
num_steps=steps,
|
||||
discretization_config=discretization_config,
|
||||
guider_config=guider_config,
|
||||
order=order,
|
||||
verbose=True,
|
||||
params.eta = st.sidebar.number_input("eta", value=params.eta, min_value=0.0)
|
||||
|
||||
elif params.sampler == Sampler.LINEAR_MULTISTEP:
|
||||
params.order = int(
|
||||
st.sidebar.number_input("order", value=params.order, min_value=1)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unknown sampler {sampler_name}!")
|
||||
|
||||
return sampler
|
||||
return params
|
||||
|
||||
|
||||
def get_interactive_image(key=None) -> Image.Image:
|
||||
def get_interactive_image(key=None) -> Optional[Image.Image]:
|
||||
image = st.file_uploader("Input", type=["jpg", "JPEG", "png"], key=key)
|
||||
if image is not None:
|
||||
image = Image.open(image)
|
||||
@@ -347,7 +250,7 @@ def get_interactive_image(key=None) -> Image.Image:
|
||||
return image
|
||||
|
||||
|
||||
def load_img(display=True, key=None):
|
||||
def load_img(display=True, key=None) -> torch.Tensor:
|
||||
image = get_interactive_image(key=key)
|
||||
if image is None:
|
||||
return None
|
||||
|
||||
@@ -22,12 +22,12 @@ from typing import Optional
|
||||
|
||||
|
||||
class ModelArchitecture(str, Enum):
|
||||
SD_2_1 = "stable-diffusion-v2-1"
|
||||
SD_2_1_768 = "stable-diffusion-v2-1-768"
|
||||
SDXL_V0_9_BASE = "stable-diffusion-xl-v0-9-base"
|
||||
SDXL_V0_9_REFINER = "stable-diffusion-xl-v0-9-refiner"
|
||||
SDXL_V1_BASE = "stable-diffusion-xl-v1-base"
|
||||
SDXL_V1_REFINER = "stable-diffusion-xl-v1-refiner"
|
||||
SDXL_V0_9_BASE = "stable-diffusion-xl-v0-9-base"
|
||||
SDXL_V0_9_REFINER = "stable-diffusion-xl-v0-9-refiner"
|
||||
SD_2_1 = "stable-diffusion-v2-1"
|
||||
SD_2_1_768 = "stable-diffusion-v2-1-768"
|
||||
|
||||
|
||||
class Sampler(str, Enum):
|
||||
@@ -58,7 +58,7 @@ class SamplingParams:
|
||||
width: int = 1024
|
||||
height: int = 1024
|
||||
steps: int = 40
|
||||
sampler: Sampler = Sampler.DPMPP2M
|
||||
sampler: Sampler = Sampler.EULER_EDM
|
||||
discretization: Discretization = Discretization.LEGACY_DDPM
|
||||
guider: Guider = Guider.VANILLA
|
||||
thresholder: Thresholder = Thresholder.NONE
|
||||
@@ -227,6 +227,7 @@ class SamplingPipeline:
|
||||
samples: int = 1,
|
||||
return_latents: bool = False,
|
||||
noise_strength=None,
|
||||
filter=None,
|
||||
):
|
||||
sampler = get_sampler_config(params)
|
||||
|
||||
@@ -253,7 +254,7 @@ class SamplingPipeline:
|
||||
self.specs.factor,
|
||||
force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
|
||||
return_latents=return_latents,
|
||||
filter=None,
|
||||
filter=filter,
|
||||
)
|
||||
|
||||
def image_to_image(
|
||||
@@ -265,6 +266,7 @@ class SamplingPipeline:
|
||||
samples: int = 1,
|
||||
return_latents: bool = False,
|
||||
noise_strength=None,
|
||||
filter=None,
|
||||
):
|
||||
sampler = get_sampler_config(params)
|
||||
|
||||
@@ -289,7 +291,7 @@ class SamplingPipeline:
|
||||
samples,
|
||||
force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
|
||||
return_latents=return_latents,
|
||||
filter=None,
|
||||
filter=filter,
|
||||
)
|
||||
|
||||
def wrap_discretization(
|
||||
@@ -327,6 +329,8 @@ class SamplingPipeline:
|
||||
),
|
||||
samples: int = 1,
|
||||
return_latents: bool = False,
|
||||
filter=None,
|
||||
add_noise=False,
|
||||
):
|
||||
sampler = get_sampler_config(params)
|
||||
value_dict = {
|
||||
@@ -354,8 +358,8 @@ class SamplingPipeline:
|
||||
samples,
|
||||
skip_encode=True,
|
||||
return_latents=return_latents,
|
||||
filter=None,
|
||||
add_noise=False,
|
||||
filter=filter,
|
||||
add_noise=add_noise,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user