From 3e7ada70c503622b474db9e55c281b63f36b3047 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Thu, 10 Aug 2023 05:42:31 -0700 Subject: [PATCH] fix autocast --- sgm/inference/helpers.py | 30 ++++++++++-------------------- 1 file changed, 10 insertions(+), 20 deletions(-) diff --git a/sgm/inference/helpers.py b/sgm/inference/helpers.py index 314fe19..f86eda6 100644 --- a/sgm/inference/helpers.py +++ b/sgm/inference/helpers.py @@ -9,7 +9,6 @@ from PIL import Image from einops import rearrange from imwatermark import WatermarkEncoder from omegaconf import ListConfig -from torch import autocast from sgm.util import append_dims @@ -84,6 +83,14 @@ class DeviceModelManager(object): """ return model.to(self.device) + def autocast(self): + """ + Context manager that enables autocast for the device if supported. + """ + if self.device.type not in ("cuda", "cpu"): + return contextlib.nullcontext() + return torch.autocast(self.device.type) + @contextlib.contextmanager def use(self, model: torch.nn.Module): """ @@ -104,23 +111,6 @@ class CudaModelManager(DeviceModelManager): Device manager that loads a model to a CUDA device, optionally swapping to CPU when not in use. """ - def __init__( - self, - device: Union[torch.device, str] = "cuda", - swap_device: Union[torch.device, str] = None, - ): - """ - Args: - device (Union[torch.device, str]): The device to use for the model. - """ - super().__init__(device, swap_device) - - def load(self, model: Union[torch.nn.Module, torch.Tensor]): - """ - Loads a model to the device. - """ - return model.to(self.device) - @contextlib.contextmanager def use(self, model: Union[torch.nn.Module, torch.Tensor]): """ @@ -226,7 +216,7 @@ def do_sample( batch2model_input = [] with torch.no_grad(): - with autocast(device_manager.device): + with device_manager.autocast(): with model.ema_scope(): num_samples = [num_samples] with device_manager.use(model.conditioner): @@ -371,7 +361,7 @@ def do_img2img( device_manager=DeviceModelManager("cuda"), ): with torch.no_grad(): - with autocast(device_manager.device): + with device_manager.autocast(): with model.ema_scope(): with device_manager.use(model.conditioner): batch, batch_uc = get_batch(