mirror of
https://github.com/Stability-AI/generative-models.git
synced 2025-12-19 22:34:22 +01:00
Stable Video Diffusion
This commit is contained in:
@@ -1,18 +1,22 @@
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
from abc import abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from omegaconf import ListConfig
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
from packaging import version
|
||||
from safetensors.torch import load_file as load_safetensors
|
||||
|
||||
from ..modules.diffusionmodules.model import Decoder, Encoder
|
||||
from ..modules.distributions.distributions import DiagonalGaussianDistribution
|
||||
from ..modules.autoencoding.regularizers import AbstractRegularizer
|
||||
from ..modules.ema import LitEma
|
||||
from ..util import default, get_obj_from_str, instantiate_from_config
|
||||
from ..util import (default, get_nested_attribute, get_obj_from_str,
|
||||
instantiate_from_config)
|
||||
|
||||
logpy = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AbstractAutoencoder(pl.LightningModule):
|
||||
@@ -27,10 +31,9 @@ class AbstractAutoencoder(pl.LightningModule):
|
||||
ema_decay: Union[None, float] = None,
|
||||
monitor: Union[None, str] = None,
|
||||
input_key: str = "jpg",
|
||||
ckpt_path: Union[None, str] = None,
|
||||
ignore_keys: Union[Tuple, list, ListConfig] = (),
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.input_key = input_key
|
||||
self.use_ema = ema_decay is not None
|
||||
if monitor is not None:
|
||||
@@ -38,38 +41,21 @@ class AbstractAutoencoder(pl.LightningModule):
|
||||
|
||||
if self.use_ema:
|
||||
self.model_ema = LitEma(self, decay=ema_decay)
|
||||
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
||||
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||
logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
||||
|
||||
if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
||||
self.automatic_optimization = False
|
||||
|
||||
def init_from_ckpt(
|
||||
self, path: str, ignore_keys: Union[Tuple, list, ListConfig] = tuple()
|
||||
) -> None:
|
||||
if path.endswith("ckpt"):
|
||||
sd = torch.load(path, map_location="cpu")["state_dict"]
|
||||
elif path.endswith("safetensors"):
|
||||
sd = load_safetensors(path)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
for ik in ignore_keys:
|
||||
if re.match(ik, k):
|
||||
print("Deleting key {} from state_dict.".format(k))
|
||||
del sd[k]
|
||||
missing, unexpected = self.load_state_dict(sd, strict=False)
|
||||
print(
|
||||
f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
|
||||
)
|
||||
if len(missing) > 0:
|
||||
print(f"Missing Keys: {missing}")
|
||||
if len(unexpected) > 0:
|
||||
print(f"Unexpected Keys: {unexpected}")
|
||||
def apply_ckpt(self, ckpt: Union[None, str, dict]):
|
||||
if ckpt is None:
|
||||
return
|
||||
if isinstance(ckpt, str):
|
||||
ckpt = {
|
||||
"target": "sgm.modules.checkpoint.CheckpointEngine",
|
||||
"params": {"ckpt_path": ckpt},
|
||||
}
|
||||
engine = instantiate_from_config(ckpt)
|
||||
engine(self)
|
||||
|
||||
@abstractmethod
|
||||
def get_input(self, batch) -> Any:
|
||||
@@ -86,14 +72,14 @@ class AbstractAutoencoder(pl.LightningModule):
|
||||
self.model_ema.store(self.parameters())
|
||||
self.model_ema.copy_to(self)
|
||||
if context is not None:
|
||||
print(f"{context}: Switched to EMA weights")
|
||||
logpy.info(f"{context}: Switched to EMA weights")
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
if self.use_ema:
|
||||
self.model_ema.restore(self.parameters())
|
||||
if context is not None:
|
||||
print(f"{context}: Restored training weights")
|
||||
logpy.info(f"{context}: Restored training weights")
|
||||
|
||||
@abstractmethod
|
||||
def encode(self, *args, **kwargs) -> torch.Tensor:
|
||||
@@ -104,7 +90,7 @@ class AbstractAutoencoder(pl.LightningModule):
|
||||
raise NotImplementedError("decode()-method of abstract base class called")
|
||||
|
||||
def instantiate_optimizer_from_config(self, params, lr, cfg):
|
||||
print(f"loading >>> {cfg['target']} <<< optimizer from config")
|
||||
logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config")
|
||||
return get_obj_from_str(cfg["target"])(
|
||||
params, lr=lr, **cfg.get("params", dict())
|
||||
)
|
||||
@@ -129,196 +115,435 @@ class AutoencodingEngine(AbstractAutoencoder):
|
||||
regularizer_config: Dict,
|
||||
optimizer_config: Union[Dict, None] = None,
|
||||
lr_g_factor: float = 1.0,
|
||||
trainable_ae_params: Optional[List[List[str]]] = None,
|
||||
ae_optimizer_args: Optional[List[dict]] = None,
|
||||
trainable_disc_params: Optional[List[List[str]]] = None,
|
||||
disc_optimizer_args: Optional[List[dict]] = None,
|
||||
disc_start_iter: int = 0,
|
||||
diff_boost_factor: float = 3.0,
|
||||
ckpt_engine: Union[None, str, dict] = None,
|
||||
ckpt_path: Optional[str] = None,
|
||||
additional_decode_keys: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
# todo: add options to freeze encoder/decoder
|
||||
self.encoder = instantiate_from_config(encoder_config)
|
||||
self.decoder = instantiate_from_config(decoder_config)
|
||||
self.loss = instantiate_from_config(loss_config)
|
||||
self.regularization = instantiate_from_config(regularizer_config)
|
||||
self.automatic_optimization = False # pytorch lightning
|
||||
|
||||
self.encoder: torch.nn.Module = instantiate_from_config(encoder_config)
|
||||
self.decoder: torch.nn.Module = instantiate_from_config(decoder_config)
|
||||
self.loss: torch.nn.Module = instantiate_from_config(loss_config)
|
||||
self.regularization: AbstractRegularizer = instantiate_from_config(
|
||||
regularizer_config
|
||||
)
|
||||
self.optimizer_config = default(
|
||||
optimizer_config, {"target": "torch.optim.Adam"}
|
||||
)
|
||||
self.diff_boost_factor = diff_boost_factor
|
||||
self.disc_start_iter = disc_start_iter
|
||||
self.lr_g_factor = lr_g_factor
|
||||
self.trainable_ae_params = trainable_ae_params
|
||||
if self.trainable_ae_params is not None:
|
||||
self.ae_optimizer_args = default(
|
||||
ae_optimizer_args,
|
||||
[{} for _ in range(len(self.trainable_ae_params))],
|
||||
)
|
||||
assert len(self.ae_optimizer_args) == len(self.trainable_ae_params)
|
||||
else:
|
||||
self.ae_optimizer_args = [{}] # makes type consitent
|
||||
|
||||
self.trainable_disc_params = trainable_disc_params
|
||||
if self.trainable_disc_params is not None:
|
||||
self.disc_optimizer_args = default(
|
||||
disc_optimizer_args,
|
||||
[{} for _ in range(len(self.trainable_disc_params))],
|
||||
)
|
||||
assert len(self.disc_optimizer_args) == len(self.trainable_disc_params)
|
||||
else:
|
||||
self.disc_optimizer_args = [{}] # makes type consitent
|
||||
|
||||
if ckpt_path is not None:
|
||||
assert ckpt_engine is None, "Can't set ckpt_engine and ckpt_path"
|
||||
logpy.warn("Checkpoint path is deprecated, use `checkpoint_egnine` instead")
|
||||
self.apply_ckpt(default(ckpt_path, ckpt_engine))
|
||||
self.additional_decode_keys = set(default(additional_decode_keys, []))
|
||||
|
||||
def get_input(self, batch: Dict) -> torch.Tensor:
|
||||
# assuming unified data format, dataloader returns a dict.
|
||||
# image tensors should be scaled to -1 ... 1 and in channels-first format (e.g., bchw instead if bhwc)
|
||||
# image tensors should be scaled to -1 ... 1 and in channels-first
|
||||
# format (e.g., bchw instead if bhwc)
|
||||
return batch[self.input_key]
|
||||
|
||||
def get_autoencoder_params(self) -> list:
|
||||
params = (
|
||||
list(self.encoder.parameters())
|
||||
+ list(self.decoder.parameters())
|
||||
+ list(self.regularization.get_trainable_parameters())
|
||||
+ list(self.loss.get_trainable_autoencoder_parameters())
|
||||
)
|
||||
params = []
|
||||
if hasattr(self.loss, "get_trainable_autoencoder_parameters"):
|
||||
params += list(self.loss.get_trainable_autoencoder_parameters())
|
||||
if hasattr(self.regularization, "get_trainable_parameters"):
|
||||
params += list(self.regularization.get_trainable_parameters())
|
||||
params = params + list(self.encoder.parameters())
|
||||
params = params + list(self.decoder.parameters())
|
||||
return params
|
||||
|
||||
def get_discriminator_params(self) -> list:
|
||||
params = list(self.loss.get_trainable_parameters()) # e.g., discriminator
|
||||
if hasattr(self.loss, "get_trainable_parameters"):
|
||||
params = list(self.loss.get_trainable_parameters()) # e.g., discriminator
|
||||
else:
|
||||
params = []
|
||||
return params
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.decoder.get_last_layer()
|
||||
|
||||
def encode(self, x: Any, return_reg_log: bool = False) -> Any:
|
||||
def encode(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
return_reg_log: bool = False,
|
||||
unregularized: bool = False,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
|
||||
z = self.encoder(x)
|
||||
if unregularized:
|
||||
return z, dict()
|
||||
z, reg_log = self.regularization(z)
|
||||
if return_reg_log:
|
||||
return z, reg_log
|
||||
return z
|
||||
|
||||
def decode(self, z: Any) -> torch.Tensor:
|
||||
x = self.decoder(z)
|
||||
def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
x = self.decoder(z, **kwargs)
|
||||
return x
|
||||
|
||||
def forward(self, x: Any) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
def forward(
|
||||
self, x: torch.Tensor, **additional_decode_kwargs
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, dict]:
|
||||
z, reg_log = self.encode(x, return_reg_log=True)
|
||||
dec = self.decode(z)
|
||||
dec = self.decode(z, **additional_decode_kwargs)
|
||||
return z, dec, reg_log
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx) -> Any:
|
||||
def inner_training_step(
|
||||
self, batch: dict, batch_idx: int, optimizer_idx: int = 0
|
||||
) -> torch.Tensor:
|
||||
x = self.get_input(batch)
|
||||
z, xrec, regularization_log = self(x)
|
||||
additional_decode_kwargs = {
|
||||
key: batch[key] for key in self.additional_decode_keys.intersection(batch)
|
||||
}
|
||||
z, xrec, regularization_log = self(x, **additional_decode_kwargs)
|
||||
if hasattr(self.loss, "forward_keys"):
|
||||
extra_info = {
|
||||
"z": z,
|
||||
"optimizer_idx": optimizer_idx,
|
||||
"global_step": self.global_step,
|
||||
"last_layer": self.get_last_layer(),
|
||||
"split": "train",
|
||||
"regularization_log": regularization_log,
|
||||
"autoencoder": self,
|
||||
}
|
||||
extra_info = {k: extra_info[k] for k in self.loss.forward_keys}
|
||||
else:
|
||||
extra_info = dict()
|
||||
|
||||
if optimizer_idx == 0:
|
||||
# autoencode
|
||||
aeloss, log_dict_ae = self.loss(
|
||||
regularization_log,
|
||||
x,
|
||||
xrec,
|
||||
optimizer_idx,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split="train",
|
||||
)
|
||||
out_loss = self.loss(x, xrec, **extra_info)
|
||||
if isinstance(out_loss, tuple):
|
||||
aeloss, log_dict_ae = out_loss
|
||||
else:
|
||||
# simple loss function
|
||||
aeloss = out_loss
|
||||
log_dict_ae = {"train/loss/rec": aeloss.detach()}
|
||||
|
||||
self.log_dict(
|
||||
log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True
|
||||
log_dict_ae,
|
||||
prog_bar=False,
|
||||
logger=True,
|
||||
on_step=True,
|
||||
on_epoch=True,
|
||||
sync_dist=False,
|
||||
)
|
||||
self.log(
|
||||
"loss",
|
||||
aeloss.mean().detach(),
|
||||
prog_bar=True,
|
||||
logger=False,
|
||||
on_epoch=False,
|
||||
on_step=True,
|
||||
)
|
||||
return aeloss
|
||||
|
||||
if optimizer_idx == 1:
|
||||
elif optimizer_idx == 1:
|
||||
# discriminator
|
||||
discloss, log_dict_disc = self.loss(
|
||||
regularization_log,
|
||||
x,
|
||||
xrec,
|
||||
optimizer_idx,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split="train",
|
||||
)
|
||||
discloss, log_dict_disc = self.loss(x, xrec, **extra_info)
|
||||
# -> discriminator always needs to return a tuple
|
||||
self.log_dict(
|
||||
log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True
|
||||
)
|
||||
return discloss
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown optimizer {optimizer_idx}")
|
||||
|
||||
def validation_step(self, batch, batch_idx) -> Dict:
|
||||
def training_step(self, batch: dict, batch_idx: int):
|
||||
opts = self.optimizers()
|
||||
if not isinstance(opts, list):
|
||||
# Non-adversarial case
|
||||
opts = [opts]
|
||||
optimizer_idx = batch_idx % len(opts)
|
||||
if self.global_step < self.disc_start_iter:
|
||||
optimizer_idx = 0
|
||||
opt = opts[optimizer_idx]
|
||||
opt.zero_grad()
|
||||
with opt.toggle_model():
|
||||
loss = self.inner_training_step(
|
||||
batch, batch_idx, optimizer_idx=optimizer_idx
|
||||
)
|
||||
self.manual_backward(loss)
|
||||
opt.step()
|
||||
|
||||
def validation_step(self, batch: dict, batch_idx: int) -> Dict:
|
||||
log_dict = self._validation_step(batch, batch_idx)
|
||||
with self.ema_scope():
|
||||
log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
|
||||
log_dict.update(log_dict_ema)
|
||||
return log_dict
|
||||
|
||||
def _validation_step(self, batch, batch_idx, postfix="") -> Dict:
|
||||
def _validation_step(self, batch: dict, batch_idx: int, postfix: str = "") -> Dict:
|
||||
x = self.get_input(batch)
|
||||
|
||||
z, xrec, regularization_log = self(x)
|
||||
aeloss, log_dict_ae = self.loss(
|
||||
regularization_log,
|
||||
x,
|
||||
xrec,
|
||||
0,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split="val" + postfix,
|
||||
if hasattr(self.loss, "forward_keys"):
|
||||
extra_info = {
|
||||
"z": z,
|
||||
"optimizer_idx": 0,
|
||||
"global_step": self.global_step,
|
||||
"last_layer": self.get_last_layer(),
|
||||
"split": "val" + postfix,
|
||||
"regularization_log": regularization_log,
|
||||
"autoencoder": self,
|
||||
}
|
||||
extra_info = {k: extra_info[k] for k in self.loss.forward_keys}
|
||||
else:
|
||||
extra_info = dict()
|
||||
out_loss = self.loss(x, xrec, **extra_info)
|
||||
if isinstance(out_loss, tuple):
|
||||
aeloss, log_dict_ae = out_loss
|
||||
else:
|
||||
# simple loss function
|
||||
aeloss = out_loss
|
||||
log_dict_ae = {f"val{postfix}/loss/rec": aeloss.detach()}
|
||||
full_log_dict = log_dict_ae
|
||||
|
||||
if "optimizer_idx" in extra_info:
|
||||
extra_info["optimizer_idx"] = 1
|
||||
discloss, log_dict_disc = self.loss(x, xrec, **extra_info)
|
||||
full_log_dict.update(log_dict_disc)
|
||||
self.log(
|
||||
f"val{postfix}/loss/rec",
|
||||
log_dict_ae[f"val{postfix}/loss/rec"],
|
||||
sync_dist=True,
|
||||
)
|
||||
self.log_dict(full_log_dict, sync_dist=True)
|
||||
return full_log_dict
|
||||
|
||||
discloss, log_dict_disc = self.loss(
|
||||
regularization_log,
|
||||
x,
|
||||
xrec,
|
||||
1,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split="val" + postfix,
|
||||
)
|
||||
self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
|
||||
log_dict_ae.update(log_dict_disc)
|
||||
self.log_dict(log_dict_ae)
|
||||
return log_dict_ae
|
||||
|
||||
def configure_optimizers(self) -> Any:
|
||||
ae_params = self.get_autoencoder_params()
|
||||
disc_params = self.get_discriminator_params()
|
||||
def get_param_groups(
|
||||
self, parameter_names: List[List[str]], optimizer_args: List[dict]
|
||||
) -> Tuple[List[Dict[str, Any]], int]:
|
||||
groups = []
|
||||
num_params = 0
|
||||
for names, args in zip(parameter_names, optimizer_args):
|
||||
params = []
|
||||
for pattern_ in names:
|
||||
pattern_params = []
|
||||
pattern = re.compile(pattern_)
|
||||
for p_name, param in self.named_parameters():
|
||||
if re.match(pattern, p_name):
|
||||
pattern_params.append(param)
|
||||
num_params += param.numel()
|
||||
if len(pattern_params) == 0:
|
||||
logpy.warn(f"Did not find parameters for pattern {pattern_}")
|
||||
params.extend(pattern_params)
|
||||
groups.append({"params": params, **args})
|
||||
return groups, num_params
|
||||
|
||||
def configure_optimizers(self) -> List[torch.optim.Optimizer]:
|
||||
if self.trainable_ae_params is None:
|
||||
ae_params = self.get_autoencoder_params()
|
||||
else:
|
||||
ae_params, num_ae_params = self.get_param_groups(
|
||||
self.trainable_ae_params, self.ae_optimizer_args
|
||||
)
|
||||
logpy.info(f"Number of trainable autoencoder parameters: {num_ae_params:,}")
|
||||
if self.trainable_disc_params is None:
|
||||
disc_params = self.get_discriminator_params()
|
||||
else:
|
||||
disc_params, num_disc_params = self.get_param_groups(
|
||||
self.trainable_disc_params, self.disc_optimizer_args
|
||||
)
|
||||
logpy.info(
|
||||
f"Number of trainable discriminator parameters: {num_disc_params:,}"
|
||||
)
|
||||
opt_ae = self.instantiate_optimizer_from_config(
|
||||
ae_params,
|
||||
default(self.lr_g_factor, 1.0) * self.learning_rate,
|
||||
self.optimizer_config,
|
||||
)
|
||||
opt_disc = self.instantiate_optimizer_from_config(
|
||||
disc_params, self.learning_rate, self.optimizer_config
|
||||
)
|
||||
opts = [opt_ae]
|
||||
if len(disc_params) > 0:
|
||||
opt_disc = self.instantiate_optimizer_from_config(
|
||||
disc_params, self.learning_rate, self.optimizer_config
|
||||
)
|
||||
opts.append(opt_disc)
|
||||
|
||||
return [opt_ae, opt_disc], []
|
||||
return opts
|
||||
|
||||
@torch.no_grad()
|
||||
def log_images(self, batch: Dict, **kwargs) -> Dict:
|
||||
def log_images(
|
||||
self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs
|
||||
) -> dict:
|
||||
log = dict()
|
||||
additional_decode_kwargs = {}
|
||||
x = self.get_input(batch)
|
||||
_, xrec, _ = self(x)
|
||||
additional_decode_kwargs.update(
|
||||
{key: batch[key] for key in self.additional_decode_keys.intersection(batch)}
|
||||
)
|
||||
|
||||
_, xrec, _ = self(x, **additional_decode_kwargs)
|
||||
log["inputs"] = x
|
||||
log["reconstructions"] = xrec
|
||||
diff = 0.5 * torch.abs(torch.clamp(xrec, -1.0, 1.0) - x)
|
||||
diff.clamp_(0, 1.0)
|
||||
log["diff"] = 2.0 * diff - 1.0
|
||||
# diff_boost shows location of small errors, by boosting their
|
||||
# brightness.
|
||||
log["diff_boost"] = (
|
||||
2.0 * torch.clamp(self.diff_boost_factor * diff, 0.0, 1.0) - 1
|
||||
)
|
||||
if hasattr(self.loss, "log_images"):
|
||||
log.update(self.loss.log_images(x, xrec))
|
||||
with self.ema_scope():
|
||||
_, xrec_ema, _ = self(x)
|
||||
_, xrec_ema, _ = self(x, **additional_decode_kwargs)
|
||||
log["reconstructions_ema"] = xrec_ema
|
||||
diff_ema = 0.5 * torch.abs(torch.clamp(xrec_ema, -1.0, 1.0) - x)
|
||||
diff_ema.clamp_(0, 1.0)
|
||||
log["diff_ema"] = 2.0 * diff_ema - 1.0
|
||||
log["diff_boost_ema"] = (
|
||||
2.0 * torch.clamp(self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1
|
||||
)
|
||||
if additional_log_kwargs:
|
||||
additional_decode_kwargs.update(additional_log_kwargs)
|
||||
_, xrec_add, _ = self(x, **additional_decode_kwargs)
|
||||
log_str = "reconstructions-" + "-".join(
|
||||
[f"{key}={additional_log_kwargs[key]}" for key in additional_log_kwargs]
|
||||
)
|
||||
log[log_str] = xrec_add
|
||||
return log
|
||||
|
||||
|
||||
class AutoencoderKL(AutoencodingEngine):
|
||||
class AutoencodingEngineLegacy(AutoencodingEngine):
|
||||
def __init__(self, embed_dim: int, **kwargs):
|
||||
self.max_batch_size = kwargs.pop("max_batch_size", None)
|
||||
ddconfig = kwargs.pop("ddconfig")
|
||||
ckpt_path = kwargs.pop("ckpt_path", None)
|
||||
ignore_keys = kwargs.pop("ignore_keys", ())
|
||||
ckpt_engine = kwargs.pop("ckpt_engine", None)
|
||||
super().__init__(
|
||||
encoder_config={"target": "torch.nn.Identity"},
|
||||
decoder_config={"target": "torch.nn.Identity"},
|
||||
regularizer_config={"target": "torch.nn.Identity"},
|
||||
loss_config=kwargs.pop("lossconfig"),
|
||||
encoder_config={
|
||||
"target": "sgm.modules.diffusionmodules.model.Encoder",
|
||||
"params": ddconfig,
|
||||
},
|
||||
decoder_config={
|
||||
"target": "sgm.modules.diffusionmodules.model.Decoder",
|
||||
"params": ddconfig,
|
||||
},
|
||||
**kwargs,
|
||||
)
|
||||
assert ddconfig["double_z"]
|
||||
self.encoder = Encoder(**ddconfig)
|
||||
self.decoder = Decoder(**ddconfig)
|
||||
self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
|
||||
self.quant_conv = torch.nn.Conv2d(
|
||||
(1 + ddconfig["double_z"]) * ddconfig["z_channels"],
|
||||
(1 + ddconfig["double_z"]) * embed_dim,
|
||||
1,
|
||||
)
|
||||
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||
self.apply_ckpt(default(ckpt_path, ckpt_engine))
|
||||
|
||||
def encode(self, x):
|
||||
assert (
|
||||
not self.training
|
||||
), f"{self.__class__.__name__} only supports inference currently"
|
||||
h = self.encoder(x)
|
||||
moments = self.quant_conv(h)
|
||||
posterior = DiagonalGaussianDistribution(moments)
|
||||
return posterior
|
||||
def get_autoencoder_params(self) -> list:
|
||||
params = super().get_autoencoder_params()
|
||||
return params
|
||||
|
||||
def encode(
|
||||
self, x: torch.Tensor, return_reg_log: bool = False
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
|
||||
if self.max_batch_size is None:
|
||||
z = self.encoder(x)
|
||||
z = self.quant_conv(z)
|
||||
else:
|
||||
N = x.shape[0]
|
||||
bs = self.max_batch_size
|
||||
n_batches = int(math.ceil(N / bs))
|
||||
z = list()
|
||||
for i_batch in range(n_batches):
|
||||
z_batch = self.encoder(x[i_batch * bs : (i_batch + 1) * bs])
|
||||
z_batch = self.quant_conv(z_batch)
|
||||
z.append(z_batch)
|
||||
z = torch.cat(z, 0)
|
||||
|
||||
z, reg_log = self.regularization(z)
|
||||
if return_reg_log:
|
||||
return z, reg_log
|
||||
return z
|
||||
|
||||
def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor:
|
||||
if self.max_batch_size is None:
|
||||
dec = self.post_quant_conv(z)
|
||||
dec = self.decoder(dec, **decoder_kwargs)
|
||||
else:
|
||||
N = z.shape[0]
|
||||
bs = self.max_batch_size
|
||||
n_batches = int(math.ceil(N / bs))
|
||||
dec = list()
|
||||
for i_batch in range(n_batches):
|
||||
dec_batch = self.post_quant_conv(z[i_batch * bs : (i_batch + 1) * bs])
|
||||
dec_batch = self.decoder(dec_batch, **decoder_kwargs)
|
||||
dec.append(dec_batch)
|
||||
dec = torch.cat(dec, 0)
|
||||
|
||||
def decode(self, z, **decoder_kwargs):
|
||||
z = self.post_quant_conv(z)
|
||||
dec = self.decoder(z, **decoder_kwargs)
|
||||
return dec
|
||||
|
||||
|
||||
class AutoencoderKLInferenceWrapper(AutoencoderKL):
|
||||
def encode(self, x):
|
||||
return super().encode(x).sample()
|
||||
class AutoencoderKL(AutoencodingEngineLegacy):
|
||||
def __init__(self, **kwargs):
|
||||
if "lossconfig" in kwargs:
|
||||
kwargs["loss_config"] = kwargs.pop("lossconfig")
|
||||
super().__init__(
|
||||
regularizer_config={
|
||||
"target": (
|
||||
"sgm.modules.autoencoding.regularizers"
|
||||
".DiagonalGaussianRegularizer"
|
||||
)
|
||||
},
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class AutoencoderLegacyVQ(AutoencodingEngineLegacy):
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
n_embed: int,
|
||||
sane_index_shape: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
if "lossconfig" in kwargs:
|
||||
logpy.warn(f"Parameter `lossconfig` is deprecated, use `loss_config`.")
|
||||
kwargs["loss_config"] = kwargs.pop("lossconfig")
|
||||
super().__init__(
|
||||
regularizer_config={
|
||||
"target": (
|
||||
"sgm.modules.autoencoding.regularizers.quantize" ".VectorQuantizer"
|
||||
),
|
||||
"params": {
|
||||
"n_e": n_embed,
|
||||
"e_dim": embed_dim,
|
||||
"sane_index_shape": sane_index_shape,
|
||||
},
|
||||
},
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class IdentityFirstStage(AbstractAutoencoder):
|
||||
@@ -333,3 +558,58 @@ class IdentityFirstStage(AbstractAutoencoder):
|
||||
|
||||
def decode(self, x: Any, *args, **kwargs) -> Any:
|
||||
return x
|
||||
|
||||
|
||||
class AEIntegerWrapper(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
shape: Union[None, Tuple[int, int], List[int]] = (16, 16),
|
||||
regularization_key: str = "regularization",
|
||||
encoder_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
assert hasattr(model, "encode") and hasattr(
|
||||
model, "decode"
|
||||
), "Need AE interface"
|
||||
self.regularization = get_nested_attribute(model, regularization_key)
|
||||
self.shape = shape
|
||||
self.encoder_kwargs = default(encoder_kwargs, {"return_reg_log": True})
|
||||
|
||||
def encode(self, x) -> torch.Tensor:
|
||||
assert (
|
||||
not self.training
|
||||
), f"{self.__class__.__name__} only supports inference currently"
|
||||
_, log = self.model.encode(x, **self.encoder_kwargs)
|
||||
assert isinstance(log, dict)
|
||||
inds = log["min_encoding_indices"]
|
||||
return rearrange(inds, "b ... -> b (...)")
|
||||
|
||||
def decode(
|
||||
self, inds: torch.Tensor, shape: Union[None, tuple, list] = None
|
||||
) -> torch.Tensor:
|
||||
# expect inds shape (b, s) with s = h*w
|
||||
shape = default(shape, self.shape) # Optional[(h, w)]
|
||||
if shape is not None:
|
||||
assert len(shape) == 2, f"Unhandeled shape {shape}"
|
||||
inds = rearrange(inds, "b (h w) -> b h w", h=shape[0], w=shape[1])
|
||||
h = self.regularization.get_codebook_entry(inds) # (b, h, w, c)
|
||||
h = rearrange(h, "b h w c -> b c h w")
|
||||
return self.model.decode(h)
|
||||
|
||||
|
||||
class AutoencoderKLModeOnly(AutoencodingEngineLegacy):
|
||||
def __init__(self, **kwargs):
|
||||
if "lossconfig" in kwargs:
|
||||
kwargs["loss_config"] = kwargs.pop("lossconfig")
|
||||
super().__init__(
|
||||
regularizer_config={
|
||||
"target": (
|
||||
"sgm.modules.autoencoding.regularizers"
|
||||
".DiagonalGaussianRegularizer"
|
||||
),
|
||||
"params": {"sample": False},
|
||||
},
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user