From 9b18e6fa19c31abf2bc7b7816c7d287c1a43c23c Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Thu, 10 Aug 2023 05:07:22 -0700 Subject: [PATCH] update api module --- sgm/inference/api.py | 21 +++++++++++---------- sgm/inference/helpers.py | 10 ++++++---- 2 files changed, 17 insertions(+), 14 deletions(-) diff --git a/sgm/inference/api.py b/sgm/inference/api.py index 0588a26..8516e73 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -5,8 +5,8 @@ import os from sgm.inference.helpers import ( do_sample, do_img2img, - BaseDeviceModelLoader, - CudaModelLoader, + DeviceModelManager, + CudaModelManager, Img2ImgDiscretizationWrapper, Txt2NoisyDiscretizationWrapper, ) @@ -167,7 +167,7 @@ class SamplingPipeline: model_path: Optional[str] = None, config_path: Optional[str] = None, use_fp16: bool = True, - model_loader: BaseDeviceModelLoader = CudaModelLoader(device="cuda"), + device_manager: DeviceModelManager = CudaModelManager(device="cuda"), ) -> None: """ Sampling pipeline for generating images from a model. @@ -204,17 +204,18 @@ class SamplingPipeline: raise ValueError( f"Checkpoint {self.ckpt} not found, check model spec or config_path" ) - self.device = device - self.swap_device = swap_device - load_device = device if swap_device is None else swap_device - self.model = self._load_model(device=load_device, use_fp16=use_fp16) - def _load_model(self, device="cuda", use_fp16=True): + self.model_manager = device_manager + self.model = self._load_model( + device_manager=self.model_manager, use_fp16=use_fp16 + ) + + def _load_model(self, device_manager: DeviceModelManager, use_fp16=True): config = OmegaConf.load(self.config) model = load_model_from_config(config, self.ckpt) if model is None: raise ValueError(f"Model {self.model_id} could not be loaded") - model.to(device) + device_manager.load(model) if use_fp16: model.conditioner.half() model.model.half() @@ -256,7 +257,7 @@ class SamplingPipeline: force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [], return_latents=return_latents, filter=filter, - device=self.device, + model_manager=self.model_manager, ) def image_to_image( diff --git a/sgm/inference/helpers.py b/sgm/inference/helpers.py index 095cf08..314fe19 100644 --- a/sgm/inference/helpers.py +++ b/sgm/inference/helpers.py @@ -91,10 +91,12 @@ class DeviceModelManager(object): The default model loader does not perform any swapping, so the model will stay on device. """ - model.to(self.device) - yield - if self.device != self.swap_device: - model.to(self.swap_device) + try: + model.to(self.device) + yield + finally: + if self.device != self.swap_device: + model.to(self.swap_device) class CudaModelManager(DeviceModelManager):