From f6704532a0c50eaf7961600d1e517a18e2060740 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Sat, 12 Aug 2023 07:27:25 +0000 Subject: [PATCH] abstract device defaults --- sgm/inference/api.py | 15 ++++++--------- sgm/inference/helpers.py | 28 ++++++++++++++++++---------- 2 files changed, 24 insertions(+), 19 deletions(-) diff --git a/sgm/inference/api.py b/sgm/inference/api.py index afb1f72..9ca1111 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -174,7 +174,7 @@ class SamplingPipeline: model_path: Optional[str] = None, config_path: Optional[str] = None, use_fp16: bool = True, - device: Optional[Union[DeviceModelManager, str, torch.device]] = None + device: Optional[Union[DeviceModelManager, str, torch.device]] = None, ) -> None: """ Sampling pipeline for generating images from a model. @@ -211,16 +211,13 @@ class SamplingPipeline: raise ValueError( f"Checkpoint {self.ckpt} not found, check model spec or config_path" ) - if not isinstance(device, DeviceModelManager): - self.device_manager = get_model_manager(device=device) - else: - self.device_manager = device + + self.device_manager = get_model_manager(device) self.model = self._load_model( device_manager=self.device_manager, use_fp16=use_fp16 ) - def _load_model(self, device_manager: DeviceModelManager, use_fp16=True): config = OmegaConf.load(self.config) model = load_model_from_config(config, self.ckpt) @@ -268,7 +265,7 @@ class SamplingPipeline: force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [], return_latents=return_latents, filter=filter, - device_manager=self.device_manager, + device=self.device_manager, ) def image_to_image( @@ -308,7 +305,7 @@ class SamplingPipeline: force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [], return_latents=return_latents, filter=filter, - device_manager=self.device_manager, + device=self.device_manager, ) def wrap_discretization( @@ -377,7 +374,7 @@ class SamplingPipeline: return_latents=return_latents, filter=filter, add_noise=add_noise, - device_manager=self.device_manager, + device=self.device_manager, ) diff --git a/sgm/inference/helpers.py b/sgm/inference/helpers.py index addefe3..e84b8a2 100644 --- a/sgm/inference/helpers.py +++ b/sgm/inference/helpers.py @@ -109,7 +109,7 @@ class DeviceModelManager(object): class CudaModelManager(DeviceModelManager): """ Device manager that loads a model to a CUDA device, optionally swapping to CPU when not in use. - """ + """ @contextlib.contextmanager def use(self, model: Union[torch.nn.Module, torch.Tensor]): @@ -141,14 +141,19 @@ def perform_save_locally(save_path, samples): base_count += 1 -def get_model_manager(device: Union[str,torch.device]) -> DeviceModelManager: - if isinstance(device, torch.device) or isinstance(device, str): - if torch.device(device).type == "cuda": - return CudaModelManager(device=device) - else: - return DeviceModelManager(device=device) - else: +def get_model_manager( + device: Optional[Union[DeviceModelManager, str, torch.device]] +) -> DeviceModelManager: + if isinstance(device, DeviceModelManager): return device + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + device = torch.device(device) + if device.type == "cuda": + return CudaModelManager(device=device) + else: + return DeviceModelManager(device=device) + class Img2ImgDiscretizationWrapper: """ @@ -217,13 +222,15 @@ def do_sample( batch2model_input: Optional[List] = None, return_latents=False, filter=None, - device_manager: DeviceModelManager = DeviceModelManager("cuda"), + device: Optional[Union[DeviceModelManager, str, torch.device]] = None, ): if force_uc_zero_embeddings is None: force_uc_zero_embeddings = [] if batch2model_input is None: batch2model_input = [] + device_manager = get_model_manager(device=device) + with torch.no_grad(): with device_manager.autocast(): with model.ema_scope(): @@ -367,8 +374,9 @@ def do_img2img( skip_encode=False, filter=None, add_noise=True, - device_manager=DeviceModelManager("cuda"), + device: Optional[Union[DeviceModelManager, str, torch.device]] = None, ): + device_manager = get_model_manager(device) with torch.no_grad(): with device_manager.autocast(): with model.ema_scope():