diff --git a/scripts/demo/sampling.py b/scripts/demo/sampling.py index 9bf0dff..13d8db3 100644 --- a/scripts/demo/sampling.py +++ b/scripts/demo/sampling.py @@ -226,6 +226,11 @@ if __name__ == "__main__": mode = st.radio("Mode", ("txt2img", "img2img"), 0) st.write("__________________________") + st.write("### Performance Options") + use_fp16 = st.checkbox("Use fp16", True) + enable_swap = st.checkbox("Enable model swapping to CPU", False) + st.write("__________________________") + if version_enum in sdxl_base_model_list: add_pipeline = st.checkbox("Load SDXL-refiner?", False) st.write("__________________________") @@ -237,11 +242,12 @@ if __name__ == "__main__": ) seed_everything(seed) - lowvram_mode = st.checkbox("Low vram mode", True) - save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, str(version))) state = init_st( - model_specs[version_enum], load_filter=True, lowvram_mode=lowvram_mode + model_specs[version_enum], + load_filter=True, + use_fp16=use_fp16, + enable_swap=enable_swap, ) model = state["model"] diff --git a/scripts/demo/streamlit_helpers.py b/scripts/demo/streamlit_helpers.py index fec7d33..84c4e62 100644 --- a/scripts/demo/streamlit_helpers.py +++ b/scripts/demo/streamlit_helpers.py @@ -25,21 +25,25 @@ from sgm.inference.helpers import embed_watermark, CudaModelManager @st.cache_resource() def init_st( - spec: SamplingSpec, load_ckpt=True, load_filter=True, lowvram_mode=True + spec: SamplingSpec, + load_ckpt=True, + load_filter=True, + use_fp16=True, + enable_swap=True, ) -> Dict[str, Any]: state: Dict[str, Any] = dict() if not "model" in state: config = spec.config ckpt = spec.ckpt - if lowvram_mode: + if enable_swap: pipeline = SamplingPipeline( model_spec=spec, - use_fp16=True, + use_fp16=use_fp16, device=CudaModelManager(device="cuda", swap_device="cpu"), ) else: - pipeline = SamplingPipeline(model_spec=spec, use_fp16=False) + pipeline = SamplingPipeline(model_spec=spec, use_fp16=use_fp16) state["spec"] = spec state["model"] = pipeline