mirror of
https://github.com/Stability-AI/generative-models.git
synced 2026-02-02 04:14:27 +01:00
low vram checkbox fix, remove magic strings
This commit is contained in:
@@ -23,7 +23,6 @@ from scripts.demo.streamlit_helpers import (
|
||||
init_sampling,
|
||||
init_save_locally,
|
||||
init_st,
|
||||
set_lowvram_mode,
|
||||
show_samples,
|
||||
)
|
||||
|
||||
@@ -205,6 +204,16 @@ def apply_refiner(
|
||||
return samples
|
||||
|
||||
|
||||
sdxl_base_model_list = [
|
||||
ModelArchitecture.SDXL_V1_BASE,
|
||||
ModelArchitecture.SDXL_V0_9_BASE,
|
||||
]
|
||||
|
||||
sdxl_refiner_model_list = [
|
||||
ModelArchitecture.SDXL_V1_REFINER,
|
||||
ModelArchitecture.SDXL_V0_9_REFINER,
|
||||
]
|
||||
|
||||
if __name__ == "__main__":
|
||||
st.title("Stable Diffusion")
|
||||
version = st.selectbox(
|
||||
@@ -217,9 +226,7 @@ if __name__ == "__main__":
|
||||
mode = st.radio("Mode", ("txt2img", "img2img"), 0)
|
||||
st.write("__________________________")
|
||||
|
||||
set_lowvram_mode(st.checkbox("Low vram mode", True))
|
||||
|
||||
if str(version).startswith("stable-diffusion-xl"):
|
||||
if version_enum in sdxl_base_model_list:
|
||||
add_pipeline = st.checkbox("Load SDXL-refiner?", False)
|
||||
st.write("__________________________")
|
||||
else:
|
||||
@@ -253,10 +260,7 @@ if __name__ == "__main__":
|
||||
version2 = ModelArchitecture(
|
||||
st.selectbox(
|
||||
"Refiner:",
|
||||
[
|
||||
ModelArchitecture.SDXL_V1_REFINER.value,
|
||||
ModelArchitecture.SDXL_V0_9_REFINER.value,
|
||||
],
|
||||
[member.value for member in sdxl_refiner_model_list],
|
||||
)
|
||||
)
|
||||
st.warning(
|
||||
|
||||
@@ -25,12 +25,13 @@ from sgm.inference.helpers import embed_watermark, CudaModelManager
|
||||
|
||||
@st.cache_resource()
|
||||
def init_st(spec: SamplingSpec, load_ckpt=True, load_filter=True) -> Dict[str, Any]:
|
||||
global lowvram_mode
|
||||
state: Dict[str, Any] = dict()
|
||||
if not "model" in state:
|
||||
config = spec.config
|
||||
ckpt = spec.ckpt
|
||||
|
||||
lowvram_mode = st.checkbox("Low VRAM mode", value=False)
|
||||
|
||||
if lowvram_mode:
|
||||
pipeline = SamplingPipeline(
|
||||
model_spec=spec,
|
||||
@@ -52,14 +53,6 @@ def init_st(spec: SamplingSpec, load_ckpt=True, load_filter=True) -> Dict[str, A
|
||||
return state
|
||||
|
||||
|
||||
lowvram_mode = False
|
||||
|
||||
|
||||
def set_lowvram_mode(mode):
|
||||
global lowvram_mode
|
||||
lowvram_mode = mode
|
||||
|
||||
|
||||
def get_unique_embedder_keys_from_conditioner(conditioner):
|
||||
return list(set([x.input_key for x in conditioner.embedders]))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user