From 47805f233cae54019860ff4adb07e3d19e54cbb9 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Thu, 10 Aug 2023 04:55:43 -0700 Subject: [PATCH] finish device manager refactor --- scripts/demo/streamlit_helpers.py | 4 +- sgm/inference/helpers.py | 178 ++++++++++++++++-------------- 2 files changed, 97 insertions(+), 85 deletions(-) diff --git a/scripts/demo/streamlit_helpers.py b/scripts/demo/streamlit_helpers.py index 5838741..5b0214a 100644 --- a/scripts/demo/streamlit_helpers.py +++ b/scripts/demo/streamlit_helpers.py @@ -20,7 +20,7 @@ from sgm.inference.api import ( SamplingPipeline, Thresholder, ) -from sgm.inference.helpers import embed_watermark, CudaModelLoader +from sgm.inference.helpers import embed_watermark, CudaModelManager @st.cache_resource() @@ -35,7 +35,7 @@ def init_st(spec: SamplingSpec, load_ckpt=True, load_filter=True) -> Dict[str, A pipeline = SamplingPipeline( model_spec=spec, use_fp16=True, - model_loader=CudaModelLoader(device="cuda", swap_device="cpu"), + model_loader=CudaModelManager(device="cuda", swap_device="cpu"), ) else: pipeline = SamplingPipeline(model_spec=spec, use_fp16=False) diff --git a/sgm/inference/helpers.py b/sgm/inference/helpers.py index bef2fb3..095cf08 100644 --- a/sgm/inference/helpers.py +++ b/sgm/inference/helpers.py @@ -10,7 +10,6 @@ from einops import rearrange from imwatermark import WatermarkEncoder from omegaconf import ListConfig from torch import autocast -from abc import ABC, abstractmethod from sgm.util import append_dims @@ -60,6 +59,80 @@ WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]] embed_watermark = WatermarkEmbedder(WATERMARK_BITS) +class DeviceModelManager(object): + """ + Default model loading class, should work for all device classes. + """ + + def __init__( + self, + device: Union[torch.device, str], + swap_device: Optional[Union[torch.device, str]] = None, + ): + """ + Args: + device (Union[torch.device, str]): The device to use for the model. + """ + self.device = torch.device(device) + self.swap_device = ( + torch.device(swap_device) if swap_device is not None else self.device + ) + + def load(self, model: torch.nn.Module): + """ + Loads a model to the device. + """ + return model.to(self.device) + + @contextlib.contextmanager + def use(self, model: torch.nn.Module): + """ + Context manager that ensures a model is on the correct device during use. + The default model loader does not perform any swapping, so the model will + stay on device. + """ + model.to(self.device) + yield + if self.device != self.swap_device: + model.to(self.swap_device) + + +class CudaModelManager(DeviceModelManager): + """ + Device manager that loads a model to a CUDA device, optionally swapping to CPU when not in use. + """ + + def __init__( + self, + device: Union[torch.device, str] = "cuda", + swap_device: Union[torch.device, str] = None, + ): + """ + Args: + device (Union[torch.device, str]): The device to use for the model. + """ + super().__init__(device, swap_device) + + def load(self, model: Union[torch.nn.Module, torch.Tensor]): + """ + Loads a model to the device. + """ + return model.to(self.device) + + @contextlib.contextmanager + def use(self, model: Union[torch.nn.Module, torch.Tensor]): + """ + Context manager that ensures a model is on the correct device during use. + If a swap device was provided, this will move the model to it after use and clear cache. + """ + model.to(self.device) + yield + if self.device != self.swap_device: + model.to(self.swap_device) + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + def get_unique_embedder_keys_from_conditioner(conditioner): return list({x.input_key for x in conditioner.embedders}) @@ -143,7 +216,7 @@ def do_sample( batch2model_input: Optional[List] = None, return_latents=False, filter=None, - device="cuda", + device_manager: DeviceModelManager = DeviceModelManager("cuda"), ): if force_uc_zero_embeddings is None: force_uc_zero_embeddings = [] @@ -151,10 +224,10 @@ def do_sample( batch2model_input = [] with torch.no_grad(): - with autocast(device) as precision_scope: + with autocast(device_manager.device): with model.ema_scope(): num_samples = [num_samples] - with swap_to_device(model.conditioner, device): + with device_manager.use(model.conditioner): batch, batch_uc = get_batch( get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, @@ -176,7 +249,10 @@ def do_sample( for k in c: if not k == "crossattn": c[k], uc[k] = map( - lambda y: y[k][: math.prod(num_samples)].to(device), (c, uc) + lambda y: y[k][: math.prod(num_samples)].to( + device_manager.device + ), + (c, uc), ) additional_model_inputs = {} @@ -184,18 +260,18 @@ def do_sample( additional_model_inputs[k] = batch[k] shape = (math.prod(num_samples), C, H // F, W // F) - randn = torch.randn(shape).to(device) + randn = torch.randn(shape).to(device_manager.device) def denoiser(input, sigma, c): return model.denoiser( model.model, input, sigma, c, **additional_model_inputs ) - with swap_to_device(model.denoiser, device): - with swap_to_device(model.model, device): + with device_manager.use(model.denoiser): + with device_manager.use(model.model): samples_z = sampler(denoiser, randn, cond=c, uc=uc) - with swap_to_device(model.first_stage_model, device): + with device_manager.use(model.first_stage_model): samples_x = model.decode_first_stage(samples_z) samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) @@ -290,12 +366,12 @@ def do_img2img( skip_encode=False, filter=None, add_noise=True, - device="cuda", + device_manager=DeviceModelManager("cuda"), ): with torch.no_grad(): - with autocast(device): + with autocast(device_manager.device): with model.ema_scope(): - with swap_to_device(model.conditioner, device): + with device_manager.use(model.conditioner): batch, batch_uc = get_batch( get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, @@ -308,14 +384,16 @@ def do_img2img( ) for k in c: - c[k], uc[k] = map(lambda y: y[k][:num_samples].to(device), (c, uc)) + c[k], uc[k] = map( + lambda y: y[k][:num_samples].to(device_manager.device), (c, uc) + ) for k in additional_kwargs: c[k] = uc[k] = additional_kwargs[k] if skip_encode: z = img else: - with swap_to_device(model.first_stage_model, device): + with device_manager.use(model.first_stage_model): z = model.encode_first_stage(img) noise = torch.randn_like(z) @@ -338,11 +416,11 @@ def do_img2img( def denoiser(x, sigma, c): return model.denoiser(model.model, x, sigma, c) - with swap_to_device(model.denoiser, device): - with swap_to_device(model.model, device): + with device_manager.use(model.denoiser): + with device_manager.use(model.model): samples_z = sampler(denoiser, noised_z, cond=c, uc=uc) - with swap_to_device(model.first_stage_model, device): + with device_manager.use(model.first_stage_model): samples_x = model.decode_first_stage(samples_z) samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) @@ -352,69 +430,3 @@ def do_img2img( if return_latents: return samples, samples_z return samples - - -class BaseDeviceModelLoader(ABC): - """ - Base class for device managers. Device managers are used to manage the device used for a model. - """ - - @abstractmethod - def __init__(self, device: Union[torch.device, str]): - """ - Args: - device (Union[torch.device, str]): The device to use for the model. - """ - pass - - def load(self, model: torch.nn.Module): - """ - Loads a model to the device. - """ - pass - - @contextlib.contextmanager - def use(self, model: torch.nn.Module): - """ - Context manager that ensures a model is on the correct device during use. - """ - yield - - -class CudaModelLoader(BaseDeviceModelLoader): - """ - Device manager that loads a model to a CUDA device, optionally swapping to CPU when not in use. - """ - - def __init__( - self, - device: Union[torch.device, str] = "cuda", - swap_device: Union[torch.device, str] = None, - ): - """ - Args: - device (Union[torch.device, str]): The device to use for the model. - """ - self.device = torch.device(device) - self.swap_device = ( - torch.device(swap_device) if swap_device is not None else self.device - ) - - def load(self, model: Union[torch.nn.Module, torch.Tensor]): - """ - Loads a model to the device. - """ - model.to(self.swap_device) - - @contextlib.contextmanager - def use(self, model: Union[torch.nn.Module, torch.Tensor]): - """ - Context manager that ensures a model is on the correct device during use. - """ - if self.device != self.swap_device: - model.to(self.device) - yield - if self.device != self.swap_device: - model.to(self.swap_device) - if torch.cuda.is_available(): - torch.cuda.empty_cache()