low vram checkbox fix, remove magic strings

This commit is contained in:
Stephan Auerhahn
2023-08-10 12:40:32 -07:00
parent 26b10f56f3
commit a25662e969
2 changed files with 14 additions and 17 deletions

View File

@@ -23,7 +23,6 @@ from scripts.demo.streamlit_helpers import (
init_sampling,
init_save_locally,
init_st,
set_lowvram_mode,
show_samples,
)
@@ -205,6 +204,16 @@ def apply_refiner(
return samples
sdxl_base_model_list = [
ModelArchitecture.SDXL_V1_BASE,
ModelArchitecture.SDXL_V0_9_BASE,
]
sdxl_refiner_model_list = [
ModelArchitecture.SDXL_V1_REFINER,
ModelArchitecture.SDXL_V0_9_REFINER,
]
if __name__ == "__main__":
st.title("Stable Diffusion")
version = st.selectbox(
@@ -217,9 +226,7 @@ if __name__ == "__main__":
mode = st.radio("Mode", ("txt2img", "img2img"), 0)
st.write("__________________________")
set_lowvram_mode(st.checkbox("Low vram mode", True))
if str(version).startswith("stable-diffusion-xl"):
if version_enum in sdxl_base_model_list:
add_pipeline = st.checkbox("Load SDXL-refiner?", False)
st.write("__________________________")
else:
@@ -253,10 +260,7 @@ if __name__ == "__main__":
version2 = ModelArchitecture(
st.selectbox(
"Refiner:",
[
ModelArchitecture.SDXL_V1_REFINER.value,
ModelArchitecture.SDXL_V0_9_REFINER.value,
],
[member.value for member in sdxl_refiner_model_list],
)
)
st.warning(

View File

@@ -25,12 +25,13 @@ from sgm.inference.helpers import embed_watermark, CudaModelManager
@st.cache_resource()
def init_st(spec: SamplingSpec, load_ckpt=True, load_filter=True) -> Dict[str, Any]:
global lowvram_mode
state: Dict[str, Any] = dict()
if not "model" in state:
config = spec.config
ckpt = spec.ckpt
lowvram_mode = st.checkbox("Low VRAM mode", value=False)
if lowvram_mode:
pipeline = SamplingPipeline(
model_spec=spec,
@@ -52,14 +53,6 @@ def init_st(spec: SamplingSpec, load_ckpt=True, load_filter=True) -> Dict[str, A
return state
lowvram_mode = False
def set_lowvram_mode(mode):
global lowvram_mode
lowvram_mode = mode
def get_unique_embedder_keys_from_conditioner(conditioner):
return list(set([x.input_key for x in conditioner.embedders]))