diff --git a/scripts/demo/sampling.py b/scripts/demo/sampling.py index 29cb305..d638717 100644 --- a/scripts/demo/sampling.py +++ b/scripts/demo/sampling.py @@ -1,8 +1,13 @@ +from dataclasses import asdict from pytorch_lightning import seed_everything +from sgm.inference.api import ( + SamplingParams, + ModelArchitecture, + SamplingPipeline, + model_specs, +) from sgm.inference.helpers import ( - do_img2img, - do_sample, get_unique_embedder_keys_from_conditioner, perform_save_locally, ) @@ -39,63 +44,6 @@ SD_XL_BASE_RATIOS = { "3.0": (1728, 576), } -VERSION2SPECS = { - "SDXL-base-1.0": { - "H": 1024, - "W": 1024, - "C": 4, - "f": 8, - "is_legacy": False, - "config": "configs/inference/sd_xl_base.yaml", - "ckpt": "checkpoints/sd_xl_base_1.0.safetensors", - }, - "SDXL-base-0.9": { - "H": 1024, - "W": 1024, - "C": 4, - "f": 8, - "is_legacy": False, - "config": "configs/inference/sd_xl_base.yaml", - "ckpt": "checkpoints/sd_xl_base_0.9.safetensors", - }, - "SD-2.1": { - "H": 512, - "W": 512, - "C": 4, - "f": 8, - "is_legacy": True, - "config": "configs/inference/sd_2_1.yaml", - "ckpt": "checkpoints/v2-1_512-ema-pruned.safetensors", - }, - "SD-2.1-768": { - "H": 768, - "W": 768, - "C": 4, - "f": 8, - "is_legacy": True, - "config": "configs/inference/sd_2_1_768.yaml", - "ckpt": "checkpoints/v2-1_768-ema-pruned.safetensors", - }, - "SDXL-refiner-0.9": { - "H": 1024, - "W": 1024, - "C": 4, - "f": 8, - "is_legacy": True, - "config": "configs/inference/sd_xl_refiner.yaml", - "ckpt": "checkpoints/sd_xl_refiner_0.9.safetensors", - }, - "SDXL-refiner-1.0": { - "H": 1024, - "W": 1024, - "C": 4, - "f": 8, - "is_legacy": True, - "config": "configs/inference/sd_xl_refiner.yaml", - "ckpt": "checkpoints/sd_xl_refiner_1.0.safetensors", - }, -} - def load_img(display=True, key=None, device="cuda"): image = get_interactive_image(key=key) @@ -117,52 +65,48 @@ def load_img(display=True, key=None, device="cuda"): def run_txt2img( state, - version, - version_dict, - is_legacy=False, + version: str, + prompt: str, + negative_prompt: str, return_latents=False, - filter=None, stage2strength=None, ): - if version.startswith("SDXL-base"): - W, H = st.selectbox("Resolution:", list(SD_XL_BASE_RATIOS.values()), 10) + spec: SamplingSpec = state.get("spec") + model: SamplingPipeline = state.get("model") + params: SamplingParams = state.get("params") + if version.startswith("stable-diffusion-xl") and version.endswith("-base"): + params.width, params.height = st.selectbox( + "Resolution:", list(SD_XL_BASE_RATIOS.values()), 10 + ) else: - H = st.number_input("H", value=version_dict["H"], min_value=64, max_value=2048) - W = st.number_input("W", value=version_dict["W"], min_value=64, max_value=2048) - C = version_dict["C"] - F = version_dict["f"] + params.height = int( + st.number_input("H", value=spec.height, min_value=64, max_value=2048) + ) + params.width = int( + st.number_input("W", value=spec.width, min_value=64, max_value=2048) + ) - init_dict = { - "orig_width": W, - "orig_height": H, - "target_width": W, - "target_height": H, - } - value_dict = init_embedder_options( - get_unique_embedder_keys_from_conditioner(state["model"].conditioner), - init_dict, + params = init_embedder_options( + get_unique_embedder_keys_from_conditioner(model.model.conditioner), + params=params, prompt=prompt, negative_prompt=negative_prompt, ) - sampler, num_rows, num_cols = init_sampling(stage2strength=stage2strength) + params, num_rows, num_cols = init_sampling(params=params) num_samples = num_rows * num_cols if st.button("Sample"): st.write(f"**Model I:** {version}") outputs = st.empty() st.text("Sampling") - out = do_sample( - state["model"], - sampler, - value_dict, - num_samples, - H, - W, - C, - F, - force_uc_zero_embeddings=["txt"] if not is_legacy else [], + out = model.text_to_image( + params=params, + prompt=prompt, + negative_prompt=negative_prompt, + samples=int(num_samples), return_latents=return_latents, - filter=filter, + noise_strength=stage2strength, + filter=state.get("filter"), ) show_samples(out, outputs) @@ -172,51 +116,45 @@ def run_txt2img( def run_img2img( state, - version_dict, - is_legacy=False, + prompt: str, + negative_prompt: str, return_latents=False, - filter=None, stage2strength=None, ): + model: SamplingPipeline = state.get("model") + params: SamplingParams = state.get("params") + img = load_img() if img is None: return None - H, W = img.shape[2], img.shape[3] + params.height, params.width = img.shape[2], img.shape[3] - init_dict = { - "orig_width": W, - "orig_height": H, - "target_width": W, - "target_height": H, - } - value_dict = init_embedder_options( - get_unique_embedder_keys_from_conditioner(state["model"].conditioner), - init_dict, + params = init_embedder_options( + get_unique_embedder_keys_from_conditioner(model.model.conditioner), + params=params, prompt=prompt, negative_prompt=negative_prompt, ) - strength = st.number_input( + params.img2img_strength = st.number_input( "**Img2Img Strength**", value=0.75, min_value=0.0, max_value=1.0 ) - sampler, num_rows, num_cols = init_sampling( - img2img_strength=strength, - stage2strength=stage2strength, - ) + params, num_rows, num_cols = init_sampling(params=params) num_samples = num_rows * num_cols if st.button("Sample"): outputs = st.empty() st.text("Sampling") - out = do_img2img( - repeat(img, "1 ... -> n ...", n=num_samples), - state["model"], - sampler, - value_dict, - num_samples, - force_uc_zero_embeddings=["txt"] if not is_legacy else [], + out = model.image_to_image( + image=repeat(img, "1 ... -> n ...", n=num_samples), + params=params, + prompt=prompt, + negative_prompt=negative_prompt, + samples=int(num_samples), return_latents=return_latents, - filter=filter, + noise_strength=stage2strength, + filter=state.get("filter"), ) + show_samples(out, outputs) return out @@ -224,39 +162,29 @@ def run_img2img( def apply_refiner( input, state, - sampler, - num_samples, - prompt, - negative_prompt, - filter=None, + num_samples: int, + prompt: str, + negative_prompt: str, finish_denoising=False, ): - init_dict = { - "orig_width": input.shape[3] * 8, - "orig_height": input.shape[2] * 8, - "target_width": input.shape[3] * 8, - "target_height": input.shape[2] * 8, - } + model: SamplingPipeline = state.get("model") + params: SamplingParams = state.get("params") - value_dict = init_dict - value_dict["prompt"] = prompt - value_dict["negative_prompt"] = negative_prompt - - value_dict["crop_coords_top"] = 0 - value_dict["crop_coords_left"] = 0 - - value_dict["aesthetic_score"] = 6.0 - value_dict["negative_aesthetic_score"] = 2.5 + params.orig_width = input.shape[3] * 8 + params.orig_height = input.shape[2] * 8 + params.width = input.shape[3] * 8 + params.height = input.shape[2] * 8 st.warning(f"refiner input shape: {input.shape}") - samples = do_img2img( - input, - state["model"], - sampler, - value_dict, - num_samples, - skip_encode=True, - filter=filter, + + samples = model.refiner( + image=input, + params=params, + prompt=prompt, + negative_prompt=negative_prompt, + samples=num_samples, + return_latents=False, + filter=state.get("filter"), add_noise=not finish_denoising, ) @@ -265,28 +193,34 @@ def apply_refiner( if __name__ == "__main__": st.title("Stable Diffusion") - version = st.selectbox("Model Version", list(VERSION2SPECS.keys()), 0) - version_dict = VERSION2SPECS[version] + version = st.selectbox( + "Model Version", + [member.value for member in ModelArchitecture], + 0, + ) + version_enum = ModelArchitecture(version) + specs = model_specs[version_enum] mode = st.radio("Mode", ("txt2img", "img2img"), 0) st.write("__________________________") set_lowvram_mode(st.checkbox("Low vram mode", True)) - if version.startswith("SDXL-base"): + if str(version).startswith("stable-diffusion-xl"): add_pipeline = st.checkbox("Load SDXL-refiner?", False) st.write("__________________________") else: add_pipeline = False - seed = st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9)) + seed = int( + st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9)) + ) seed_everything(seed) - save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, version)) - - state = init_st(version_dict, load_filter=True) + save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, str(version))) + state = init_st(model_specs[version_enum], load_filter=True) model = state["model"] - is_legacy = version_dict["is_legacy"] + is_legacy = specs.is_legacy prompt = st.text_input( "prompt", @@ -302,46 +236,59 @@ if __name__ == "__main__": if add_pipeline: st.write("__________________________") - version2 = st.selectbox("Refiner:", ["SDXL-refiner-1.0", "SDXL-refiner-0.9"]) + version2 = ModelArchitecture( + st.selectbox( + "Refiner:", + [ + ModelArchitecture.SDXL_V1_REFINER.value, + ModelArchitecture.SDXL_V0_9_REFINER.value, + ], + ) + ) st.warning( f"Running with {version2} as the second stage model. Make sure to provide (V)RAM :) " ) st.write("**Refiner Options:**") - version_dict2 = VERSION2SPECS[version2] - state2 = init_st(version_dict2, load_filter=False) + specs2 = model_specs[version2] + state2 = init_st(specs2, load_filter=False) + params2 = state2["params"] - stage2strength = st.number_input( + params2.img2img_strength = st.number_input( "**Refinement strength**", value=0.15, min_value=0.0, max_value=1.0 ) - sampler2, *_ = init_sampling( + params2, *_ = init_sampling( key=2, - img2img_strength=stage2strength, + params=state2["params"], specify_num_samples=False, ) st.write("__________________________") finish_denoising = st.checkbox("Finish denoising with refiner.", True) - if not finish_denoising: + if finish_denoising: + stage2strength = params2.img2img_strength + else: stage2strength = None + else: + state2 = None + params2 = None + stage2strength = None if mode == "txt2img": out = run_txt2img( - state, - version, - version_dict, - is_legacy=is_legacy, + state=state, + version=str(version), + prompt=prompt, + negative_prompt=negative_prompt, return_latents=add_pipeline, - filter=state.get("filter"), stage2strength=stage2strength, ) elif mode == "img2img": out = run_img2img( - state, - version_dict, - is_legacy=is_legacy, + state=state, + prompt=prompt, + negative_prompt=negative_prompt, return_latents=add_pipeline, - filter=state.get("filter"), stage2strength=stage2strength, ) else: @@ -356,13 +303,11 @@ if __name__ == "__main__": outputs = st.empty() st.write("**Running Refinement Stage**") samples = apply_refiner( - samples_z, - state2, - sampler2, - samples_z.shape[0], + input=samples_z, + state=state2, + num_samples=samples_z.shape[0], prompt=prompt, negative_prompt=negative_prompt if is_legacy else "", - filter=state.get("filter"), finish_denoising=finish_denoising, ) show_samples(samples, outputs) diff --git a/scripts/demo/streamlit_helpers.py b/scripts/demo/streamlit_helpers.py index d3dd7d7..9a90f3b 100644 --- a/scripts/demo/streamlit_helpers.py +++ b/scripts/demo/streamlit_helpers.py @@ -4,43 +4,46 @@ import numpy as np import streamlit as st import torch from einops import rearrange, repeat -from omegaconf import OmegaConf from PIL import Image from torchvision import transforms +from typing import Optional, Tuple from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering -from sgm.modules.diffusionmodules.sampling import ( - DPMPP2MSampler, - DPMPP2SAncestralSampler, - EulerAncestralSampler, - EulerEDMSampler, - HeunEDMSampler, - LinearMultistepSampler, + +from sgm.inference.api import ( + Discretization, + Guider, + Sampler, + SamplingParams, + SamplingSpec, + SamplingPipeline, + Thresholder, ) from sgm.inference.helpers import ( - Img2ImgDiscretizationWrapper, - Txt2NoisyDiscretizationWrapper, embed_watermark, ) -from sgm.util import load_model_from_config @st.cache_resource() -def init_st(version_dict, load_ckpt=True, load_filter=True): +def init_st(spec: SamplingSpec, load_ckpt=True, load_filter=True): + global lowvram_mode state = dict() if not "model" in state: - config = version_dict["config"] - ckpt = version_dict["ckpt"] + config = spec.config + ckpt = spec.ckpt - config = OmegaConf.load(config) - model = load_model_from_config( - config, ckpt if load_ckpt else None, freeze=False + pipeline = SamplingPipeline( + model_spec=spec, + use_fp16=lowvram_mode, + device="cpu" if lowvram_mode else "cuda", ) - state["model"] = model + state["spec"] = spec + state["model"] = pipeline state["ckpt"] = ckpt if load_ckpt else None state["config"] = config + state["params"] = SamplingParams() if load_filter: state["filter"] = DeepFloydDataFiltering(verbose=False) return state @@ -54,23 +57,13 @@ def set_lowvram_mode(mode): lowvram_mode = mode -def initial_model_load(model): - global lowvram_mode - if lowvram_mode: - model.model.half() - else: - model.cuda() - return model - - def get_unique_embedder_keys_from_conditioner(conditioner): return list(set([x.input_key for x in conditioner.embedders])) -def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None): - # Hardcoded demo settings; might undergo some changes in the future - - value_dict = {} +def init_embedder_options( + keys, params: SamplingParams, prompt=None, negative_prompt=None +) -> SamplingParams: for key in keys: if key == "txt": if prompt is None: @@ -80,40 +73,32 @@ def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None): if negative_prompt is None: negative_prompt = st.text_input("Negative prompt", "") - value_dict["prompt"] = prompt - value_dict["negative_prompt"] = negative_prompt - if key == "original_size_as_tuple": orig_width = st.number_input( "orig_width", - value=init_dict["orig_width"], + value=params.orig_width, min_value=16, ) orig_height = st.number_input( "orig_height", - value=init_dict["orig_height"], + value=params.orig_height, min_value=16, ) - value_dict["orig_width"] = orig_width - value_dict["orig_height"] = orig_height + params.orig_width = int(orig_width) + params.orig_height = int(orig_height) if key == "crop_coords_top_left": - crop_coord_top = st.number_input("crop_coords_top", value=0, min_value=0) - crop_coord_left = st.number_input("crop_coords_left", value=0, min_value=0) + crop_coord_top = st.number_input( + "crop_coords_top", value=params.crop_coords_top, min_value=0 + ) + crop_coord_left = st.number_input( + "crop_coords_left", value=params.crop_coords_left, min_value=0 + ) - value_dict["crop_coords_top"] = crop_coord_top - value_dict["crop_coords_left"] = crop_coord_left - - if key == "aesthetic_score": - value_dict["aesthetic_score"] = 6.0 - value_dict["negative_aesthetic_score"] = 2.5 - - if key == "target_size_as_tuple": - value_dict["target_width"] = init_dict["target_width"] - value_dict["target_height"] = init_dict["target_height"] - - return value_dict + params.crop_coords_top = int(crop_coord_top) + params.crop_coords_left = int(crop_coord_left) + return params def perform_save_locally(save_path, samples): @@ -146,24 +131,18 @@ def show_samples(samples, outputs): outputs.image(grid.cpu().numpy()) -def get_guider(key): - guider = st.sidebar.selectbox( - f"Discretization #{key}", - [ - "VanillaCFG", - "IdentityGuider", - ], +def get_guider(key, params: SamplingParams) -> SamplingParams: + params.guider = Guider( + st.sidebar.selectbox( + f"Discretization #{key}", [member.value for member in Guider] + ) ) - if guider == "IdentityGuider": - guider_config = { - "target": "sgm.modules.diffusionmodules.guiders.IdentityGuider" - } - elif guider == "VanillaCFG": + if params.guider == Guider.VANILLA: scale = st.number_input( - f"cfg-scale #{key}", value=5.0, min_value=0.0, max_value=100.0 + f"cfg-scale #{key}", value=params.scale, min_value=0.0, max_value=100.0 ) - + params.scale = scale thresholder = st.sidebar.selectbox( f"Thresholder #{key}", [ @@ -172,173 +151,97 @@ def get_guider(key): ) if thresholder == "None": - dyn_thresh_config = { - "target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding" - } + params.thresholder = Thresholder.NONE else: raise NotImplementedError - - guider_config = { - "target": "sgm.modules.diffusionmodules.guiders.VanillaCFG", - "params": {"scale": scale, "dyn_thresh_config": dyn_thresh_config}, - } - else: - raise NotImplementedError - return guider_config + return params def init_sampling( key=1, - img2img_strength=1.0, + params: SamplingParams = SamplingParams(), specify_num_samples=True, - stage2strength=None, -): +) -> Tuple[SamplingParams, int, int]: + params = SamplingParams(img2img_strength=params.img2img_strength) + num_rows, num_cols = 1, 1 if specify_num_samples: num_cols = st.number_input( f"num cols #{key}", value=2, min_value=1, max_value=10 ) - steps = st.sidebar.number_input( - f"steps #{key}", value=40, min_value=1, max_value=1000 - ) - sampler = st.sidebar.selectbox( - f"Sampler #{key}", - [ - "EulerEDMSampler", - "HeunEDMSampler", - "EulerAncestralSampler", - "DPMPP2SAncestralSampler", - "DPMPP2MSampler", - "LinearMultistepSampler", - ], - 0, - ) - discretization = st.sidebar.selectbox( - f"Discretization #{key}", - [ - "LegacyDDPMDiscretization", - "EDMDiscretization", - ], + params.steps = int( + st.sidebar.number_input( + f"steps #{key}", value=params.steps, min_value=1, max_value=1000 + ) ) - discretization_config = get_discretization(discretization, key=key) - - guider_config = get_guider(key=key) - - sampler = get_sampler(sampler, steps, discretization_config, guider_config, key=key) - if img2img_strength < 1.0: - st.warning( - f"Wrapping {sampler.__class__.__name__} with Img2ImgDiscretizationWrapper" + params.sampler = Sampler( + st.sidebar.selectbox( + f"Sampler #{key}", + [member.value for member in Sampler], + 0, ) - sampler.discretization = Img2ImgDiscretizationWrapper( - sampler.discretization, strength=img2img_strength + ) + params.discretization = Discretization( + st.sidebar.selectbox( + f"Discretization #{key}", + [member.value for member in Discretization], ) - if stage2strength is not None: - sampler.discretization = Txt2NoisyDiscretizationWrapper( - sampler.discretization, strength=stage2strength, original_steps=steps + ) + + params = get_discretization(params, key=key) + + params = get_guider(key=key, params=params) + + params = get_sampler(params, key=key) + return params, num_rows, num_cols + + +def get_discretization(params: SamplingParams, key=1) -> SamplingParams: + if params.discretization == Discretization.EDM: + params.sigma_min = st.number_input( + f"sigma_min #{key}", value=params.sigma_min + ) # 0.0292 + params.sigma_max = st.number_input( + f"sigma_max #{key}", value=params.sigma_max + ) # 14.6146 + params.rho = st.number_input(f"rho #{key}", value=params.rho) + return params + + +def get_sampler(params: SamplingParams, key=1) -> SamplingParams: + if params.sampler == Sampler.EULER_EDM or params.sampler == Sampler.HEUN_EDM: + params.s_churn = st.sidebar.number_input( + f"s_churn #{key}", value=params.s_churn, min_value=0.0 + ) + params.s_tmin = st.sidebar.number_input( + f"s_tmin #{key}", value=params.s_tmin, min_value=0.0 + ) + params.s_tmax = st.sidebar.number_input( + f"s_tmax #{key}", value=params.s_tmax, min_value=0.0 + ) + params.s_noise = st.sidebar.number_input( + f"s_noise #{key}", value=params.s_noise, min_value=0.0 ) - return sampler, num_rows, num_cols - -def get_discretization(discretization, key=1): - if discretization == "LegacyDDPMDiscretization": - discretization_config = { - "target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization", - } - elif discretization == "EDMDiscretization": - sigma_min = st.number_input(f"sigma_min #{key}", value=0.03) # 0.0292 - sigma_max = st.number_input(f"sigma_max #{key}", value=14.61) # 14.6146 - rho = st.number_input(f"rho #{key}", value=3.0) - discretization_config = { - "target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization", - "params": { - "sigma_min": sigma_min, - "sigma_max": sigma_max, - "rho": rho, - }, - } - - return discretization_config - - -def get_sampler(sampler_name, steps, discretization_config, guider_config, key=1): - if sampler_name == "EulerEDMSampler" or sampler_name == "HeunEDMSampler": - s_churn = st.sidebar.number_input(f"s_churn #{key}", value=0.0, min_value=0.0) - s_tmin = st.sidebar.number_input(f"s_tmin #{key}", value=0.0, min_value=0.0) - s_tmax = st.sidebar.number_input(f"s_tmax #{key}", value=999.0, min_value=0.0) - s_noise = st.sidebar.number_input(f"s_noise #{key}", value=1.0, min_value=0.0) - - if sampler_name == "EulerEDMSampler": - sampler = EulerEDMSampler( - num_steps=steps, - discretization_config=discretization_config, - guider_config=guider_config, - s_churn=s_churn, - s_tmin=s_tmin, - s_tmax=s_tmax, - s_noise=s_noise, - verbose=True, - ) - elif sampler_name == "HeunEDMSampler": - sampler = HeunEDMSampler( - num_steps=steps, - discretization_config=discretization_config, - guider_config=guider_config, - s_churn=s_churn, - s_tmin=s_tmin, - s_tmax=s_tmax, - s_noise=s_noise, - verbose=True, - ) elif ( - sampler_name == "EulerAncestralSampler" - or sampler_name == "DPMPP2SAncestralSampler" + params.sampler == Sampler.EULER_ANCESTRAL + or params.sampler == Sampler.DPMPP2S_ANCESTRAL ): - s_noise = st.sidebar.number_input("s_noise", value=1.0, min_value=0.0) - eta = st.sidebar.number_input("eta", value=1.0, min_value=0.0) - - if sampler_name == "EulerAncestralSampler": - sampler = EulerAncestralSampler( - num_steps=steps, - discretization_config=discretization_config, - guider_config=guider_config, - eta=eta, - s_noise=s_noise, - verbose=True, - ) - elif sampler_name == "DPMPP2SAncestralSampler": - sampler = DPMPP2SAncestralSampler( - num_steps=steps, - discretization_config=discretization_config, - guider_config=guider_config, - eta=eta, - s_noise=s_noise, - verbose=True, - ) - elif sampler_name == "DPMPP2MSampler": - sampler = DPMPP2MSampler( - num_steps=steps, - discretization_config=discretization_config, - guider_config=guider_config, - verbose=True, + params.s_noise = st.sidebar.number_input( + "s_noise", value=params.s_noise, min_value=0.0 ) - elif sampler_name == "LinearMultistepSampler": - order = st.sidebar.number_input("order", value=4, min_value=1) - sampler = LinearMultistepSampler( - num_steps=steps, - discretization_config=discretization_config, - guider_config=guider_config, - order=order, - verbose=True, + params.eta = st.sidebar.number_input("eta", value=params.eta, min_value=0.0) + + elif params.sampler == Sampler.LINEAR_MULTISTEP: + params.order = int( + st.sidebar.number_input("order", value=params.order, min_value=1) ) - else: - raise ValueError(f"unknown sampler {sampler_name}!") - - return sampler + return params -def get_interactive_image(key=None) -> Image.Image: +def get_interactive_image(key=None) -> Optional[Image.Image]: image = st.file_uploader("Input", type=["jpg", "JPEG", "png"], key=key) if image is not None: image = Image.open(image) @@ -347,7 +250,7 @@ def get_interactive_image(key=None) -> Image.Image: return image -def load_img(display=True, key=None): +def load_img(display=True, key=None) -> torch.Tensor: image = get_interactive_image(key=key) if image is None: return None diff --git a/sgm/inference/api.py b/sgm/inference/api.py index 89b7370..ad6aecc 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -22,12 +22,12 @@ from typing import Optional class ModelArchitecture(str, Enum): - SD_2_1 = "stable-diffusion-v2-1" - SD_2_1_768 = "stable-diffusion-v2-1-768" - SDXL_V0_9_BASE = "stable-diffusion-xl-v0-9-base" - SDXL_V0_9_REFINER = "stable-diffusion-xl-v0-9-refiner" SDXL_V1_BASE = "stable-diffusion-xl-v1-base" SDXL_V1_REFINER = "stable-diffusion-xl-v1-refiner" + SDXL_V0_9_BASE = "stable-diffusion-xl-v0-9-base" + SDXL_V0_9_REFINER = "stable-diffusion-xl-v0-9-refiner" + SD_2_1 = "stable-diffusion-v2-1" + SD_2_1_768 = "stable-diffusion-v2-1-768" class Sampler(str, Enum): @@ -58,7 +58,7 @@ class SamplingParams: width: int = 1024 height: int = 1024 steps: int = 40 - sampler: Sampler = Sampler.DPMPP2M + sampler: Sampler = Sampler.EULER_EDM discretization: Discretization = Discretization.LEGACY_DDPM guider: Guider = Guider.VANILLA thresholder: Thresholder = Thresholder.NONE @@ -227,6 +227,7 @@ class SamplingPipeline: samples: int = 1, return_latents: bool = False, noise_strength=None, + filter=None, ): sampler = get_sampler_config(params) @@ -253,7 +254,7 @@ class SamplingPipeline: self.specs.factor, force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [], return_latents=return_latents, - filter=None, + filter=filter, ) def image_to_image( @@ -265,6 +266,7 @@ class SamplingPipeline: samples: int = 1, return_latents: bool = False, noise_strength=None, + filter=None, ): sampler = get_sampler_config(params) @@ -289,7 +291,7 @@ class SamplingPipeline: samples, force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [], return_latents=return_latents, - filter=None, + filter=filter, ) def wrap_discretization( @@ -327,6 +329,8 @@ class SamplingPipeline: ), samples: int = 1, return_latents: bool = False, + filter=None, + add_noise=False, ): sampler = get_sampler_config(params) value_dict = { @@ -354,8 +358,8 @@ class SamplingPipeline: samples, skip_encode=True, return_latents=return_latents, - filter=None, - add_noise=False, + filter=filter, + add_noise=add_noise, )