update api module

This commit is contained in:
Stephan Auerhahn
2023-08-10 05:07:22 -07:00
parent 47805f233c
commit 9b18e6fa19
2 changed files with 17 additions and 14 deletions

View File

@@ -5,8 +5,8 @@ import os
from sgm.inference.helpers import (
do_sample,
do_img2img,
BaseDeviceModelLoader,
CudaModelLoader,
DeviceModelManager,
CudaModelManager,
Img2ImgDiscretizationWrapper,
Txt2NoisyDiscretizationWrapper,
)
@@ -167,7 +167,7 @@ class SamplingPipeline:
model_path: Optional[str] = None,
config_path: Optional[str] = None,
use_fp16: bool = True,
model_loader: BaseDeviceModelLoader = CudaModelLoader(device="cuda"),
device_manager: DeviceModelManager = CudaModelManager(device="cuda"),
) -> None:
"""
Sampling pipeline for generating images from a model.
@@ -204,17 +204,18 @@ class SamplingPipeline:
raise ValueError(
f"Checkpoint {self.ckpt} not found, check model spec or config_path"
)
self.device = device
self.swap_device = swap_device
load_device = device if swap_device is None else swap_device
self.model = self._load_model(device=load_device, use_fp16=use_fp16)
def _load_model(self, device="cuda", use_fp16=True):
self.model_manager = device_manager
self.model = self._load_model(
device_manager=self.model_manager, use_fp16=use_fp16
)
def _load_model(self, device_manager: DeviceModelManager, use_fp16=True):
config = OmegaConf.load(self.config)
model = load_model_from_config(config, self.ckpt)
if model is None:
raise ValueError(f"Model {self.model_id} could not be loaded")
model.to(device)
device_manager.load(model)
if use_fp16:
model.conditioner.half()
model.model.half()
@@ -256,7 +257,7 @@ class SamplingPipeline:
force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
return_latents=return_latents,
filter=filter,
device=self.device,
model_manager=self.model_manager,
)
def image_to_image(

View File

@@ -91,10 +91,12 @@ class DeviceModelManager(object):
The default model loader does not perform any swapping, so the model will
stay on device.
"""
model.to(self.device)
yield
if self.device != self.swap_device:
model.to(self.swap_device)
try:
model.to(self.device)
yield
finally:
if self.device != self.swap_device:
model.to(self.swap_device)
class CudaModelManager(DeviceModelManager):