mirror of
https://github.com/Stability-AI/generative-models.git
synced 2026-02-02 04:14:27 +01:00
finish device manager refactor
This commit is contained in:
@@ -20,7 +20,7 @@ from sgm.inference.api import (
|
||||
SamplingPipeline,
|
||||
Thresholder,
|
||||
)
|
||||
from sgm.inference.helpers import embed_watermark, CudaModelLoader
|
||||
from sgm.inference.helpers import embed_watermark, CudaModelManager
|
||||
|
||||
|
||||
@st.cache_resource()
|
||||
@@ -35,7 +35,7 @@ def init_st(spec: SamplingSpec, load_ckpt=True, load_filter=True) -> Dict[str, A
|
||||
pipeline = SamplingPipeline(
|
||||
model_spec=spec,
|
||||
use_fp16=True,
|
||||
model_loader=CudaModelLoader(device="cuda", swap_device="cpu"),
|
||||
model_loader=CudaModelManager(device="cuda", swap_device="cpu"),
|
||||
)
|
||||
else:
|
||||
pipeline = SamplingPipeline(model_spec=spec, use_fp16=False)
|
||||
|
||||
@@ -10,7 +10,6 @@ from einops import rearrange
|
||||
from imwatermark import WatermarkEncoder
|
||||
from omegaconf import ListConfig
|
||||
from torch import autocast
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from sgm.util import append_dims
|
||||
|
||||
@@ -60,6 +59,80 @@ WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
|
||||
embed_watermark = WatermarkEmbedder(WATERMARK_BITS)
|
||||
|
||||
|
||||
class DeviceModelManager(object):
|
||||
"""
|
||||
Default model loading class, should work for all device classes.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device: Union[torch.device, str],
|
||||
swap_device: Optional[Union[torch.device, str]] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
device (Union[torch.device, str]): The device to use for the model.
|
||||
"""
|
||||
self.device = torch.device(device)
|
||||
self.swap_device = (
|
||||
torch.device(swap_device) if swap_device is not None else self.device
|
||||
)
|
||||
|
||||
def load(self, model: torch.nn.Module):
|
||||
"""
|
||||
Loads a model to the device.
|
||||
"""
|
||||
return model.to(self.device)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def use(self, model: torch.nn.Module):
|
||||
"""
|
||||
Context manager that ensures a model is on the correct device during use.
|
||||
The default model loader does not perform any swapping, so the model will
|
||||
stay on device.
|
||||
"""
|
||||
model.to(self.device)
|
||||
yield
|
||||
if self.device != self.swap_device:
|
||||
model.to(self.swap_device)
|
||||
|
||||
|
||||
class CudaModelManager(DeviceModelManager):
|
||||
"""
|
||||
Device manager that loads a model to a CUDA device, optionally swapping to CPU when not in use.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device: Union[torch.device, str] = "cuda",
|
||||
swap_device: Union[torch.device, str] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
device (Union[torch.device, str]): The device to use for the model.
|
||||
"""
|
||||
super().__init__(device, swap_device)
|
||||
|
||||
def load(self, model: Union[torch.nn.Module, torch.Tensor]):
|
||||
"""
|
||||
Loads a model to the device.
|
||||
"""
|
||||
return model.to(self.device)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def use(self, model: Union[torch.nn.Module, torch.Tensor]):
|
||||
"""
|
||||
Context manager that ensures a model is on the correct device during use.
|
||||
If a swap device was provided, this will move the model to it after use and clear cache.
|
||||
"""
|
||||
model.to(self.device)
|
||||
yield
|
||||
if self.device != self.swap_device:
|
||||
model.to(self.swap_device)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def get_unique_embedder_keys_from_conditioner(conditioner):
|
||||
return list({x.input_key for x in conditioner.embedders})
|
||||
|
||||
@@ -143,7 +216,7 @@ def do_sample(
|
||||
batch2model_input: Optional[List] = None,
|
||||
return_latents=False,
|
||||
filter=None,
|
||||
device="cuda",
|
||||
device_manager: DeviceModelManager = DeviceModelManager("cuda"),
|
||||
):
|
||||
if force_uc_zero_embeddings is None:
|
||||
force_uc_zero_embeddings = []
|
||||
@@ -151,10 +224,10 @@ def do_sample(
|
||||
batch2model_input = []
|
||||
|
||||
with torch.no_grad():
|
||||
with autocast(device) as precision_scope:
|
||||
with autocast(device_manager.device):
|
||||
with model.ema_scope():
|
||||
num_samples = [num_samples]
|
||||
with swap_to_device(model.conditioner, device):
|
||||
with device_manager.use(model.conditioner):
|
||||
batch, batch_uc = get_batch(
|
||||
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
||||
value_dict,
|
||||
@@ -176,7 +249,10 @@ def do_sample(
|
||||
for k in c:
|
||||
if not k == "crossattn":
|
||||
c[k], uc[k] = map(
|
||||
lambda y: y[k][: math.prod(num_samples)].to(device), (c, uc)
|
||||
lambda y: y[k][: math.prod(num_samples)].to(
|
||||
device_manager.device
|
||||
),
|
||||
(c, uc),
|
||||
)
|
||||
|
||||
additional_model_inputs = {}
|
||||
@@ -184,18 +260,18 @@ def do_sample(
|
||||
additional_model_inputs[k] = batch[k]
|
||||
|
||||
shape = (math.prod(num_samples), C, H // F, W // F)
|
||||
randn = torch.randn(shape).to(device)
|
||||
randn = torch.randn(shape).to(device_manager.device)
|
||||
|
||||
def denoiser(input, sigma, c):
|
||||
return model.denoiser(
|
||||
model.model, input, sigma, c, **additional_model_inputs
|
||||
)
|
||||
|
||||
with swap_to_device(model.denoiser, device):
|
||||
with swap_to_device(model.model, device):
|
||||
with device_manager.use(model.denoiser):
|
||||
with device_manager.use(model.model):
|
||||
samples_z = sampler(denoiser, randn, cond=c, uc=uc)
|
||||
|
||||
with swap_to_device(model.first_stage_model, device):
|
||||
with device_manager.use(model.first_stage_model):
|
||||
samples_x = model.decode_first_stage(samples_z)
|
||||
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
|
||||
@@ -290,12 +366,12 @@ def do_img2img(
|
||||
skip_encode=False,
|
||||
filter=None,
|
||||
add_noise=True,
|
||||
device="cuda",
|
||||
device_manager=DeviceModelManager("cuda"),
|
||||
):
|
||||
with torch.no_grad():
|
||||
with autocast(device):
|
||||
with autocast(device_manager.device):
|
||||
with model.ema_scope():
|
||||
with swap_to_device(model.conditioner, device):
|
||||
with device_manager.use(model.conditioner):
|
||||
batch, batch_uc = get_batch(
|
||||
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
||||
value_dict,
|
||||
@@ -308,14 +384,16 @@ def do_img2img(
|
||||
)
|
||||
|
||||
for k in c:
|
||||
c[k], uc[k] = map(lambda y: y[k][:num_samples].to(device), (c, uc))
|
||||
c[k], uc[k] = map(
|
||||
lambda y: y[k][:num_samples].to(device_manager.device), (c, uc)
|
||||
)
|
||||
|
||||
for k in additional_kwargs:
|
||||
c[k] = uc[k] = additional_kwargs[k]
|
||||
if skip_encode:
|
||||
z = img
|
||||
else:
|
||||
with swap_to_device(model.first_stage_model, device):
|
||||
with device_manager.use(model.first_stage_model):
|
||||
z = model.encode_first_stage(img)
|
||||
|
||||
noise = torch.randn_like(z)
|
||||
@@ -338,11 +416,11 @@ def do_img2img(
|
||||
def denoiser(x, sigma, c):
|
||||
return model.denoiser(model.model, x, sigma, c)
|
||||
|
||||
with swap_to_device(model.denoiser, device):
|
||||
with swap_to_device(model.model, device):
|
||||
with device_manager.use(model.denoiser):
|
||||
with device_manager.use(model.model):
|
||||
samples_z = sampler(denoiser, noised_z, cond=c, uc=uc)
|
||||
|
||||
with swap_to_device(model.first_stage_model, device):
|
||||
with device_manager.use(model.first_stage_model):
|
||||
samples_x = model.decode_first_stage(samples_z)
|
||||
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
|
||||
@@ -352,69 +430,3 @@ def do_img2img(
|
||||
if return_latents:
|
||||
return samples, samples_z
|
||||
return samples
|
||||
|
||||
|
||||
class BaseDeviceModelLoader(ABC):
|
||||
"""
|
||||
Base class for device managers. Device managers are used to manage the device used for a model.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, device: Union[torch.device, str]):
|
||||
"""
|
||||
Args:
|
||||
device (Union[torch.device, str]): The device to use for the model.
|
||||
"""
|
||||
pass
|
||||
|
||||
def load(self, model: torch.nn.Module):
|
||||
"""
|
||||
Loads a model to the device.
|
||||
"""
|
||||
pass
|
||||
|
||||
@contextlib.contextmanager
|
||||
def use(self, model: torch.nn.Module):
|
||||
"""
|
||||
Context manager that ensures a model is on the correct device during use.
|
||||
"""
|
||||
yield
|
||||
|
||||
|
||||
class CudaModelLoader(BaseDeviceModelLoader):
|
||||
"""
|
||||
Device manager that loads a model to a CUDA device, optionally swapping to CPU when not in use.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device: Union[torch.device, str] = "cuda",
|
||||
swap_device: Union[torch.device, str] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
device (Union[torch.device, str]): The device to use for the model.
|
||||
"""
|
||||
self.device = torch.device(device)
|
||||
self.swap_device = (
|
||||
torch.device(swap_device) if swap_device is not None else self.device
|
||||
)
|
||||
|
||||
def load(self, model: Union[torch.nn.Module, torch.Tensor]):
|
||||
"""
|
||||
Loads a model to the device.
|
||||
"""
|
||||
model.to(self.swap_device)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def use(self, model: Union[torch.nn.Module, torch.Tensor]):
|
||||
"""
|
||||
Context manager that ensures a model is on the correct device during use.
|
||||
"""
|
||||
if self.device != self.swap_device:
|
||||
model.to(self.device)
|
||||
yield
|
||||
if self.device != self.swap_device:
|
||||
model.to(self.swap_device)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
Reference in New Issue
Block a user