Stable Video Diffusion

This commit is contained in:
Tim Dockhorn
2023-11-21 10:40:21 -08:00
parent 477d8b9a77
commit 059d8e9cd9
59 changed files with 5463 additions and 1691 deletions

View File

@@ -1,18 +1,22 @@
import logging
import math
import re
from abc import abstractmethod
from contextlib import contextmanager
from typing import Any, Dict, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import pytorch_lightning as pl
import torch
from omegaconf import ListConfig
import torch.nn as nn
from einops import rearrange
from packaging import version
from safetensors.torch import load_file as load_safetensors
from ..modules.diffusionmodules.model import Decoder, Encoder
from ..modules.distributions.distributions import DiagonalGaussianDistribution
from ..modules.autoencoding.regularizers import AbstractRegularizer
from ..modules.ema import LitEma
from ..util import default, get_obj_from_str, instantiate_from_config
from ..util import (default, get_nested_attribute, get_obj_from_str,
instantiate_from_config)
logpy = logging.getLogger(__name__)
class AbstractAutoencoder(pl.LightningModule):
@@ -27,10 +31,9 @@ class AbstractAutoencoder(pl.LightningModule):
ema_decay: Union[None, float] = None,
monitor: Union[None, str] = None,
input_key: str = "jpg",
ckpt_path: Union[None, str] = None,
ignore_keys: Union[Tuple, list, ListConfig] = (),
):
super().__init__()
self.input_key = input_key
self.use_ema = ema_decay is not None
if monitor is not None:
@@ -38,38 +41,21 @@ class AbstractAutoencoder(pl.LightningModule):
if self.use_ema:
self.model_ema = LitEma(self, decay=ema_decay)
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
if version.parse(torch.__version__) >= version.parse("2.0.0"):
self.automatic_optimization = False
def init_from_ckpt(
self, path: str, ignore_keys: Union[Tuple, list, ListConfig] = tuple()
) -> None:
if path.endswith("ckpt"):
sd = torch.load(path, map_location="cpu")["state_dict"]
elif path.endswith("safetensors"):
sd = load_safetensors(path)
else:
raise NotImplementedError
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if re.match(ik, k):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
missing, unexpected = self.load_state_dict(sd, strict=False)
print(
f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
)
if len(missing) > 0:
print(f"Missing Keys: {missing}")
if len(unexpected) > 0:
print(f"Unexpected Keys: {unexpected}")
def apply_ckpt(self, ckpt: Union[None, str, dict]):
if ckpt is None:
return
if isinstance(ckpt, str):
ckpt = {
"target": "sgm.modules.checkpoint.CheckpointEngine",
"params": {"ckpt_path": ckpt},
}
engine = instantiate_from_config(ckpt)
engine(self)
@abstractmethod
def get_input(self, batch) -> Any:
@@ -86,14 +72,14 @@ class AbstractAutoencoder(pl.LightningModule):
self.model_ema.store(self.parameters())
self.model_ema.copy_to(self)
if context is not None:
print(f"{context}: Switched to EMA weights")
logpy.info(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.parameters())
if context is not None:
print(f"{context}: Restored training weights")
logpy.info(f"{context}: Restored training weights")
@abstractmethod
def encode(self, *args, **kwargs) -> torch.Tensor:
@@ -104,7 +90,7 @@ class AbstractAutoencoder(pl.LightningModule):
raise NotImplementedError("decode()-method of abstract base class called")
def instantiate_optimizer_from_config(self, params, lr, cfg):
print(f"loading >>> {cfg['target']} <<< optimizer from config")
logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config")
return get_obj_from_str(cfg["target"])(
params, lr=lr, **cfg.get("params", dict())
)
@@ -129,196 +115,435 @@ class AutoencodingEngine(AbstractAutoencoder):
regularizer_config: Dict,
optimizer_config: Union[Dict, None] = None,
lr_g_factor: float = 1.0,
trainable_ae_params: Optional[List[List[str]]] = None,
ae_optimizer_args: Optional[List[dict]] = None,
trainable_disc_params: Optional[List[List[str]]] = None,
disc_optimizer_args: Optional[List[dict]] = None,
disc_start_iter: int = 0,
diff_boost_factor: float = 3.0,
ckpt_engine: Union[None, str, dict] = None,
ckpt_path: Optional[str] = None,
additional_decode_keys: Optional[List[str]] = None,
**kwargs,
):
super().__init__(*args, **kwargs)
# todo: add options to freeze encoder/decoder
self.encoder = instantiate_from_config(encoder_config)
self.decoder = instantiate_from_config(decoder_config)
self.loss = instantiate_from_config(loss_config)
self.regularization = instantiate_from_config(regularizer_config)
self.automatic_optimization = False # pytorch lightning
self.encoder: torch.nn.Module = instantiate_from_config(encoder_config)
self.decoder: torch.nn.Module = instantiate_from_config(decoder_config)
self.loss: torch.nn.Module = instantiate_from_config(loss_config)
self.regularization: AbstractRegularizer = instantiate_from_config(
regularizer_config
)
self.optimizer_config = default(
optimizer_config, {"target": "torch.optim.Adam"}
)
self.diff_boost_factor = diff_boost_factor
self.disc_start_iter = disc_start_iter
self.lr_g_factor = lr_g_factor
self.trainable_ae_params = trainable_ae_params
if self.trainable_ae_params is not None:
self.ae_optimizer_args = default(
ae_optimizer_args,
[{} for _ in range(len(self.trainable_ae_params))],
)
assert len(self.ae_optimizer_args) == len(self.trainable_ae_params)
else:
self.ae_optimizer_args = [{}] # makes type consitent
self.trainable_disc_params = trainable_disc_params
if self.trainable_disc_params is not None:
self.disc_optimizer_args = default(
disc_optimizer_args,
[{} for _ in range(len(self.trainable_disc_params))],
)
assert len(self.disc_optimizer_args) == len(self.trainable_disc_params)
else:
self.disc_optimizer_args = [{}] # makes type consitent
if ckpt_path is not None:
assert ckpt_engine is None, "Can't set ckpt_engine and ckpt_path"
logpy.warn("Checkpoint path is deprecated, use `checkpoint_egnine` instead")
self.apply_ckpt(default(ckpt_path, ckpt_engine))
self.additional_decode_keys = set(default(additional_decode_keys, []))
def get_input(self, batch: Dict) -> torch.Tensor:
# assuming unified data format, dataloader returns a dict.
# image tensors should be scaled to -1 ... 1 and in channels-first format (e.g., bchw instead if bhwc)
# image tensors should be scaled to -1 ... 1 and in channels-first
# format (e.g., bchw instead if bhwc)
return batch[self.input_key]
def get_autoencoder_params(self) -> list:
params = (
list(self.encoder.parameters())
+ list(self.decoder.parameters())
+ list(self.regularization.get_trainable_parameters())
+ list(self.loss.get_trainable_autoencoder_parameters())
)
params = []
if hasattr(self.loss, "get_trainable_autoencoder_parameters"):
params += list(self.loss.get_trainable_autoencoder_parameters())
if hasattr(self.regularization, "get_trainable_parameters"):
params += list(self.regularization.get_trainable_parameters())
params = params + list(self.encoder.parameters())
params = params + list(self.decoder.parameters())
return params
def get_discriminator_params(self) -> list:
params = list(self.loss.get_trainable_parameters()) # e.g., discriminator
if hasattr(self.loss, "get_trainable_parameters"):
params = list(self.loss.get_trainable_parameters()) # e.g., discriminator
else:
params = []
return params
def get_last_layer(self):
return self.decoder.get_last_layer()
def encode(self, x: Any, return_reg_log: bool = False) -> Any:
def encode(
self,
x: torch.Tensor,
return_reg_log: bool = False,
unregularized: bool = False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
z = self.encoder(x)
if unregularized:
return z, dict()
z, reg_log = self.regularization(z)
if return_reg_log:
return z, reg_log
return z
def decode(self, z: Any) -> torch.Tensor:
x = self.decoder(z)
def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor:
x = self.decoder(z, **kwargs)
return x
def forward(self, x: Any) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
def forward(
self, x: torch.Tensor, **additional_decode_kwargs
) -> Tuple[torch.Tensor, torch.Tensor, dict]:
z, reg_log = self.encode(x, return_reg_log=True)
dec = self.decode(z)
dec = self.decode(z, **additional_decode_kwargs)
return z, dec, reg_log
def training_step(self, batch, batch_idx, optimizer_idx) -> Any:
def inner_training_step(
self, batch: dict, batch_idx: int, optimizer_idx: int = 0
) -> torch.Tensor:
x = self.get_input(batch)
z, xrec, regularization_log = self(x)
additional_decode_kwargs = {
key: batch[key] for key in self.additional_decode_keys.intersection(batch)
}
z, xrec, regularization_log = self(x, **additional_decode_kwargs)
if hasattr(self.loss, "forward_keys"):
extra_info = {
"z": z,
"optimizer_idx": optimizer_idx,
"global_step": self.global_step,
"last_layer": self.get_last_layer(),
"split": "train",
"regularization_log": regularization_log,
"autoencoder": self,
}
extra_info = {k: extra_info[k] for k in self.loss.forward_keys}
else:
extra_info = dict()
if optimizer_idx == 0:
# autoencode
aeloss, log_dict_ae = self.loss(
regularization_log,
x,
xrec,
optimizer_idx,
self.global_step,
last_layer=self.get_last_layer(),
split="train",
)
out_loss = self.loss(x, xrec, **extra_info)
if isinstance(out_loss, tuple):
aeloss, log_dict_ae = out_loss
else:
# simple loss function
aeloss = out_loss
log_dict_ae = {"train/loss/rec": aeloss.detach()}
self.log_dict(
log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True
log_dict_ae,
prog_bar=False,
logger=True,
on_step=True,
on_epoch=True,
sync_dist=False,
)
self.log(
"loss",
aeloss.mean().detach(),
prog_bar=True,
logger=False,
on_epoch=False,
on_step=True,
)
return aeloss
if optimizer_idx == 1:
elif optimizer_idx == 1:
# discriminator
discloss, log_dict_disc = self.loss(
regularization_log,
x,
xrec,
optimizer_idx,
self.global_step,
last_layer=self.get_last_layer(),
split="train",
)
discloss, log_dict_disc = self.loss(x, xrec, **extra_info)
# -> discriminator always needs to return a tuple
self.log_dict(
log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True
)
return discloss
else:
raise NotImplementedError(f"Unknown optimizer {optimizer_idx}")
def validation_step(self, batch, batch_idx) -> Dict:
def training_step(self, batch: dict, batch_idx: int):
opts = self.optimizers()
if not isinstance(opts, list):
# Non-adversarial case
opts = [opts]
optimizer_idx = batch_idx % len(opts)
if self.global_step < self.disc_start_iter:
optimizer_idx = 0
opt = opts[optimizer_idx]
opt.zero_grad()
with opt.toggle_model():
loss = self.inner_training_step(
batch, batch_idx, optimizer_idx=optimizer_idx
)
self.manual_backward(loss)
opt.step()
def validation_step(self, batch: dict, batch_idx: int) -> Dict:
log_dict = self._validation_step(batch, batch_idx)
with self.ema_scope():
log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
log_dict.update(log_dict_ema)
return log_dict
def _validation_step(self, batch, batch_idx, postfix="") -> Dict:
def _validation_step(self, batch: dict, batch_idx: int, postfix: str = "") -> Dict:
x = self.get_input(batch)
z, xrec, regularization_log = self(x)
aeloss, log_dict_ae = self.loss(
regularization_log,
x,
xrec,
0,
self.global_step,
last_layer=self.get_last_layer(),
split="val" + postfix,
if hasattr(self.loss, "forward_keys"):
extra_info = {
"z": z,
"optimizer_idx": 0,
"global_step": self.global_step,
"last_layer": self.get_last_layer(),
"split": "val" + postfix,
"regularization_log": regularization_log,
"autoencoder": self,
}
extra_info = {k: extra_info[k] for k in self.loss.forward_keys}
else:
extra_info = dict()
out_loss = self.loss(x, xrec, **extra_info)
if isinstance(out_loss, tuple):
aeloss, log_dict_ae = out_loss
else:
# simple loss function
aeloss = out_loss
log_dict_ae = {f"val{postfix}/loss/rec": aeloss.detach()}
full_log_dict = log_dict_ae
if "optimizer_idx" in extra_info:
extra_info["optimizer_idx"] = 1
discloss, log_dict_disc = self.loss(x, xrec, **extra_info)
full_log_dict.update(log_dict_disc)
self.log(
f"val{postfix}/loss/rec",
log_dict_ae[f"val{postfix}/loss/rec"],
sync_dist=True,
)
self.log_dict(full_log_dict, sync_dist=True)
return full_log_dict
discloss, log_dict_disc = self.loss(
regularization_log,
x,
xrec,
1,
self.global_step,
last_layer=self.get_last_layer(),
split="val" + postfix,
)
self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
log_dict_ae.update(log_dict_disc)
self.log_dict(log_dict_ae)
return log_dict_ae
def configure_optimizers(self) -> Any:
ae_params = self.get_autoencoder_params()
disc_params = self.get_discriminator_params()
def get_param_groups(
self, parameter_names: List[List[str]], optimizer_args: List[dict]
) -> Tuple[List[Dict[str, Any]], int]:
groups = []
num_params = 0
for names, args in zip(parameter_names, optimizer_args):
params = []
for pattern_ in names:
pattern_params = []
pattern = re.compile(pattern_)
for p_name, param in self.named_parameters():
if re.match(pattern, p_name):
pattern_params.append(param)
num_params += param.numel()
if len(pattern_params) == 0:
logpy.warn(f"Did not find parameters for pattern {pattern_}")
params.extend(pattern_params)
groups.append({"params": params, **args})
return groups, num_params
def configure_optimizers(self) -> List[torch.optim.Optimizer]:
if self.trainable_ae_params is None:
ae_params = self.get_autoencoder_params()
else:
ae_params, num_ae_params = self.get_param_groups(
self.trainable_ae_params, self.ae_optimizer_args
)
logpy.info(f"Number of trainable autoencoder parameters: {num_ae_params:,}")
if self.trainable_disc_params is None:
disc_params = self.get_discriminator_params()
else:
disc_params, num_disc_params = self.get_param_groups(
self.trainable_disc_params, self.disc_optimizer_args
)
logpy.info(
f"Number of trainable discriminator parameters: {num_disc_params:,}"
)
opt_ae = self.instantiate_optimizer_from_config(
ae_params,
default(self.lr_g_factor, 1.0) * self.learning_rate,
self.optimizer_config,
)
opt_disc = self.instantiate_optimizer_from_config(
disc_params, self.learning_rate, self.optimizer_config
)
opts = [opt_ae]
if len(disc_params) > 0:
opt_disc = self.instantiate_optimizer_from_config(
disc_params, self.learning_rate, self.optimizer_config
)
opts.append(opt_disc)
return [opt_ae, opt_disc], []
return opts
@torch.no_grad()
def log_images(self, batch: Dict, **kwargs) -> Dict:
def log_images(
self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs
) -> dict:
log = dict()
additional_decode_kwargs = {}
x = self.get_input(batch)
_, xrec, _ = self(x)
additional_decode_kwargs.update(
{key: batch[key] for key in self.additional_decode_keys.intersection(batch)}
)
_, xrec, _ = self(x, **additional_decode_kwargs)
log["inputs"] = x
log["reconstructions"] = xrec
diff = 0.5 * torch.abs(torch.clamp(xrec, -1.0, 1.0) - x)
diff.clamp_(0, 1.0)
log["diff"] = 2.0 * diff - 1.0
# diff_boost shows location of small errors, by boosting their
# brightness.
log["diff_boost"] = (
2.0 * torch.clamp(self.diff_boost_factor * diff, 0.0, 1.0) - 1
)
if hasattr(self.loss, "log_images"):
log.update(self.loss.log_images(x, xrec))
with self.ema_scope():
_, xrec_ema, _ = self(x)
_, xrec_ema, _ = self(x, **additional_decode_kwargs)
log["reconstructions_ema"] = xrec_ema
diff_ema = 0.5 * torch.abs(torch.clamp(xrec_ema, -1.0, 1.0) - x)
diff_ema.clamp_(0, 1.0)
log["diff_ema"] = 2.0 * diff_ema - 1.0
log["diff_boost_ema"] = (
2.0 * torch.clamp(self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1
)
if additional_log_kwargs:
additional_decode_kwargs.update(additional_log_kwargs)
_, xrec_add, _ = self(x, **additional_decode_kwargs)
log_str = "reconstructions-" + "-".join(
[f"{key}={additional_log_kwargs[key]}" for key in additional_log_kwargs]
)
log[log_str] = xrec_add
return log
class AutoencoderKL(AutoencodingEngine):
class AutoencodingEngineLegacy(AutoencodingEngine):
def __init__(self, embed_dim: int, **kwargs):
self.max_batch_size = kwargs.pop("max_batch_size", None)
ddconfig = kwargs.pop("ddconfig")
ckpt_path = kwargs.pop("ckpt_path", None)
ignore_keys = kwargs.pop("ignore_keys", ())
ckpt_engine = kwargs.pop("ckpt_engine", None)
super().__init__(
encoder_config={"target": "torch.nn.Identity"},
decoder_config={"target": "torch.nn.Identity"},
regularizer_config={"target": "torch.nn.Identity"},
loss_config=kwargs.pop("lossconfig"),
encoder_config={
"target": "sgm.modules.diffusionmodules.model.Encoder",
"params": ddconfig,
},
decoder_config={
"target": "sgm.modules.diffusionmodules.model.Decoder",
"params": ddconfig,
},
**kwargs,
)
assert ddconfig["double_z"]
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
self.quant_conv = torch.nn.Conv2d(
(1 + ddconfig["double_z"]) * ddconfig["z_channels"],
(1 + ddconfig["double_z"]) * embed_dim,
1,
)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
self.embed_dim = embed_dim
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
self.apply_ckpt(default(ckpt_path, ckpt_engine))
def encode(self, x):
assert (
not self.training
), f"{self.__class__.__name__} only supports inference currently"
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
return posterior
def get_autoencoder_params(self) -> list:
params = super().get_autoencoder_params()
return params
def encode(
self, x: torch.Tensor, return_reg_log: bool = False
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
if self.max_batch_size is None:
z = self.encoder(x)
z = self.quant_conv(z)
else:
N = x.shape[0]
bs = self.max_batch_size
n_batches = int(math.ceil(N / bs))
z = list()
for i_batch in range(n_batches):
z_batch = self.encoder(x[i_batch * bs : (i_batch + 1) * bs])
z_batch = self.quant_conv(z_batch)
z.append(z_batch)
z = torch.cat(z, 0)
z, reg_log = self.regularization(z)
if return_reg_log:
return z, reg_log
return z
def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor:
if self.max_batch_size is None:
dec = self.post_quant_conv(z)
dec = self.decoder(dec, **decoder_kwargs)
else:
N = z.shape[0]
bs = self.max_batch_size
n_batches = int(math.ceil(N / bs))
dec = list()
for i_batch in range(n_batches):
dec_batch = self.post_quant_conv(z[i_batch * bs : (i_batch + 1) * bs])
dec_batch = self.decoder(dec_batch, **decoder_kwargs)
dec.append(dec_batch)
dec = torch.cat(dec, 0)
def decode(self, z, **decoder_kwargs):
z = self.post_quant_conv(z)
dec = self.decoder(z, **decoder_kwargs)
return dec
class AutoencoderKLInferenceWrapper(AutoencoderKL):
def encode(self, x):
return super().encode(x).sample()
class AutoencoderKL(AutoencodingEngineLegacy):
def __init__(self, **kwargs):
if "lossconfig" in kwargs:
kwargs["loss_config"] = kwargs.pop("lossconfig")
super().__init__(
regularizer_config={
"target": (
"sgm.modules.autoencoding.regularizers"
".DiagonalGaussianRegularizer"
)
},
**kwargs,
)
class AutoencoderLegacyVQ(AutoencodingEngineLegacy):
def __init__(
self,
embed_dim: int,
n_embed: int,
sane_index_shape: bool = False,
**kwargs,
):
if "lossconfig" in kwargs:
logpy.warn(f"Parameter `lossconfig` is deprecated, use `loss_config`.")
kwargs["loss_config"] = kwargs.pop("lossconfig")
super().__init__(
regularizer_config={
"target": (
"sgm.modules.autoencoding.regularizers.quantize" ".VectorQuantizer"
),
"params": {
"n_e": n_embed,
"e_dim": embed_dim,
"sane_index_shape": sane_index_shape,
},
},
**kwargs,
)
class IdentityFirstStage(AbstractAutoencoder):
@@ -333,3 +558,58 @@ class IdentityFirstStage(AbstractAutoencoder):
def decode(self, x: Any, *args, **kwargs) -> Any:
return x
class AEIntegerWrapper(nn.Module):
def __init__(
self,
model: nn.Module,
shape: Union[None, Tuple[int, int], List[int]] = (16, 16),
regularization_key: str = "regularization",
encoder_kwargs: Optional[Dict[str, Any]] = None,
):
super().__init__()
self.model = model
assert hasattr(model, "encode") and hasattr(
model, "decode"
), "Need AE interface"
self.regularization = get_nested_attribute(model, regularization_key)
self.shape = shape
self.encoder_kwargs = default(encoder_kwargs, {"return_reg_log": True})
def encode(self, x) -> torch.Tensor:
assert (
not self.training
), f"{self.__class__.__name__} only supports inference currently"
_, log = self.model.encode(x, **self.encoder_kwargs)
assert isinstance(log, dict)
inds = log["min_encoding_indices"]
return rearrange(inds, "b ... -> b (...)")
def decode(
self, inds: torch.Tensor, shape: Union[None, tuple, list] = None
) -> torch.Tensor:
# expect inds shape (b, s) with s = h*w
shape = default(shape, self.shape) # Optional[(h, w)]
if shape is not None:
assert len(shape) == 2, f"Unhandeled shape {shape}"
inds = rearrange(inds, "b (h w) -> b h w", h=shape[0], w=shape[1])
h = self.regularization.get_codebook_entry(inds) # (b, h, w, c)
h = rearrange(h, "b h w c -> b c h w")
return self.model.decode(h)
class AutoencoderKLModeOnly(AutoencodingEngineLegacy):
def __init__(self, **kwargs):
if "lossconfig" in kwargs:
kwargs["loss_config"] = kwargs.pop("lossconfig")
super().__init__(
regularizer_config={
"target": (
"sgm.modules.autoencoding.regularizers"
".DiagonalGaussianRegularizer"
),
"params": {"sample": False},
},
**kwargs,
)