mirror of
https://github.com/Stability-AI/generative-models.git
synced 2026-02-02 04:14:27 +01:00
split fp16 and swapping functionality
This commit is contained in:
@@ -226,6 +226,11 @@ if __name__ == "__main__":
|
||||
mode = st.radio("Mode", ("txt2img", "img2img"), 0)
|
||||
st.write("__________________________")
|
||||
|
||||
st.write("### Performance Options")
|
||||
use_fp16 = st.checkbox("Use fp16", True)
|
||||
enable_swap = st.checkbox("Enable model swapping to CPU", False)
|
||||
st.write("__________________________")
|
||||
|
||||
if version_enum in sdxl_base_model_list:
|
||||
add_pipeline = st.checkbox("Load SDXL-refiner?", False)
|
||||
st.write("__________________________")
|
||||
@@ -237,11 +242,12 @@ if __name__ == "__main__":
|
||||
)
|
||||
seed_everything(seed)
|
||||
|
||||
lowvram_mode = st.checkbox("Low vram mode", True)
|
||||
|
||||
save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, str(version)))
|
||||
state = init_st(
|
||||
model_specs[version_enum], load_filter=True, lowvram_mode=lowvram_mode
|
||||
model_specs[version_enum],
|
||||
load_filter=True,
|
||||
use_fp16=use_fp16,
|
||||
enable_swap=enable_swap,
|
||||
)
|
||||
model = state["model"]
|
||||
|
||||
|
||||
@@ -25,21 +25,25 @@ from sgm.inference.helpers import embed_watermark, CudaModelManager
|
||||
|
||||
@st.cache_resource()
|
||||
def init_st(
|
||||
spec: SamplingSpec, load_ckpt=True, load_filter=True, lowvram_mode=True
|
||||
spec: SamplingSpec,
|
||||
load_ckpt=True,
|
||||
load_filter=True,
|
||||
use_fp16=True,
|
||||
enable_swap=True,
|
||||
) -> Dict[str, Any]:
|
||||
state: Dict[str, Any] = dict()
|
||||
if not "model" in state:
|
||||
config = spec.config
|
||||
ckpt = spec.ckpt
|
||||
|
||||
if lowvram_mode:
|
||||
if enable_swap:
|
||||
pipeline = SamplingPipeline(
|
||||
model_spec=spec,
|
||||
use_fp16=True,
|
||||
use_fp16=use_fp16,
|
||||
device=CudaModelManager(device="cuda", swap_device="cpu"),
|
||||
)
|
||||
else:
|
||||
pipeline = SamplingPipeline(model_spec=spec, use_fp16=False)
|
||||
pipeline = SamplingPipeline(model_spec=spec, use_fp16=use_fp16)
|
||||
|
||||
state["spec"] = spec
|
||||
state["model"] = pipeline
|
||||
|
||||
Reference in New Issue
Block a user