split fp16 and swapping functionality

This commit is contained in:
Stephan Auerhahn
2023-08-10 13:14:38 -07:00
parent 3816aaa639
commit 2aebc8882d
2 changed files with 17 additions and 7 deletions

View File

@@ -226,6 +226,11 @@ if __name__ == "__main__":
mode = st.radio("Mode", ("txt2img", "img2img"), 0)
st.write("__________________________")
st.write("### Performance Options")
use_fp16 = st.checkbox("Use fp16", True)
enable_swap = st.checkbox("Enable model swapping to CPU", False)
st.write("__________________________")
if version_enum in sdxl_base_model_list:
add_pipeline = st.checkbox("Load SDXL-refiner?", False)
st.write("__________________________")
@@ -237,11 +242,12 @@ if __name__ == "__main__":
)
seed_everything(seed)
lowvram_mode = st.checkbox("Low vram mode", True)
save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, str(version)))
state = init_st(
model_specs[version_enum], load_filter=True, lowvram_mode=lowvram_mode
model_specs[version_enum],
load_filter=True,
use_fp16=use_fp16,
enable_swap=enable_swap,
)
model = state["model"]

View File

@@ -25,21 +25,25 @@ from sgm.inference.helpers import embed_watermark, CudaModelManager
@st.cache_resource()
def init_st(
spec: SamplingSpec, load_ckpt=True, load_filter=True, lowvram_mode=True
spec: SamplingSpec,
load_ckpt=True,
load_filter=True,
use_fp16=True,
enable_swap=True,
) -> Dict[str, Any]:
state: Dict[str, Any] = dict()
if not "model" in state:
config = spec.config
ckpt = spec.ckpt
if lowvram_mode:
if enable_swap:
pipeline = SamplingPipeline(
model_spec=spec,
use_fp16=True,
use_fp16=use_fp16,
device=CudaModelManager(device="cuda", swap_device="cpu"),
)
else:
pipeline = SamplingPipeline(model_spec=spec, use_fp16=False)
pipeline = SamplingPipeline(model_spec=spec, use_fp16=use_fp16)
state["spec"] = spec
state["model"] = pipeline