From a25662e969fc9a4e8df74e4917d2adca37621591 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Thu, 10 Aug 2023 12:40:32 -0700 Subject: [PATCH] low vram checkbox fix, remove magic strings --- scripts/demo/sampling.py | 20 ++++++++++++-------- scripts/demo/streamlit_helpers.py | 11 ++--------- 2 files changed, 14 insertions(+), 17 deletions(-) diff --git a/scripts/demo/sampling.py b/scripts/demo/sampling.py index 5a709e0..4dca18d 100644 --- a/scripts/demo/sampling.py +++ b/scripts/demo/sampling.py @@ -23,7 +23,6 @@ from scripts.demo.streamlit_helpers import ( init_sampling, init_save_locally, init_st, - set_lowvram_mode, show_samples, ) @@ -205,6 +204,16 @@ def apply_refiner( return samples +sdxl_base_model_list = [ + ModelArchitecture.SDXL_V1_BASE, + ModelArchitecture.SDXL_V0_9_BASE, +] + +sdxl_refiner_model_list = [ + ModelArchitecture.SDXL_V1_REFINER, + ModelArchitecture.SDXL_V0_9_REFINER, +] + if __name__ == "__main__": st.title("Stable Diffusion") version = st.selectbox( @@ -217,9 +226,7 @@ if __name__ == "__main__": mode = st.radio("Mode", ("txt2img", "img2img"), 0) st.write("__________________________") - set_lowvram_mode(st.checkbox("Low vram mode", True)) - - if str(version).startswith("stable-diffusion-xl"): + if version_enum in sdxl_base_model_list: add_pipeline = st.checkbox("Load SDXL-refiner?", False) st.write("__________________________") else: @@ -253,10 +260,7 @@ if __name__ == "__main__": version2 = ModelArchitecture( st.selectbox( "Refiner:", - [ - ModelArchitecture.SDXL_V1_REFINER.value, - ModelArchitecture.SDXL_V0_9_REFINER.value, - ], + [member.value for member in sdxl_refiner_model_list], ) ) st.warning( diff --git a/scripts/demo/streamlit_helpers.py b/scripts/demo/streamlit_helpers.py index 119ffd7..59cd27b 100644 --- a/scripts/demo/streamlit_helpers.py +++ b/scripts/demo/streamlit_helpers.py @@ -25,12 +25,13 @@ from sgm.inference.helpers import embed_watermark, CudaModelManager @st.cache_resource() def init_st(spec: SamplingSpec, load_ckpt=True, load_filter=True) -> Dict[str, Any]: - global lowvram_mode state: Dict[str, Any] = dict() if not "model" in state: config = spec.config ckpt = spec.ckpt + lowvram_mode = st.checkbox("Low VRAM mode", value=False) + if lowvram_mode: pipeline = SamplingPipeline( model_spec=spec, @@ -52,14 +53,6 @@ def init_st(spec: SamplingSpec, load_ckpt=True, load_filter=True) -> Dict[str, A return state -lowvram_mode = False - - -def set_lowvram_mode(mode): - global lowvram_mode - lowvram_mode = mode - - def get_unique_embedder_keys_from_conditioner(conditioner): return list(set([x.input_key for x in conditioner.embedders]))