mirror of
https://github.com/Stability-AI/generative-models.git
synced 2026-02-02 04:14:27 +01:00
fix autocast
This commit is contained in:
@@ -9,7 +9,6 @@ from PIL import Image
|
||||
from einops import rearrange
|
||||
from imwatermark import WatermarkEncoder
|
||||
from omegaconf import ListConfig
|
||||
from torch import autocast
|
||||
|
||||
from sgm.util import append_dims
|
||||
|
||||
@@ -84,6 +83,14 @@ class DeviceModelManager(object):
|
||||
"""
|
||||
return model.to(self.device)
|
||||
|
||||
def autocast(self):
|
||||
"""
|
||||
Context manager that enables autocast for the device if supported.
|
||||
"""
|
||||
if self.device.type not in ("cuda", "cpu"):
|
||||
return contextlib.nullcontext()
|
||||
return torch.autocast(self.device.type)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def use(self, model: torch.nn.Module):
|
||||
"""
|
||||
@@ -104,23 +111,6 @@ class CudaModelManager(DeviceModelManager):
|
||||
Device manager that loads a model to a CUDA device, optionally swapping to CPU when not in use.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device: Union[torch.device, str] = "cuda",
|
||||
swap_device: Union[torch.device, str] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
device (Union[torch.device, str]): The device to use for the model.
|
||||
"""
|
||||
super().__init__(device, swap_device)
|
||||
|
||||
def load(self, model: Union[torch.nn.Module, torch.Tensor]):
|
||||
"""
|
||||
Loads a model to the device.
|
||||
"""
|
||||
return model.to(self.device)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def use(self, model: Union[torch.nn.Module, torch.Tensor]):
|
||||
"""
|
||||
@@ -226,7 +216,7 @@ def do_sample(
|
||||
batch2model_input = []
|
||||
|
||||
with torch.no_grad():
|
||||
with autocast(device_manager.device):
|
||||
with device_manager.autocast():
|
||||
with model.ema_scope():
|
||||
num_samples = [num_samples]
|
||||
with device_manager.use(model.conditioner):
|
||||
@@ -371,7 +361,7 @@ def do_img2img(
|
||||
device_manager=DeviceModelManager("cuda"),
|
||||
):
|
||||
with torch.no_grad():
|
||||
with autocast(device_manager.device):
|
||||
with device_manager.autocast():
|
||||
with model.ema_scope():
|
||||
with device_manager.use(model.conditioner):
|
||||
batch, batch_uc = get_batch(
|
||||
|
||||
Reference in New Issue
Block a user