finish device manager refactor

This commit is contained in:
Stephan Auerhahn
2023-08-10 04:55:43 -07:00
parent e190ecc60b
commit 47805f233c
2 changed files with 97 additions and 85 deletions

View File

@@ -20,7 +20,7 @@ from sgm.inference.api import (
SamplingPipeline,
Thresholder,
)
from sgm.inference.helpers import embed_watermark, CudaModelLoader
from sgm.inference.helpers import embed_watermark, CudaModelManager
@st.cache_resource()
@@ -35,7 +35,7 @@ def init_st(spec: SamplingSpec, load_ckpt=True, load_filter=True) -> Dict[str, A
pipeline = SamplingPipeline(
model_spec=spec,
use_fp16=True,
model_loader=CudaModelLoader(device="cuda", swap_device="cpu"),
model_loader=CudaModelManager(device="cuda", swap_device="cpu"),
)
else:
pipeline = SamplingPipeline(model_spec=spec, use_fp16=False)