finish device manager refactor

This commit is contained in:
Stephan Auerhahn
2023-08-10 04:55:43 -07:00
parent e190ecc60b
commit 47805f233c
2 changed files with 97 additions and 85 deletions

View File

@@ -20,7 +20,7 @@ from sgm.inference.api import (
SamplingPipeline,
Thresholder,
)
from sgm.inference.helpers import embed_watermark, CudaModelLoader
from sgm.inference.helpers import embed_watermark, CudaModelManager
@st.cache_resource()
@@ -35,7 +35,7 @@ def init_st(spec: SamplingSpec, load_ckpt=True, load_filter=True) -> Dict[str, A
pipeline = SamplingPipeline(
model_spec=spec,
use_fp16=True,
model_loader=CudaModelLoader(device="cuda", swap_device="cpu"),
model_loader=CudaModelManager(device="cuda", swap_device="cpu"),
)
else:
pipeline = SamplingPipeline(model_spec=spec, use_fp16=False)

View File

@@ -10,7 +10,6 @@ from einops import rearrange
from imwatermark import WatermarkEncoder
from omegaconf import ListConfig
from torch import autocast
from abc import ABC, abstractmethod
from sgm.util import append_dims
@@ -60,6 +59,80 @@ WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
embed_watermark = WatermarkEmbedder(WATERMARK_BITS)
class DeviceModelManager(object):
"""
Default model loading class, should work for all device classes.
"""
def __init__(
self,
device: Union[torch.device, str],
swap_device: Optional[Union[torch.device, str]] = None,
):
"""
Args:
device (Union[torch.device, str]): The device to use for the model.
"""
self.device = torch.device(device)
self.swap_device = (
torch.device(swap_device) if swap_device is not None else self.device
)
def load(self, model: torch.nn.Module):
"""
Loads a model to the device.
"""
return model.to(self.device)
@contextlib.contextmanager
def use(self, model: torch.nn.Module):
"""
Context manager that ensures a model is on the correct device during use.
The default model loader does not perform any swapping, so the model will
stay on device.
"""
model.to(self.device)
yield
if self.device != self.swap_device:
model.to(self.swap_device)
class CudaModelManager(DeviceModelManager):
"""
Device manager that loads a model to a CUDA device, optionally swapping to CPU when not in use.
"""
def __init__(
self,
device: Union[torch.device, str] = "cuda",
swap_device: Union[torch.device, str] = None,
):
"""
Args:
device (Union[torch.device, str]): The device to use for the model.
"""
super().__init__(device, swap_device)
def load(self, model: Union[torch.nn.Module, torch.Tensor]):
"""
Loads a model to the device.
"""
return model.to(self.device)
@contextlib.contextmanager
def use(self, model: Union[torch.nn.Module, torch.Tensor]):
"""
Context manager that ensures a model is on the correct device during use.
If a swap device was provided, this will move the model to it after use and clear cache.
"""
model.to(self.device)
yield
if self.device != self.swap_device:
model.to(self.swap_device)
if torch.cuda.is_available():
torch.cuda.empty_cache()
def get_unique_embedder_keys_from_conditioner(conditioner):
return list({x.input_key for x in conditioner.embedders})
@@ -143,7 +216,7 @@ def do_sample(
batch2model_input: Optional[List] = None,
return_latents=False,
filter=None,
device="cuda",
device_manager: DeviceModelManager = DeviceModelManager("cuda"),
):
if force_uc_zero_embeddings is None:
force_uc_zero_embeddings = []
@@ -151,10 +224,10 @@ def do_sample(
batch2model_input = []
with torch.no_grad():
with autocast(device) as precision_scope:
with autocast(device_manager.device):
with model.ema_scope():
num_samples = [num_samples]
with swap_to_device(model.conditioner, device):
with device_manager.use(model.conditioner):
batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner),
value_dict,
@@ -176,7 +249,10 @@ def do_sample(
for k in c:
if not k == "crossattn":
c[k], uc[k] = map(
lambda y: y[k][: math.prod(num_samples)].to(device), (c, uc)
lambda y: y[k][: math.prod(num_samples)].to(
device_manager.device
),
(c, uc),
)
additional_model_inputs = {}
@@ -184,18 +260,18 @@ def do_sample(
additional_model_inputs[k] = batch[k]
shape = (math.prod(num_samples), C, H // F, W // F)
randn = torch.randn(shape).to(device)
randn = torch.randn(shape).to(device_manager.device)
def denoiser(input, sigma, c):
return model.denoiser(
model.model, input, sigma, c, **additional_model_inputs
)
with swap_to_device(model.denoiser, device):
with swap_to_device(model.model, device):
with device_manager.use(model.denoiser):
with device_manager.use(model.model):
samples_z = sampler(denoiser, randn, cond=c, uc=uc)
with swap_to_device(model.first_stage_model, device):
with device_manager.use(model.first_stage_model):
samples_x = model.decode_first_stage(samples_z)
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
@@ -290,12 +366,12 @@ def do_img2img(
skip_encode=False,
filter=None,
add_noise=True,
device="cuda",
device_manager=DeviceModelManager("cuda"),
):
with torch.no_grad():
with autocast(device):
with autocast(device_manager.device):
with model.ema_scope():
with swap_to_device(model.conditioner, device):
with device_manager.use(model.conditioner):
batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner),
value_dict,
@@ -308,14 +384,16 @@ def do_img2img(
)
for k in c:
c[k], uc[k] = map(lambda y: y[k][:num_samples].to(device), (c, uc))
c[k], uc[k] = map(
lambda y: y[k][:num_samples].to(device_manager.device), (c, uc)
)
for k in additional_kwargs:
c[k] = uc[k] = additional_kwargs[k]
if skip_encode:
z = img
else:
with swap_to_device(model.first_stage_model, device):
with device_manager.use(model.first_stage_model):
z = model.encode_first_stage(img)
noise = torch.randn_like(z)
@@ -338,11 +416,11 @@ def do_img2img(
def denoiser(x, sigma, c):
return model.denoiser(model.model, x, sigma, c)
with swap_to_device(model.denoiser, device):
with swap_to_device(model.model, device):
with device_manager.use(model.denoiser):
with device_manager.use(model.model):
samples_z = sampler(denoiser, noised_z, cond=c, uc=uc)
with swap_to_device(model.first_stage_model, device):
with device_manager.use(model.first_stage_model):
samples_x = model.decode_first_stage(samples_z)
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
@@ -352,69 +430,3 @@ def do_img2img(
if return_latents:
return samples, samples_z
return samples
class BaseDeviceModelLoader(ABC):
"""
Base class for device managers. Device managers are used to manage the device used for a model.
"""
@abstractmethod
def __init__(self, device: Union[torch.device, str]):
"""
Args:
device (Union[torch.device, str]): The device to use for the model.
"""
pass
def load(self, model: torch.nn.Module):
"""
Loads a model to the device.
"""
pass
@contextlib.contextmanager
def use(self, model: torch.nn.Module):
"""
Context manager that ensures a model is on the correct device during use.
"""
yield
class CudaModelLoader(BaseDeviceModelLoader):
"""
Device manager that loads a model to a CUDA device, optionally swapping to CPU when not in use.
"""
def __init__(
self,
device: Union[torch.device, str] = "cuda",
swap_device: Union[torch.device, str] = None,
):
"""
Args:
device (Union[torch.device, str]): The device to use for the model.
"""
self.device = torch.device(device)
self.swap_device = (
torch.device(swap_device) if swap_device is not None else self.device
)
def load(self, model: Union[torch.nn.Module, torch.Tensor]):
"""
Loads a model to the device.
"""
model.to(self.swap_device)
@contextlib.contextmanager
def use(self, model: Union[torch.nn.Module, torch.Tensor]):
"""
Context manager that ensures a model is on the correct device during use.
"""
if self.device != self.swap_device:
model.to(self.device)
yield
if self.device != self.swap_device:
model.to(self.swap_device)
if torch.cuda.is_available():
torch.cuda.empty_cache()