mirror of
https://github.com/Stability-AI/generative-models.git
synced 2026-02-02 04:14:27 +01:00
update api module
This commit is contained in:
@@ -5,8 +5,8 @@ import os
|
||||
from sgm.inference.helpers import (
|
||||
do_sample,
|
||||
do_img2img,
|
||||
BaseDeviceModelLoader,
|
||||
CudaModelLoader,
|
||||
DeviceModelManager,
|
||||
CudaModelManager,
|
||||
Img2ImgDiscretizationWrapper,
|
||||
Txt2NoisyDiscretizationWrapper,
|
||||
)
|
||||
@@ -167,7 +167,7 @@ class SamplingPipeline:
|
||||
model_path: Optional[str] = None,
|
||||
config_path: Optional[str] = None,
|
||||
use_fp16: bool = True,
|
||||
model_loader: BaseDeviceModelLoader = CudaModelLoader(device="cuda"),
|
||||
device_manager: DeviceModelManager = CudaModelManager(device="cuda"),
|
||||
) -> None:
|
||||
"""
|
||||
Sampling pipeline for generating images from a model.
|
||||
@@ -204,17 +204,18 @@ class SamplingPipeline:
|
||||
raise ValueError(
|
||||
f"Checkpoint {self.ckpt} not found, check model spec or config_path"
|
||||
)
|
||||
self.device = device
|
||||
self.swap_device = swap_device
|
||||
load_device = device if swap_device is None else swap_device
|
||||
self.model = self._load_model(device=load_device, use_fp16=use_fp16)
|
||||
|
||||
def _load_model(self, device="cuda", use_fp16=True):
|
||||
self.model_manager = device_manager
|
||||
self.model = self._load_model(
|
||||
device_manager=self.model_manager, use_fp16=use_fp16
|
||||
)
|
||||
|
||||
def _load_model(self, device_manager: DeviceModelManager, use_fp16=True):
|
||||
config = OmegaConf.load(self.config)
|
||||
model = load_model_from_config(config, self.ckpt)
|
||||
if model is None:
|
||||
raise ValueError(f"Model {self.model_id} could not be loaded")
|
||||
model.to(device)
|
||||
device_manager.load(model)
|
||||
if use_fp16:
|
||||
model.conditioner.half()
|
||||
model.model.half()
|
||||
@@ -256,7 +257,7 @@ class SamplingPipeline:
|
||||
force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
|
||||
return_latents=return_latents,
|
||||
filter=filter,
|
||||
device=self.device,
|
||||
model_manager=self.model_manager,
|
||||
)
|
||||
|
||||
def image_to_image(
|
||||
|
||||
@@ -91,10 +91,12 @@ class DeviceModelManager(object):
|
||||
The default model loader does not perform any swapping, so the model will
|
||||
stay on device.
|
||||
"""
|
||||
model.to(self.device)
|
||||
yield
|
||||
if self.device != self.swap_device:
|
||||
model.to(self.swap_device)
|
||||
try:
|
||||
model.to(self.device)
|
||||
yield
|
||||
finally:
|
||||
if self.device != self.swap_device:
|
||||
model.to(self.swap_device)
|
||||
|
||||
|
||||
class CudaModelManager(DeviceModelManager):
|
||||
|
||||
Reference in New Issue
Block a user