mirror of
https://github.com/Stability-AI/generative-models.git
synced 2026-02-08 07:14:25 +01:00
finish device manager refactor
This commit is contained in:
@@ -20,7 +20,7 @@ from sgm.inference.api import (
|
||||
SamplingPipeline,
|
||||
Thresholder,
|
||||
)
|
||||
from sgm.inference.helpers import embed_watermark, CudaModelLoader
|
||||
from sgm.inference.helpers import embed_watermark, CudaModelManager
|
||||
|
||||
|
||||
@st.cache_resource()
|
||||
@@ -35,7 +35,7 @@ def init_st(spec: SamplingSpec, load_ckpt=True, load_filter=True) -> Dict[str, A
|
||||
pipeline = SamplingPipeline(
|
||||
model_spec=spec,
|
||||
use_fp16=True,
|
||||
model_loader=CudaModelLoader(device="cuda", swap_device="cpu"),
|
||||
model_loader=CudaModelManager(device="cuda", swap_device="cpu"),
|
||||
)
|
||||
else:
|
||||
pipeline = SamplingPipeline(model_spec=spec, use_fp16=False)
|
||||
|
||||
Reference in New Issue
Block a user