fix autocast

This commit is contained in:
Stephan Auerhahn
2023-08-10 05:42:31 -07:00
parent de7a627978
commit 3e7ada70c5

View File

@@ -9,7 +9,6 @@ from PIL import Image
from einops import rearrange
from imwatermark import WatermarkEncoder
from omegaconf import ListConfig
from torch import autocast
from sgm.util import append_dims
@@ -84,6 +83,14 @@ class DeviceModelManager(object):
"""
return model.to(self.device)
def autocast(self):
"""
Context manager that enables autocast for the device if supported.
"""
if self.device.type not in ("cuda", "cpu"):
return contextlib.nullcontext()
return torch.autocast(self.device.type)
@contextlib.contextmanager
def use(self, model: torch.nn.Module):
"""
@@ -104,23 +111,6 @@ class CudaModelManager(DeviceModelManager):
Device manager that loads a model to a CUDA device, optionally swapping to CPU when not in use.
"""
def __init__(
self,
device: Union[torch.device, str] = "cuda",
swap_device: Union[torch.device, str] = None,
):
"""
Args:
device (Union[torch.device, str]): The device to use for the model.
"""
super().__init__(device, swap_device)
def load(self, model: Union[torch.nn.Module, torch.Tensor]):
"""
Loads a model to the device.
"""
return model.to(self.device)
@contextlib.contextmanager
def use(self, model: Union[torch.nn.Module, torch.Tensor]):
"""
@@ -226,7 +216,7 @@ def do_sample(
batch2model_input = []
with torch.no_grad():
with autocast(device_manager.device):
with device_manager.autocast():
with model.ema_scope():
num_samples = [num_samples]
with device_manager.use(model.conditioner):
@@ -371,7 +361,7 @@ def do_img2img(
device_manager=DeviceModelManager("cuda"),
):
with torch.no_grad():
with autocast(device_manager.device):
with device_manager.autocast():
with model.ema_scope():
with device_manager.use(model.conditioner):
batch, batch_uc = get_batch(