mirror of
https://github.com/Stability-AI/generative-models.git
synced 2026-02-18 12:14:26 +01:00
path helper & model swapping rewrite
This commit is contained in:
@@ -20,9 +20,7 @@ from sgm.inference.api import (
|
||||
SamplingPipeline,
|
||||
Thresholder,
|
||||
)
|
||||
from sgm.inference.helpers import (
|
||||
embed_watermark,
|
||||
)
|
||||
from sgm.inference.helpers import embed_watermark, CudaModelLoader
|
||||
|
||||
|
||||
@st.cache_resource()
|
||||
@@ -35,10 +33,12 @@ def init_st(spec: SamplingSpec, load_ckpt=True, load_filter=True) -> Dict[str, A
|
||||
|
||||
if lowvram_mode:
|
||||
pipeline = SamplingPipeline(
|
||||
model_spec=spec, use_fp16=True, device="cuda", swap_device="cpu"
|
||||
model_spec=spec,
|
||||
use_fp16=True,
|
||||
model_loader=CudaModelLoader(device="cuda", swap_device="cpu"),
|
||||
)
|
||||
else:
|
||||
pipeline = SamplingPipeline(model_spec=spec, use_fp16=True, device="cuda")
|
||||
pipeline = SamplingPipeline(model_spec=spec, use_fp16=False)
|
||||
|
||||
state["spec"] = spec
|
||||
state["model"] = pipeline
|
||||
|
||||
Reference in New Issue
Block a user