From fbe93fc53b3407acab4cf3394b8c0645ced98a7c Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Sat, 12 Aug 2023 05:33:16 -0700 Subject: [PATCH] PR fixes, model specific defaults --- scripts/demo/sampling.py | 35 +++------ scripts/demo/streamlit_helpers.py | 22 ++---- sgm/inference/api.py | 115 ++++++++++++++++++------------ tests/inference/test_inference.py | 2 +- 4 files changed, 88 insertions(+), 86 deletions(-) diff --git a/scripts/demo/sampling.py b/scripts/demo/sampling.py index 5415551..20a8f03 100644 --- a/scripts/demo/sampling.py +++ b/scripts/demo/sampling.py @@ -66,9 +66,7 @@ def load_img(display=True, key=None, device="cuda"): st.image(image) w, h = image.size print(f"loaded input image of size ({w}, {h})") - width, height = map( - lambda x: x - x % 64, (w, h) - ) # resize to integer multiple of 64 + width, height = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64 image = image.resize((width, height)) image = np.array(image.convert("RGB")) image = image[None].transpose(0, 3, 1, 2) @@ -78,26 +76,19 @@ def load_img(display=True, key=None, device="cuda"): def run_txt2img( state, - version: str, + model_id: ModelArchitecture, prompt: str, negative_prompt: str, return_latents=False, stage2strength=None, ): - spec: SamplingSpec = state["spec"] model: SamplingPipeline = state["model"] params: SamplingParams = state["params"] - if version.startswith("stable-diffusion-xl") and version.endswith("-base"): - width, height = st.selectbox( - "Resolution:", list(SD_XL_BASE_RATIOS.values()), 10 - ) + if model_id in sdxl_base_model_list: + width, height = st.selectbox("Resolution:", list(SD_XL_BASE_RATIOS.values()), 10) else: - height = int( - st.number_input("H", value=spec.height, min_value=64, max_value=2048) - ) - width = int( - st.number_input("W", value=spec.width, min_value=64, max_value=2048) - ) + height = int(st.number_input("H", value=params.height, min_value=64, max_value=2048)) + width = int(st.number_input("W", value=params.width, min_value=64, max_value=2048)) params = init_embedder_options( get_unique_embedder_keys_from_conditioner(model.model.conditioner), @@ -207,12 +198,12 @@ def apply_refiner( sdxl_base_model_list = [ - ModelArchitecture.SDXL_V1_BASE, + ModelArchitecture.SDXL_V1_0_BASE, ModelArchitecture.SDXL_V0_9_BASE, ] sdxl_refiner_model_list = [ - ModelArchitecture.SDXL_V1_REFINER, + ModelArchitecture.SDXL_V1_0_REFINER, ModelArchitecture.SDXL_V0_9_REFINER, ] @@ -239,9 +230,7 @@ if __name__ == "__main__": else: add_pipeline = False - seed = int( - st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9)) - ) + seed = int(st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9))) seed_everything(seed) save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, str(version))) @@ -281,9 +270,7 @@ if __name__ == "__main__": st.write("**Refiner Options:**") specs2 = model_specs[version2] - state2 = init_st( - specs2, load_filter=False, use_fp16=use_fp16, enable_swap=enable_swap - ) + state2 = init_st(specs2, load_filter=False, use_fp16=use_fp16, enable_swap=enable_swap) params2 = state2["params"] params2.img2img_strength = st.number_input( @@ -309,7 +296,7 @@ if __name__ == "__main__": if mode == "txt2img": out = run_txt2img( state=state, - version=str(version), + model_id=version_enum, prompt=prompt, negative_prompt=negative_prompt, return_latents=add_pipeline, diff --git a/scripts/demo/streamlit_helpers.py b/scripts/demo/streamlit_helpers.py index 5c35069..a1770a8 100644 --- a/scripts/demo/streamlit_helpers.py +++ b/scripts/demo/streamlit_helpers.py @@ -48,7 +48,7 @@ def init_st( state["model"] = pipeline state["ckpt"] = ckpt if load_ckpt else None state["config"] = config - state["params"] = SamplingParams() + state["params"] = spec.default_params if load_filter: state["filter"] = DeepFloydDataFiltering(verbose=False) else: @@ -132,9 +132,7 @@ def show_samples(samples, outputs): def get_guider(params: SamplingParams, key=1) -> SamplingParams: params.guider = Guider( - st.sidebar.selectbox( - f"Discretization #{key}", [member.value for member in Guider] - ) + st.sidebar.selectbox(f"Discretization #{key}", [member.value for member in Guider]) ) if params.guider == Guider.VANILLA: @@ -165,14 +163,10 @@ def init_sampling( num_rows, num_cols = 1, 1 if specify_num_samples: - num_cols = st.number_input( - f"num cols #{key}", value=2, min_value=1, max_value=10 - ) + num_cols = st.number_input(f"num cols #{key}", value=2, min_value=1, max_value=10) params.steps = int( - st.sidebar.number_input( - f"steps #{key}", value=params.steps, min_value=1, max_value=1000 - ) + st.sidebar.number_input(f"steps #{key}", value=params.steps, min_value=1, max_value=1000) ) params.sampler = Sampler( @@ -220,15 +214,11 @@ def get_sampler(params: SamplingParams, key=1) -> SamplingParams: ) elif params.sampler in (Sampler.EULER_ANCESTRAL, Sampler.DPMPP2S_ANCESTRAL): - params.s_noise = st.sidebar.number_input( - "s_noise", value=params.s_noise, min_value=0.0 - ) + params.s_noise = st.sidebar.number_input("s_noise", value=params.s_noise, min_value=0.0) params.eta = st.sidebar.number_input("eta", value=params.eta, min_value=0.0) elif params.sampler == Sampler.LINEAR_MULTISTEP: - params.order = int( - st.sidebar.number_input("order", value=params.order, min_value=1) - ) + params.order = int(st.sidebar.number_input("order", value=params.order, min_value=1)) return params diff --git a/sgm/inference/api.py b/sgm/inference/api.py index 9ca1111..d2d5a7d 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -24,8 +24,8 @@ from typing import Optional, Dict, Any, Union class ModelArchitecture(str, Enum): - SDXL_V1_BASE = "stable-diffusion-xl-v1-base" - SDXL_V1_REFINER = "stable-diffusion-xl-v1-refiner" + SDXL_V1_0_BASE = "stable-diffusion-xl-v1-base" + SDXL_V1_0_REFINER = "stable-diffusion-xl-v1-refiner" SDXL_V0_9_BASE = "stable-diffusion-xl-v0-9-base" SDXL_V0_9_REFINER = "stable-diffusion-xl-v0-9-refiner" SD_2_1 = "stable-diffusion-v2-1" @@ -59,24 +59,21 @@ class Thresholder(str, Enum): class SamplingParams: """ Parameters for sampling. - The defaults here are derived from user preference testing. - They will be subject to change in the future, likely pulled - from model specs instead of global defaults. """ - width: int = 1024 - height: int = 1024 - steps: int = 40 + width: int + height: int + steps: int sampler: Sampler = Sampler.EULER_EDM discretization: Discretization = Discretization.LEGACY_DDPM guider: Guider = Guider.VANILLA thresholder: Thresholder = Thresholder.NONE - scale: float = 5.0 + scale: float aesthetic_score: float = 6.0 negative_aesthetic_score: float = 2.5 img2img_strength: float = 1.0 - orig_width: int = width - orig_height: int = height + orig_width: int = 1024 + orig_height: int = 1024 crop_coords_top: int = 0 crop_coords_left: int = 0 sigma_min: float = 0.0292 @@ -100,8 +97,10 @@ class SamplingSpec: config: str ckpt: str is_guided: bool + default_params: SamplingParams +# The defaults here are derived from user preference testing. model_specs = { ModelArchitecture.SD_2_1: SamplingSpec( height=512, @@ -112,6 +111,12 @@ model_specs = { config="sd_2_1.yaml", ckpt="v2-1_512-ema-pruned.safetensors", is_guided=True, + default_params=SamplingParams( + width=512, + height=512, + steps=40, + scale=7.0, + ), ), ModelArchitecture.SD_2_1_768: SamplingSpec( height=768, @@ -122,6 +127,12 @@ model_specs = { config="sd_2_1_768.yaml", ckpt="v2-1_768-ema-pruned.safetensors", is_guided=True, + default_params=SamplingParams( + width=768, + height=768, + steps=40, + scale=7.0, + ), ), ModelArchitecture.SDXL_V0_9_BASE: SamplingSpec( height=1024, @@ -132,6 +143,7 @@ model_specs = { config="sd_xl_base.yaml", ckpt="sd_xl_base_0.9.safetensors", is_guided=True, + default_params=SamplingParams(width=1024, height=1024, steps=40, scale=5.0), ), ModelArchitecture.SDXL_V0_9_REFINER: SamplingSpec( height=1024, @@ -142,8 +154,11 @@ model_specs = { config="sd_xl_refiner.yaml", ckpt="sd_xl_refiner_0.9.safetensors", is_guided=True, + default_params=SamplingParams( + width=1024, height=1024, steps=40, scale=5.0, img2img_strength=0.15 + ), ), - ModelArchitecture.SDXL_V1_BASE: SamplingSpec( + ModelArchitecture.SDXL_V1_0_BASE: SamplingSpec( height=1024, width=1024, channels=4, @@ -152,8 +167,9 @@ model_specs = { config="sd_xl_base.yaml", ckpt="sd_xl_base_1.0.safetensors", is_guided=True, + default_params=SamplingParams(width=1024, height=1024, steps=40, scale=5.0), ), - ModelArchitecture.SDXL_V1_REFINER: SamplingSpec( + ModelArchitecture.SDXL_V1_0_REFINER: SamplingSpec( height=1024, width=1024, channels=4, @@ -162,10 +178,39 @@ model_specs = { config="sd_xl_refiner.yaml", ckpt="sd_xl_refiner_1.0.safetensors", is_guided=True, + default_params=SamplingParams( + width=1024, height=1024, steps=40, scale=5.0, img2img_strength=0.15 + ), ), } +def wrap_discretization( + discretization, image_strength=None, noise_strength=None, steps=None +): + if isinstance(discretization, Img2ImgDiscretizationWrapper) or isinstance( + discretization, Txt2NoisyDiscretizationWrapper + ): + return discretization # Already wrapped + if image_strength is not None and image_strength < 1.0 and image_strength > 0.0: + discretization = Img2ImgDiscretizationWrapper( + discretization, strength=image_strength + ) + + if ( + noise_strength is not None + and noise_strength < 1.0 + and noise_strength > 0.0 + and steps is not None + ): + discretization = Txt2NoisyDiscretizationWrapper( + discretization, + strength=noise_strength, + original_steps=steps, + ) + return discretization + + class SamplingPipeline: def __init__( self, @@ -231,17 +276,19 @@ class SamplingPipeline: def text_to_image( self, - params: SamplingParams, prompt: str, + params: Optional[SamplingParams] = None, negative_prompt: str = "", samples: int = 1, return_latents: bool = False, noise_strength: Optional[float] = None, filter=None, ): + if params is None: + params = self.specs.default_params sampler = get_sampler_config(params) - sampler.discretization = self.wrap_discretization( + sampler.discretization = wrap_discretization( sampler.discretization, image_strength=None, noise_strength=noise_strength, @@ -270,18 +317,20 @@ class SamplingPipeline: def image_to_image( self, - params: SamplingParams, image: torch.Tensor, prompt: str, + params: Optional[SamplingParams] = None, negative_prompt: str = "", samples: int = 1, return_latents: bool = False, noise_strength: Optional[float] = None, filter=None, ): + if params is None: + params = self.specs.default_params sampler = get_sampler_config(params) - sampler.discretization = self.wrap_discretization( + sampler.discretization = wrap_discretization( sampler.discretization, image_strength=params.img2img_strength, noise_strength=noise_strength, @@ -308,44 +357,20 @@ class SamplingPipeline: device=self.device_manager, ) - def wrap_discretization( - self, discretization, image_strength=None, noise_strength=None, steps=None - ): - if isinstance(discretization, Img2ImgDiscretizationWrapper) or isinstance( - discretization, Txt2NoisyDiscretizationWrapper - ): - return discretization # Already wrapped - if image_strength is not None and image_strength < 1.0 and image_strength > 0.0: - discretization = Img2ImgDiscretizationWrapper( - discretization, strength=image_strength - ) - - if ( - noise_strength is not None - and noise_strength < 1.0 - and noise_strength > 0.0 - and steps is not None - ): - discretization = Txt2NoisyDiscretizationWrapper( - discretization, - strength=noise_strength, - original_steps=steps, - ) - return discretization - def refiner( self, image: torch.Tensor, prompt: str, negative_prompt: str = "", - params: SamplingParams = SamplingParams( - sampler=Sampler.EULER_EDM, steps=40, img2img_strength=0.15 - ), + params: Optional[SamplingParams] = None, samples: int = 1, return_latents: bool = False, filter: Any = None, add_noise: bool = False, ): + if params is None: + params = self.specs.default_params + sampler = get_sampler_config(params) value_dict = { "orig_width": image.shape[3] * 8, diff --git a/tests/inference/test_inference.py b/tests/inference/test_inference.py index 617e408..04eceb7 100644 --- a/tests/inference/test_inference.py +++ b/tests/inference/test_inference.py @@ -27,7 +27,7 @@ class TestInference: @fixture( scope="class", params=[ - [ModelArchitecture.SDXL_V1_BASE, ModelArchitecture.SDXL_V1_REFINER], + [ModelArchitecture.SDXL_V1_0_BASE, ModelArchitecture.SDXL_V1_0_REFINER], [ModelArchitecture.SDXL_V0_9_BASE, ModelArchitecture.SDXL_V0_9_REFINER], ], ids=["SDXL_V1", "SDXL_V0_9"],