Streamlit refactor (#105)

* initial streamlit refactoring pass

* cleanup and fixes

* fix refiner strength

* Modify params correctly

* fix exception
This commit is contained in:
Stephan Auerhahn
2023-08-06 19:58:52 -07:00
committed by GitHub
parent 7e7fee3f0f
commit c4b7baf896
3 changed files with 243 additions and 391 deletions

View File

@@ -1,8 +1,13 @@
from dataclasses import asdict
from pytorch_lightning import seed_everything
from sgm.inference.api import (
SamplingParams,
ModelArchitecture,
SamplingPipeline,
model_specs,
)
from sgm.inference.helpers import (
do_img2img,
do_sample,
get_unique_embedder_keys_from_conditioner,
perform_save_locally,
)
@@ -39,63 +44,6 @@ SD_XL_BASE_RATIOS = {
"3.0": (1728, 576),
}
VERSION2SPECS = {
"SDXL-base-1.0": {
"H": 1024,
"W": 1024,
"C": 4,
"f": 8,
"is_legacy": False,
"config": "configs/inference/sd_xl_base.yaml",
"ckpt": "checkpoints/sd_xl_base_1.0.safetensors",
},
"SDXL-base-0.9": {
"H": 1024,
"W": 1024,
"C": 4,
"f": 8,
"is_legacy": False,
"config": "configs/inference/sd_xl_base.yaml",
"ckpt": "checkpoints/sd_xl_base_0.9.safetensors",
},
"SD-2.1": {
"H": 512,
"W": 512,
"C": 4,
"f": 8,
"is_legacy": True,
"config": "configs/inference/sd_2_1.yaml",
"ckpt": "checkpoints/v2-1_512-ema-pruned.safetensors",
},
"SD-2.1-768": {
"H": 768,
"W": 768,
"C": 4,
"f": 8,
"is_legacy": True,
"config": "configs/inference/sd_2_1_768.yaml",
"ckpt": "checkpoints/v2-1_768-ema-pruned.safetensors",
},
"SDXL-refiner-0.9": {
"H": 1024,
"W": 1024,
"C": 4,
"f": 8,
"is_legacy": True,
"config": "configs/inference/sd_xl_refiner.yaml",
"ckpt": "checkpoints/sd_xl_refiner_0.9.safetensors",
},
"SDXL-refiner-1.0": {
"H": 1024,
"W": 1024,
"C": 4,
"f": 8,
"is_legacy": True,
"config": "configs/inference/sd_xl_refiner.yaml",
"ckpt": "checkpoints/sd_xl_refiner_1.0.safetensors",
},
}
def load_img(display=True, key=None, device="cuda"):
image = get_interactive_image(key=key)
@@ -117,52 +65,48 @@ def load_img(display=True, key=None, device="cuda"):
def run_txt2img(
state,
version,
version_dict,
is_legacy=False,
version: str,
prompt: str,
negative_prompt: str,
return_latents=False,
filter=None,
stage2strength=None,
):
if version.startswith("SDXL-base"):
W, H = st.selectbox("Resolution:", list(SD_XL_BASE_RATIOS.values()), 10)
spec: SamplingSpec = state.get("spec")
model: SamplingPipeline = state.get("model")
params: SamplingParams = state.get("params")
if version.startswith("stable-diffusion-xl") and version.endswith("-base"):
params.width, params.height = st.selectbox(
"Resolution:", list(SD_XL_BASE_RATIOS.values()), 10
)
else:
H = st.number_input("H", value=version_dict["H"], min_value=64, max_value=2048)
W = st.number_input("W", value=version_dict["W"], min_value=64, max_value=2048)
C = version_dict["C"]
F = version_dict["f"]
params.height = int(
st.number_input("H", value=spec.height, min_value=64, max_value=2048)
)
params.width = int(
st.number_input("W", value=spec.width, min_value=64, max_value=2048)
)
init_dict = {
"orig_width": W,
"orig_height": H,
"target_width": W,
"target_height": H,
}
value_dict = init_embedder_options(
get_unique_embedder_keys_from_conditioner(state["model"].conditioner),
init_dict,
params = init_embedder_options(
get_unique_embedder_keys_from_conditioner(model.model.conditioner),
params=params,
prompt=prompt,
negative_prompt=negative_prompt,
)
sampler, num_rows, num_cols = init_sampling(stage2strength=stage2strength)
params, num_rows, num_cols = init_sampling(params=params)
num_samples = num_rows * num_cols
if st.button("Sample"):
st.write(f"**Model I:** {version}")
outputs = st.empty()
st.text("Sampling")
out = do_sample(
state["model"],
sampler,
value_dict,
num_samples,
H,
W,
C,
F,
force_uc_zero_embeddings=["txt"] if not is_legacy else [],
out = model.text_to_image(
params=params,
prompt=prompt,
negative_prompt=negative_prompt,
samples=int(num_samples),
return_latents=return_latents,
filter=filter,
noise_strength=stage2strength,
filter=state.get("filter"),
)
show_samples(out, outputs)
@@ -172,51 +116,45 @@ def run_txt2img(
def run_img2img(
state,
version_dict,
is_legacy=False,
prompt: str,
negative_prompt: str,
return_latents=False,
filter=None,
stage2strength=None,
):
model: SamplingPipeline = state.get("model")
params: SamplingParams = state.get("params")
img = load_img()
if img is None:
return None
H, W = img.shape[2], img.shape[3]
params.height, params.width = img.shape[2], img.shape[3]
init_dict = {
"orig_width": W,
"orig_height": H,
"target_width": W,
"target_height": H,
}
value_dict = init_embedder_options(
get_unique_embedder_keys_from_conditioner(state["model"].conditioner),
init_dict,
params = init_embedder_options(
get_unique_embedder_keys_from_conditioner(model.model.conditioner),
params=params,
prompt=prompt,
negative_prompt=negative_prompt,
)
strength = st.number_input(
params.img2img_strength = st.number_input(
"**Img2Img Strength**", value=0.75, min_value=0.0, max_value=1.0
)
sampler, num_rows, num_cols = init_sampling(
img2img_strength=strength,
stage2strength=stage2strength,
)
params, num_rows, num_cols = init_sampling(params=params)
num_samples = num_rows * num_cols
if st.button("Sample"):
outputs = st.empty()
st.text("Sampling")
out = do_img2img(
repeat(img, "1 ... -> n ...", n=num_samples),
state["model"],
sampler,
value_dict,
num_samples,
force_uc_zero_embeddings=["txt"] if not is_legacy else [],
out = model.image_to_image(
image=repeat(img, "1 ... -> n ...", n=num_samples),
params=params,
prompt=prompt,
negative_prompt=negative_prompt,
samples=int(num_samples),
return_latents=return_latents,
filter=filter,
noise_strength=stage2strength,
filter=state.get("filter"),
)
show_samples(out, outputs)
return out
@@ -224,39 +162,29 @@ def run_img2img(
def apply_refiner(
input,
state,
sampler,
num_samples,
prompt,
negative_prompt,
filter=None,
num_samples: int,
prompt: str,
negative_prompt: str,
finish_denoising=False,
):
init_dict = {
"orig_width": input.shape[3] * 8,
"orig_height": input.shape[2] * 8,
"target_width": input.shape[3] * 8,
"target_height": input.shape[2] * 8,
}
model: SamplingPipeline = state.get("model")
params: SamplingParams = state.get("params")
value_dict = init_dict
value_dict["prompt"] = prompt
value_dict["negative_prompt"] = negative_prompt
value_dict["crop_coords_top"] = 0
value_dict["crop_coords_left"] = 0
value_dict["aesthetic_score"] = 6.0
value_dict["negative_aesthetic_score"] = 2.5
params.orig_width = input.shape[3] * 8
params.orig_height = input.shape[2] * 8
params.width = input.shape[3] * 8
params.height = input.shape[2] * 8
st.warning(f"refiner input shape: {input.shape}")
samples = do_img2img(
input,
state["model"],
sampler,
value_dict,
num_samples,
skip_encode=True,
filter=filter,
samples = model.refiner(
image=input,
params=params,
prompt=prompt,
negative_prompt=negative_prompt,
samples=num_samples,
return_latents=False,
filter=state.get("filter"),
add_noise=not finish_denoising,
)
@@ -265,28 +193,34 @@ def apply_refiner(
if __name__ == "__main__":
st.title("Stable Diffusion")
version = st.selectbox("Model Version", list(VERSION2SPECS.keys()), 0)
version_dict = VERSION2SPECS[version]
version = st.selectbox(
"Model Version",
[member.value for member in ModelArchitecture],
0,
)
version_enum = ModelArchitecture(version)
specs = model_specs[version_enum]
mode = st.radio("Mode", ("txt2img", "img2img"), 0)
st.write("__________________________")
set_lowvram_mode(st.checkbox("Low vram mode", True))
if version.startswith("SDXL-base"):
if str(version).startswith("stable-diffusion-xl"):
add_pipeline = st.checkbox("Load SDXL-refiner?", False)
st.write("__________________________")
else:
add_pipeline = False
seed = st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9))
seed = int(
st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9))
)
seed_everything(seed)
save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, version))
state = init_st(version_dict, load_filter=True)
save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, str(version)))
state = init_st(model_specs[version_enum], load_filter=True)
model = state["model"]
is_legacy = version_dict["is_legacy"]
is_legacy = specs.is_legacy
prompt = st.text_input(
"prompt",
@@ -302,46 +236,59 @@ if __name__ == "__main__":
if add_pipeline:
st.write("__________________________")
version2 = st.selectbox("Refiner:", ["SDXL-refiner-1.0", "SDXL-refiner-0.9"])
version2 = ModelArchitecture(
st.selectbox(
"Refiner:",
[
ModelArchitecture.SDXL_V1_REFINER.value,
ModelArchitecture.SDXL_V0_9_REFINER.value,
],
)
)
st.warning(
f"Running with {version2} as the second stage model. Make sure to provide (V)RAM :) "
)
st.write("**Refiner Options:**")
version_dict2 = VERSION2SPECS[version2]
state2 = init_st(version_dict2, load_filter=False)
specs2 = model_specs[version2]
state2 = init_st(specs2, load_filter=False)
params2 = state2["params"]
stage2strength = st.number_input(
params2.img2img_strength = st.number_input(
"**Refinement strength**", value=0.15, min_value=0.0, max_value=1.0
)
sampler2, *_ = init_sampling(
params2, *_ = init_sampling(
key=2,
img2img_strength=stage2strength,
params=state2["params"],
specify_num_samples=False,
)
st.write("__________________________")
finish_denoising = st.checkbox("Finish denoising with refiner.", True)
if not finish_denoising:
if finish_denoising:
stage2strength = params2.img2img_strength
else:
stage2strength = None
else:
state2 = None
params2 = None
stage2strength = None
if mode == "txt2img":
out = run_txt2img(
state,
version,
version_dict,
is_legacy=is_legacy,
state=state,
version=str(version),
prompt=prompt,
negative_prompt=negative_prompt,
return_latents=add_pipeline,
filter=state.get("filter"),
stage2strength=stage2strength,
)
elif mode == "img2img":
out = run_img2img(
state,
version_dict,
is_legacy=is_legacy,
state=state,
prompt=prompt,
negative_prompt=negative_prompt,
return_latents=add_pipeline,
filter=state.get("filter"),
stage2strength=stage2strength,
)
else:
@@ -356,13 +303,11 @@ if __name__ == "__main__":
outputs = st.empty()
st.write("**Running Refinement Stage**")
samples = apply_refiner(
samples_z,
state2,
sampler2,
samples_z.shape[0],
input=samples_z,
state=state2,
num_samples=samples_z.shape[0],
prompt=prompt,
negative_prompt=negative_prompt if is_legacy else "",
filter=state.get("filter"),
finish_denoising=finish_denoising,
)
show_samples(samples, outputs)

View File

@@ -4,43 +4,46 @@ import numpy as np
import streamlit as st
import torch
from einops import rearrange, repeat
from omegaconf import OmegaConf
from PIL import Image
from torchvision import transforms
from typing import Optional, Tuple
from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering
from sgm.modules.diffusionmodules.sampling import (
DPMPP2MSampler,
DPMPP2SAncestralSampler,
EulerAncestralSampler,
EulerEDMSampler,
HeunEDMSampler,
LinearMultistepSampler,
from sgm.inference.api import (
Discretization,
Guider,
Sampler,
SamplingParams,
SamplingSpec,
SamplingPipeline,
Thresholder,
)
from sgm.inference.helpers import (
Img2ImgDiscretizationWrapper,
Txt2NoisyDiscretizationWrapper,
embed_watermark,
)
from sgm.util import load_model_from_config
@st.cache_resource()
def init_st(version_dict, load_ckpt=True, load_filter=True):
def init_st(spec: SamplingSpec, load_ckpt=True, load_filter=True):
global lowvram_mode
state = dict()
if not "model" in state:
config = version_dict["config"]
ckpt = version_dict["ckpt"]
config = spec.config
ckpt = spec.ckpt
config = OmegaConf.load(config)
model = load_model_from_config(
config, ckpt if load_ckpt else None, freeze=False
pipeline = SamplingPipeline(
model_spec=spec,
use_fp16=lowvram_mode,
device="cpu" if lowvram_mode else "cuda",
)
state["model"] = model
state["spec"] = spec
state["model"] = pipeline
state["ckpt"] = ckpt if load_ckpt else None
state["config"] = config
state["params"] = SamplingParams()
if load_filter:
state["filter"] = DeepFloydDataFiltering(verbose=False)
return state
@@ -54,23 +57,13 @@ def set_lowvram_mode(mode):
lowvram_mode = mode
def initial_model_load(model):
global lowvram_mode
if lowvram_mode:
model.model.half()
else:
model.cuda()
return model
def get_unique_embedder_keys_from_conditioner(conditioner):
return list(set([x.input_key for x in conditioner.embedders]))
def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None):
# Hardcoded demo settings; might undergo some changes in the future
value_dict = {}
def init_embedder_options(
keys, params: SamplingParams, prompt=None, negative_prompt=None
) -> SamplingParams:
for key in keys:
if key == "txt":
if prompt is None:
@@ -80,40 +73,32 @@ def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None):
if negative_prompt is None:
negative_prompt = st.text_input("Negative prompt", "")
value_dict["prompt"] = prompt
value_dict["negative_prompt"] = negative_prompt
if key == "original_size_as_tuple":
orig_width = st.number_input(
"orig_width",
value=init_dict["orig_width"],
value=params.orig_width,
min_value=16,
)
orig_height = st.number_input(
"orig_height",
value=init_dict["orig_height"],
value=params.orig_height,
min_value=16,
)
value_dict["orig_width"] = orig_width
value_dict["orig_height"] = orig_height
params.orig_width = int(orig_width)
params.orig_height = int(orig_height)
if key == "crop_coords_top_left":
crop_coord_top = st.number_input("crop_coords_top", value=0, min_value=0)
crop_coord_left = st.number_input("crop_coords_left", value=0, min_value=0)
crop_coord_top = st.number_input(
"crop_coords_top", value=params.crop_coords_top, min_value=0
)
crop_coord_left = st.number_input(
"crop_coords_left", value=params.crop_coords_left, min_value=0
)
value_dict["crop_coords_top"] = crop_coord_top
value_dict["crop_coords_left"] = crop_coord_left
if key == "aesthetic_score":
value_dict["aesthetic_score"] = 6.0
value_dict["negative_aesthetic_score"] = 2.5
if key == "target_size_as_tuple":
value_dict["target_width"] = init_dict["target_width"]
value_dict["target_height"] = init_dict["target_height"]
return value_dict
params.crop_coords_top = int(crop_coord_top)
params.crop_coords_left = int(crop_coord_left)
return params
def perform_save_locally(save_path, samples):
@@ -146,24 +131,18 @@ def show_samples(samples, outputs):
outputs.image(grid.cpu().numpy())
def get_guider(key):
guider = st.sidebar.selectbox(
f"Discretization #{key}",
[
"VanillaCFG",
"IdentityGuider",
],
def get_guider(key, params: SamplingParams) -> SamplingParams:
params.guider = Guider(
st.sidebar.selectbox(
f"Discretization #{key}", [member.value for member in Guider]
)
)
if guider == "IdentityGuider":
guider_config = {
"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"
}
elif guider == "VanillaCFG":
if params.guider == Guider.VANILLA:
scale = st.number_input(
f"cfg-scale #{key}", value=5.0, min_value=0.0, max_value=100.0
f"cfg-scale #{key}", value=params.scale, min_value=0.0, max_value=100.0
)
params.scale = scale
thresholder = st.sidebar.selectbox(
f"Thresholder #{key}",
[
@@ -172,173 +151,97 @@ def get_guider(key):
)
if thresholder == "None":
dyn_thresh_config = {
"target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"
}
params.thresholder = Thresholder.NONE
else:
raise NotImplementedError
guider_config = {
"target": "sgm.modules.diffusionmodules.guiders.VanillaCFG",
"params": {"scale": scale, "dyn_thresh_config": dyn_thresh_config},
}
else:
raise NotImplementedError
return guider_config
return params
def init_sampling(
key=1,
img2img_strength=1.0,
params: SamplingParams = SamplingParams(),
specify_num_samples=True,
stage2strength=None,
):
) -> Tuple[SamplingParams, int, int]:
params = SamplingParams(img2img_strength=params.img2img_strength)
num_rows, num_cols = 1, 1
if specify_num_samples:
num_cols = st.number_input(
f"num cols #{key}", value=2, min_value=1, max_value=10
)
steps = st.sidebar.number_input(
f"steps #{key}", value=40, min_value=1, max_value=1000
)
sampler = st.sidebar.selectbox(
f"Sampler #{key}",
[
"EulerEDMSampler",
"HeunEDMSampler",
"EulerAncestralSampler",
"DPMPP2SAncestralSampler",
"DPMPP2MSampler",
"LinearMultistepSampler",
],
0,
)
discretization = st.sidebar.selectbox(
f"Discretization #{key}",
[
"LegacyDDPMDiscretization",
"EDMDiscretization",
],
params.steps = int(
st.sidebar.number_input(
f"steps #{key}", value=params.steps, min_value=1, max_value=1000
)
)
discretization_config = get_discretization(discretization, key=key)
guider_config = get_guider(key=key)
sampler = get_sampler(sampler, steps, discretization_config, guider_config, key=key)
if img2img_strength < 1.0:
st.warning(
f"Wrapping {sampler.__class__.__name__} with Img2ImgDiscretizationWrapper"
params.sampler = Sampler(
st.sidebar.selectbox(
f"Sampler #{key}",
[member.value for member in Sampler],
0,
)
sampler.discretization = Img2ImgDiscretizationWrapper(
sampler.discretization, strength=img2img_strength
)
params.discretization = Discretization(
st.sidebar.selectbox(
f"Discretization #{key}",
[member.value for member in Discretization],
)
if stage2strength is not None:
sampler.discretization = Txt2NoisyDiscretizationWrapper(
sampler.discretization, strength=stage2strength, original_steps=steps
)
params = get_discretization(params, key=key)
params = get_guider(key=key, params=params)
params = get_sampler(params, key=key)
return params, num_rows, num_cols
def get_discretization(params: SamplingParams, key=1) -> SamplingParams:
if params.discretization == Discretization.EDM:
params.sigma_min = st.number_input(
f"sigma_min #{key}", value=params.sigma_min
) # 0.0292
params.sigma_max = st.number_input(
f"sigma_max #{key}", value=params.sigma_max
) # 14.6146
params.rho = st.number_input(f"rho #{key}", value=params.rho)
return params
def get_sampler(params: SamplingParams, key=1) -> SamplingParams:
if params.sampler == Sampler.EULER_EDM or params.sampler == Sampler.HEUN_EDM:
params.s_churn = st.sidebar.number_input(
f"s_churn #{key}", value=params.s_churn, min_value=0.0
)
params.s_tmin = st.sidebar.number_input(
f"s_tmin #{key}", value=params.s_tmin, min_value=0.0
)
params.s_tmax = st.sidebar.number_input(
f"s_tmax #{key}", value=params.s_tmax, min_value=0.0
)
params.s_noise = st.sidebar.number_input(
f"s_noise #{key}", value=params.s_noise, min_value=0.0
)
return sampler, num_rows, num_cols
def get_discretization(discretization, key=1):
if discretization == "LegacyDDPMDiscretization":
discretization_config = {
"target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",
}
elif discretization == "EDMDiscretization":
sigma_min = st.number_input(f"sigma_min #{key}", value=0.03) # 0.0292
sigma_max = st.number_input(f"sigma_max #{key}", value=14.61) # 14.6146
rho = st.number_input(f"rho #{key}", value=3.0)
discretization_config = {
"target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization",
"params": {
"sigma_min": sigma_min,
"sigma_max": sigma_max,
"rho": rho,
},
}
return discretization_config
def get_sampler(sampler_name, steps, discretization_config, guider_config, key=1):
if sampler_name == "EulerEDMSampler" or sampler_name == "HeunEDMSampler":
s_churn = st.sidebar.number_input(f"s_churn #{key}", value=0.0, min_value=0.0)
s_tmin = st.sidebar.number_input(f"s_tmin #{key}", value=0.0, min_value=0.0)
s_tmax = st.sidebar.number_input(f"s_tmax #{key}", value=999.0, min_value=0.0)
s_noise = st.sidebar.number_input(f"s_noise #{key}", value=1.0, min_value=0.0)
if sampler_name == "EulerEDMSampler":
sampler = EulerEDMSampler(
num_steps=steps,
discretization_config=discretization_config,
guider_config=guider_config,
s_churn=s_churn,
s_tmin=s_tmin,
s_tmax=s_tmax,
s_noise=s_noise,
verbose=True,
)
elif sampler_name == "HeunEDMSampler":
sampler = HeunEDMSampler(
num_steps=steps,
discretization_config=discretization_config,
guider_config=guider_config,
s_churn=s_churn,
s_tmin=s_tmin,
s_tmax=s_tmax,
s_noise=s_noise,
verbose=True,
)
elif (
sampler_name == "EulerAncestralSampler"
or sampler_name == "DPMPP2SAncestralSampler"
params.sampler == Sampler.EULER_ANCESTRAL
or params.sampler == Sampler.DPMPP2S_ANCESTRAL
):
s_noise = st.sidebar.number_input("s_noise", value=1.0, min_value=0.0)
eta = st.sidebar.number_input("eta", value=1.0, min_value=0.0)
if sampler_name == "EulerAncestralSampler":
sampler = EulerAncestralSampler(
num_steps=steps,
discretization_config=discretization_config,
guider_config=guider_config,
eta=eta,
s_noise=s_noise,
verbose=True,
)
elif sampler_name == "DPMPP2SAncestralSampler":
sampler = DPMPP2SAncestralSampler(
num_steps=steps,
discretization_config=discretization_config,
guider_config=guider_config,
eta=eta,
s_noise=s_noise,
verbose=True,
)
elif sampler_name == "DPMPP2MSampler":
sampler = DPMPP2MSampler(
num_steps=steps,
discretization_config=discretization_config,
guider_config=guider_config,
verbose=True,
params.s_noise = st.sidebar.number_input(
"s_noise", value=params.s_noise, min_value=0.0
)
elif sampler_name == "LinearMultistepSampler":
order = st.sidebar.number_input("order", value=4, min_value=1)
sampler = LinearMultistepSampler(
num_steps=steps,
discretization_config=discretization_config,
guider_config=guider_config,
order=order,
verbose=True,
params.eta = st.sidebar.number_input("eta", value=params.eta, min_value=0.0)
elif params.sampler == Sampler.LINEAR_MULTISTEP:
params.order = int(
st.sidebar.number_input("order", value=params.order, min_value=1)
)
else:
raise ValueError(f"unknown sampler {sampler_name}!")
return sampler
return params
def get_interactive_image(key=None) -> Image.Image:
def get_interactive_image(key=None) -> Optional[Image.Image]:
image = st.file_uploader("Input", type=["jpg", "JPEG", "png"], key=key)
if image is not None:
image = Image.open(image)
@@ -347,7 +250,7 @@ def get_interactive_image(key=None) -> Image.Image:
return image
def load_img(display=True, key=None):
def load_img(display=True, key=None) -> torch.Tensor:
image = get_interactive_image(key=key)
if image is None:
return None

View File

@@ -22,12 +22,12 @@ from typing import Optional
class ModelArchitecture(str, Enum):
SD_2_1 = "stable-diffusion-v2-1"
SD_2_1_768 = "stable-diffusion-v2-1-768"
SDXL_V0_9_BASE = "stable-diffusion-xl-v0-9-base"
SDXL_V0_9_REFINER = "stable-diffusion-xl-v0-9-refiner"
SDXL_V1_BASE = "stable-diffusion-xl-v1-base"
SDXL_V1_REFINER = "stable-diffusion-xl-v1-refiner"
SDXL_V0_9_BASE = "stable-diffusion-xl-v0-9-base"
SDXL_V0_9_REFINER = "stable-diffusion-xl-v0-9-refiner"
SD_2_1 = "stable-diffusion-v2-1"
SD_2_1_768 = "stable-diffusion-v2-1-768"
class Sampler(str, Enum):
@@ -58,7 +58,7 @@ class SamplingParams:
width: int = 1024
height: int = 1024
steps: int = 40
sampler: Sampler = Sampler.DPMPP2M
sampler: Sampler = Sampler.EULER_EDM
discretization: Discretization = Discretization.LEGACY_DDPM
guider: Guider = Guider.VANILLA
thresholder: Thresholder = Thresholder.NONE
@@ -227,6 +227,7 @@ class SamplingPipeline:
samples: int = 1,
return_latents: bool = False,
noise_strength=None,
filter=None,
):
sampler = get_sampler_config(params)
@@ -253,7 +254,7 @@ class SamplingPipeline:
self.specs.factor,
force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
return_latents=return_latents,
filter=None,
filter=filter,
)
def image_to_image(
@@ -265,6 +266,7 @@ class SamplingPipeline:
samples: int = 1,
return_latents: bool = False,
noise_strength=None,
filter=None,
):
sampler = get_sampler_config(params)
@@ -289,7 +291,7 @@ class SamplingPipeline:
samples,
force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
return_latents=return_latents,
filter=None,
filter=filter,
)
def wrap_discretization(
@@ -327,6 +329,8 @@ class SamplingPipeline:
),
samples: int = 1,
return_latents: bool = False,
filter=None,
add_noise=False,
):
sampler = get_sampler_config(params)
value_dict = {
@@ -354,8 +358,8 @@ class SamplingPipeline:
samples,
skip_encode=True,
return_latents=return_latents,
filter=None,
add_noise=False,
filter=filter,
add_noise=add_noise,
)