mirror of
https://github.com/Stability-AI/generative-models.git
synced 2025-12-31 20:24:24 +01:00
Stable Video Diffusion
This commit is contained in:
@@ -1,7 +0,0 @@
|
||||
from .denoiser import Denoiser
|
||||
from .discretizer import Discretization
|
||||
from .loss import StandardDiffusionLoss
|
||||
from .model import Decoder, Encoder, Model
|
||||
from .openaimodel import UNetModel
|
||||
from .sampling import BaseDiffusionSampler
|
||||
from .wrappers import OpenAIWrapper
|
||||
|
||||
@@ -1,62 +1,74 @@
|
||||
from typing import Dict, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...util import append_dims, instantiate_from_config
|
||||
from .denoiser_scaling import DenoiserScaling
|
||||
from .discretizer import Discretization
|
||||
|
||||
|
||||
class Denoiser(nn.Module):
|
||||
def __init__(self, weighting_config, scaling_config):
|
||||
def __init__(self, scaling_config: Dict):
|
||||
super().__init__()
|
||||
|
||||
self.weighting = instantiate_from_config(weighting_config)
|
||||
self.scaling = instantiate_from_config(scaling_config)
|
||||
self.scaling: DenoiserScaling = instantiate_from_config(scaling_config)
|
||||
|
||||
def possibly_quantize_sigma(self, sigma):
|
||||
def possibly_quantize_sigma(self, sigma: torch.Tensor) -> torch.Tensor:
|
||||
return sigma
|
||||
|
||||
def possibly_quantize_c_noise(self, c_noise):
|
||||
def possibly_quantize_c_noise(self, c_noise: torch.Tensor) -> torch.Tensor:
|
||||
return c_noise
|
||||
|
||||
def w(self, sigma):
|
||||
return self.weighting(sigma)
|
||||
|
||||
def __call__(self, network, input, sigma, cond):
|
||||
def forward(
|
||||
self,
|
||||
network: nn.Module,
|
||||
input: torch.Tensor,
|
||||
sigma: torch.Tensor,
|
||||
cond: Dict,
|
||||
**additional_model_inputs,
|
||||
) -> torch.Tensor:
|
||||
sigma = self.possibly_quantize_sigma(sigma)
|
||||
sigma_shape = sigma.shape
|
||||
sigma = append_dims(sigma, input.ndim)
|
||||
c_skip, c_out, c_in, c_noise = self.scaling(sigma)
|
||||
c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape))
|
||||
return network(input * c_in, c_noise, cond) * c_out + input * c_skip
|
||||
return (
|
||||
network(input * c_in, c_noise, cond, **additional_model_inputs) * c_out
|
||||
+ input * c_skip
|
||||
)
|
||||
|
||||
|
||||
class DiscreteDenoiser(Denoiser):
|
||||
def __init__(
|
||||
self,
|
||||
weighting_config,
|
||||
scaling_config,
|
||||
num_idx,
|
||||
discretization_config,
|
||||
do_append_zero=False,
|
||||
quantize_c_noise=True,
|
||||
flip=True,
|
||||
scaling_config: Dict,
|
||||
num_idx: int,
|
||||
discretization_config: Dict,
|
||||
do_append_zero: bool = False,
|
||||
quantize_c_noise: bool = True,
|
||||
flip: bool = True,
|
||||
):
|
||||
super().__init__(weighting_config, scaling_config)
|
||||
sigmas = instantiate_from_config(discretization_config)(
|
||||
num_idx, do_append_zero=do_append_zero, flip=flip
|
||||
super().__init__(scaling_config)
|
||||
self.discretization: Discretization = instantiate_from_config(
|
||||
discretization_config
|
||||
)
|
||||
sigmas = self.discretization(num_idx, do_append_zero=do_append_zero, flip=flip)
|
||||
self.register_buffer("sigmas", sigmas)
|
||||
self.quantize_c_noise = quantize_c_noise
|
||||
self.num_idx = num_idx
|
||||
|
||||
def sigma_to_idx(self, sigma):
|
||||
def sigma_to_idx(self, sigma: torch.Tensor) -> torch.Tensor:
|
||||
dists = sigma - self.sigmas[:, None]
|
||||
return dists.abs().argmin(dim=0).view(sigma.shape)
|
||||
|
||||
def idx_to_sigma(self, idx):
|
||||
def idx_to_sigma(self, idx: Union[torch.Tensor, int]) -> torch.Tensor:
|
||||
return self.sigmas[idx]
|
||||
|
||||
def possibly_quantize_sigma(self, sigma):
|
||||
def possibly_quantize_sigma(self, sigma: torch.Tensor) -> torch.Tensor:
|
||||
return self.idx_to_sigma(self.sigma_to_idx(sigma))
|
||||
|
||||
def possibly_quantize_c_noise(self, c_noise):
|
||||
def possibly_quantize_c_noise(self, c_noise: torch.Tensor) -> torch.Tensor:
|
||||
if self.quantize_c_noise:
|
||||
return self.sigma_to_idx(c_noise)
|
||||
else:
|
||||
|
||||
@@ -1,11 +1,24 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class DenoiserScaling(ABC):
|
||||
@abstractmethod
|
||||
def __call__(
|
||||
self, sigma: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
pass
|
||||
|
||||
|
||||
class EDMScaling:
|
||||
def __init__(self, sigma_data=0.5):
|
||||
def __init__(self, sigma_data: float = 0.5):
|
||||
self.sigma_data = sigma_data
|
||||
|
||||
def __call__(self, sigma):
|
||||
def __call__(
|
||||
self, sigma: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
|
||||
c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5
|
||||
c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5
|
||||
@@ -14,7 +27,9 @@ class EDMScaling:
|
||||
|
||||
|
||||
class EpsScaling:
|
||||
def __call__(self, sigma):
|
||||
def __call__(
|
||||
self, sigma: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
c_skip = torch.ones_like(sigma, device=sigma.device)
|
||||
c_out = -sigma
|
||||
c_in = 1 / (sigma**2 + 1.0) ** 0.5
|
||||
@@ -23,9 +38,22 @@ class EpsScaling:
|
||||
|
||||
|
||||
class VScaling:
|
||||
def __call__(self, sigma):
|
||||
def __call__(
|
||||
self, sigma: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
c_skip = 1.0 / (sigma**2 + 1.0)
|
||||
c_out = -sigma / (sigma**2 + 1.0) ** 0.5
|
||||
c_in = 1.0 / (sigma**2 + 1.0) ** 0.5
|
||||
c_noise = sigma.clone()
|
||||
return c_skip, c_out, c_in, c_noise
|
||||
|
||||
|
||||
class VScalingWithEDMcNoise(DenoiserScaling):
|
||||
def __call__(
|
||||
self, sigma: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
c_skip = 1.0 / (sigma**2 + 1.0)
|
||||
c_out = -sigma / (sigma**2 + 1.0) ** 0.5
|
||||
c_in = 1.0 / (sigma**2 + 1.0) ** 0.5
|
||||
c_noise = 0.25 * sigma.log()
|
||||
return c_skip, c_out, c_in, c_noise
|
||||
|
||||
@@ -1,31 +1,33 @@
|
||||
from functools import partial
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from ...util import default, instantiate_from_config
|
||||
from ...util import append_dims, default
|
||||
|
||||
logpy = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VanillaCFG:
|
||||
"""
|
||||
implements parallelized CFG
|
||||
"""
|
||||
class Guider(ABC):
|
||||
@abstractmethod
|
||||
def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor:
|
||||
pass
|
||||
|
||||
def __init__(self, scale, dyn_thresh_config=None):
|
||||
scale_schedule = lambda scale, sigma: scale # independent of step
|
||||
self.scale_schedule = partial(scale_schedule, scale)
|
||||
self.dyn_thresh = instantiate_from_config(
|
||||
default(
|
||||
dyn_thresh_config,
|
||||
{
|
||||
"target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"
|
||||
},
|
||||
)
|
||||
)
|
||||
def prepare_inputs(
|
||||
self, x: torch.Tensor, s: float, c: Dict, uc: Dict
|
||||
) -> Tuple[torch.Tensor, float, Dict]:
|
||||
pass
|
||||
|
||||
def __call__(self, x, sigma):
|
||||
|
||||
class VanillaCFG(Guider):
|
||||
def __init__(self, scale: float):
|
||||
self.scale = scale
|
||||
|
||||
def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
|
||||
x_u, x_c = x.chunk(2)
|
||||
scale_value = self.scale_schedule(sigma)
|
||||
x_pred = self.dyn_thresh(x_u, x_c, scale_value)
|
||||
x_pred = x_u + self.scale * (x_c - x_u)
|
||||
return x_pred
|
||||
|
||||
def prepare_inputs(self, x, s, c, uc):
|
||||
@@ -40,14 +42,58 @@ class VanillaCFG:
|
||||
return torch.cat([x] * 2), torch.cat([s] * 2), c_out
|
||||
|
||||
|
||||
class IdentityGuider:
|
||||
def __call__(self, x, sigma):
|
||||
class IdentityGuider(Guider):
|
||||
def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor:
|
||||
return x
|
||||
|
||||
def prepare_inputs(self, x, s, c, uc):
|
||||
def prepare_inputs(
|
||||
self, x: torch.Tensor, s: float, c: Dict, uc: Dict
|
||||
) -> Tuple[torch.Tensor, float, Dict]:
|
||||
c_out = dict()
|
||||
|
||||
for k in c:
|
||||
c_out[k] = c[k]
|
||||
|
||||
return x, s, c_out
|
||||
|
||||
|
||||
class LinearPredictionGuider(Guider):
|
||||
def __init__(
|
||||
self,
|
||||
max_scale: float,
|
||||
num_frames: int,
|
||||
min_scale: float = 1.0,
|
||||
additional_cond_keys: Optional[Union[List[str], str]] = None,
|
||||
):
|
||||
self.min_scale = min_scale
|
||||
self.max_scale = max_scale
|
||||
self.num_frames = num_frames
|
||||
self.scale = torch.linspace(min_scale, max_scale, num_frames).unsqueeze(0)
|
||||
|
||||
additional_cond_keys = default(additional_cond_keys, [])
|
||||
if isinstance(additional_cond_keys, str):
|
||||
additional_cond_keys = [additional_cond_keys]
|
||||
self.additional_cond_keys = additional_cond_keys
|
||||
|
||||
def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
|
||||
x_u, x_c = x.chunk(2)
|
||||
|
||||
x_u = rearrange(x_u, "(b t) ... -> b t ...", t=self.num_frames)
|
||||
x_c = rearrange(x_c, "(b t) ... -> b t ...", t=self.num_frames)
|
||||
scale = repeat(self.scale, "1 t -> b t", b=x_u.shape[0])
|
||||
scale = append_dims(scale, x_u.ndim).to(x_u.device)
|
||||
|
||||
return rearrange(x_u + scale * (x_c - x_u), "b t ... -> (b t) ...")
|
||||
|
||||
def prepare_inputs(
|
||||
self, x: torch.Tensor, s: torch.Tensor, c: dict, uc: dict
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, dict]:
|
||||
c_out = dict()
|
||||
|
||||
for k in c:
|
||||
if k in ["vector", "crossattn", "concat"] + self.additional_cond_keys:
|
||||
c_out[k] = torch.cat((uc[k], c[k]), 0)
|
||||
else:
|
||||
assert c[k] == uc[k]
|
||||
c_out[k] = c[k]
|
||||
return torch.cat([x] * 2), torch.cat([s] * 2), c_out
|
||||
|
||||
@@ -1,31 +1,34 @@
|
||||
from typing import List, Optional, Union
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from omegaconf import ListConfig
|
||||
|
||||
from ...util import append_dims, instantiate_from_config
|
||||
from ...modules.autoencoding.lpips.loss.lpips import LPIPS
|
||||
from ...modules.encoders.modules import GeneralConditioner
|
||||
from ...util import append_dims, instantiate_from_config
|
||||
from .denoiser import Denoiser
|
||||
|
||||
|
||||
class StandardDiffusionLoss(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
sigma_sampler_config,
|
||||
type="l2",
|
||||
offset_noise_level=0.0,
|
||||
batch2model_keys: Optional[Union[str, List[str], ListConfig]] = None,
|
||||
sigma_sampler_config: dict,
|
||||
loss_weighting_config: dict,
|
||||
loss_type: str = "l2",
|
||||
offset_noise_level: float = 0.0,
|
||||
batch2model_keys: Optional[Union[str, List[str]]] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
assert type in ["l2", "l1", "lpips"]
|
||||
assert loss_type in ["l2", "l1", "lpips"]
|
||||
|
||||
self.sigma_sampler = instantiate_from_config(sigma_sampler_config)
|
||||
self.loss_weighting = instantiate_from_config(loss_weighting_config)
|
||||
|
||||
self.type = type
|
||||
self.loss_type = loss_type
|
||||
self.offset_noise_level = offset_noise_level
|
||||
|
||||
if type == "lpips":
|
||||
if loss_type == "lpips":
|
||||
self.lpips = LPIPS().eval()
|
||||
|
||||
if not batch2model_keys:
|
||||
@@ -36,34 +39,67 @@ class StandardDiffusionLoss(nn.Module):
|
||||
|
||||
self.batch2model_keys = set(batch2model_keys)
|
||||
|
||||
def __call__(self, network, denoiser, conditioner, input, batch):
|
||||
def get_noised_input(
|
||||
self, sigmas_bc: torch.Tensor, noise: torch.Tensor, input: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
noised_input = input + noise * sigmas_bc
|
||||
return noised_input
|
||||
|
||||
def forward(
|
||||
self,
|
||||
network: nn.Module,
|
||||
denoiser: Denoiser,
|
||||
conditioner: GeneralConditioner,
|
||||
input: torch.Tensor,
|
||||
batch: Dict,
|
||||
) -> torch.Tensor:
|
||||
cond = conditioner(batch)
|
||||
return self._forward(network, denoiser, cond, input, batch)
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
network: nn.Module,
|
||||
denoiser: Denoiser,
|
||||
cond: Dict,
|
||||
input: torch.Tensor,
|
||||
batch: Dict,
|
||||
) -> Tuple[torch.Tensor, Dict]:
|
||||
additional_model_inputs = {
|
||||
key: batch[key] for key in self.batch2model_keys.intersection(batch)
|
||||
}
|
||||
sigmas = self.sigma_sampler(input.shape[0]).to(input)
|
||||
|
||||
sigmas = self.sigma_sampler(input.shape[0]).to(input.device)
|
||||
noise = torch.randn_like(input)
|
||||
if self.offset_noise_level > 0.0:
|
||||
noise = noise + self.offset_noise_level * append_dims(
|
||||
torch.randn(input.shape[0], device=input.device), input.ndim
|
||||
offset_shape = (
|
||||
(input.shape[0], 1, input.shape[2])
|
||||
if self.n_frames is not None
|
||||
else (input.shape[0], input.shape[1])
|
||||
)
|
||||
noised_input = input + noise * append_dims(sigmas, input.ndim)
|
||||
noise = noise + self.offset_noise_level * append_dims(
|
||||
torch.randn(offset_shape, device=input.device),
|
||||
input.ndim,
|
||||
)
|
||||
sigmas_bc = append_dims(sigmas, input.ndim)
|
||||
noised_input = self.get_noised_input(sigmas_bc, noise, input)
|
||||
|
||||
model_output = denoiser(
|
||||
network, noised_input, sigmas, cond, **additional_model_inputs
|
||||
)
|
||||
w = append_dims(denoiser.w(sigmas), input.ndim)
|
||||
w = append_dims(self.loss_weighting(sigmas), input.ndim)
|
||||
return self.get_loss(model_output, input, w)
|
||||
|
||||
def get_loss(self, model_output, target, w):
|
||||
if self.type == "l2":
|
||||
if self.loss_type == "l2":
|
||||
return torch.mean(
|
||||
(w * (model_output - target) ** 2).reshape(target.shape[0], -1), 1
|
||||
)
|
||||
elif self.type == "l1":
|
||||
elif self.loss_type == "l1":
|
||||
return torch.mean(
|
||||
(w * (model_output - target).abs()).reshape(target.shape[0], -1), 1
|
||||
)
|
||||
elif self.type == "lpips":
|
||||
elif self.loss_type == "lpips":
|
||||
loss = self.lpips(model_output, target).reshape(-1)
|
||||
return loss
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown loss type {self.loss_type}")
|
||||
|
||||
32
sgm/modules/diffusionmodules/loss_weighting.py
Normal file
32
sgm/modules/diffusionmodules/loss_weighting.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class DiffusionLossWeighting(ABC):
|
||||
@abstractmethod
|
||||
def __call__(self, sigma: torch.Tensor) -> torch.Tensor:
|
||||
pass
|
||||
|
||||
|
||||
class UnitWeighting(DiffusionLossWeighting):
|
||||
def __call__(self, sigma: torch.Tensor) -> torch.Tensor:
|
||||
return torch.ones_like(sigma, device=sigma.device)
|
||||
|
||||
|
||||
class EDMWeighting(DiffusionLossWeighting):
|
||||
def __init__(self, sigma_data: float = 0.5):
|
||||
self.sigma_data = sigma_data
|
||||
|
||||
def __call__(self, sigma: torch.Tensor) -> torch.Tensor:
|
||||
return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2
|
||||
|
||||
|
||||
class VWeighting(EDMWeighting):
|
||||
def __init__(self):
|
||||
super().__init__(sigma_data=1.0)
|
||||
|
||||
|
||||
class EpsWeighting(DiffusionLossWeighting):
|
||||
def __call__(self, sigma: torch.Tensor) -> torch.Tensor:
|
||||
return sigma**-2.0
|
||||
@@ -1,4 +1,5 @@
|
||||
# pytorch_diffusion + derived encoder decoder
|
||||
import logging
|
||||
import math
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
@@ -8,6 +9,8 @@ import torch.nn as nn
|
||||
from einops import rearrange
|
||||
from packaging import version
|
||||
|
||||
logpy = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
import xformers
|
||||
import xformers.ops
|
||||
@@ -15,7 +18,7 @@ try:
|
||||
XFORMERS_IS_AVAILABLE = True
|
||||
except:
|
||||
XFORMERS_IS_AVAILABLE = False
|
||||
print("no module 'xformers'. Processing without...")
|
||||
logpy.warning("no module 'xformers'. Processing without...")
|
||||
|
||||
from ...modules.attention import LinearAttention, MemoryEfficientCrossAttention
|
||||
|
||||
@@ -288,12 +291,14 @@ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
|
||||
f"as it is too expensive. Please install xformers via e.g. 'pip install xformers==0.0.16'"
|
||||
)
|
||||
attn_type = "vanilla-xformers"
|
||||
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
|
||||
logpy.info(f"making attention of type '{attn_type}' with {in_channels} in_channels")
|
||||
if attn_type == "vanilla":
|
||||
assert attn_kwargs is None
|
||||
return AttnBlock(in_channels)
|
||||
elif attn_type == "vanilla-xformers":
|
||||
print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
|
||||
logpy.info(
|
||||
f"building MemoryEfficientAttnBlock with {in_channels} in_channels..."
|
||||
)
|
||||
return MemoryEfficientAttnBlock(in_channels)
|
||||
elif type == "memory-efficient-cross-attn":
|
||||
attn_kwargs["query_dim"] = in_channels
|
||||
@@ -633,7 +638,7 @@ class Decoder(nn.Module):
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||
print(
|
||||
logpy.info(
|
||||
"Working with z of shape {} = {} dimensions.".format(
|
||||
self.z_shape, np.prod(self.z_shape)
|
||||
)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -9,13 +9,10 @@ import torch
|
||||
from omegaconf import ListConfig, OmegaConf
|
||||
from tqdm import tqdm
|
||||
|
||||
from ...modules.diffusionmodules.sampling_utils import (
|
||||
get_ancestral_step,
|
||||
linear_multistep_coeff,
|
||||
to_d,
|
||||
to_neg_log_sigma,
|
||||
to_sigma,
|
||||
)
|
||||
from ...modules.diffusionmodules.sampling_utils import (get_ancestral_step,
|
||||
linear_multistep_coeff,
|
||||
to_d, to_neg_log_sigma,
|
||||
to_sigma)
|
||||
from ...util import append_dims, default, instantiate_from_config
|
||||
|
||||
DEFAULT_GUIDER = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"}
|
||||
|
||||
@@ -4,11 +4,6 @@ from scipy import integrate
|
||||
from ...util import append_dims
|
||||
|
||||
|
||||
class NoDynamicThresholding:
|
||||
def __call__(self, uncond, cond, scale):
|
||||
return uncond + scale * (cond - uncond)
|
||||
|
||||
|
||||
def linear_multistep_coeff(order, t, i, j, epsrel=1e-4):
|
||||
if order - 1 > i:
|
||||
raise ValueError(f"Order {order} too high for step {i}")
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""
|
||||
adopted from
|
||||
partially adopted from
|
||||
https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
||||
and
|
||||
https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
|
||||
@@ -10,10 +10,11 @@ thanks!
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import repeat
|
||||
from einops import rearrange, repeat
|
||||
|
||||
|
||||
def make_beta_schedule(
|
||||
@@ -306,3 +307,63 @@ def avg_pool_nd(dims, *args, **kwargs):
|
||||
elif dims == 3:
|
||||
return nn.AvgPool3d(*args, **kwargs)
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
||||
class AlphaBlender(nn.Module):
|
||||
strategies = ["learned", "fixed", "learned_with_images"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
alpha: float,
|
||||
merge_strategy: str = "learned_with_images",
|
||||
rearrange_pattern: str = "b t -> (b t) 1 1",
|
||||
):
|
||||
super().__init__()
|
||||
self.merge_strategy = merge_strategy
|
||||
self.rearrange_pattern = rearrange_pattern
|
||||
|
||||
assert (
|
||||
merge_strategy in self.strategies
|
||||
), f"merge_strategy needs to be in {self.strategies}"
|
||||
|
||||
if self.merge_strategy == "fixed":
|
||||
self.register_buffer("mix_factor", torch.Tensor([alpha]))
|
||||
elif (
|
||||
self.merge_strategy == "learned"
|
||||
or self.merge_strategy == "learned_with_images"
|
||||
):
|
||||
self.register_parameter(
|
||||
"mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
|
||||
|
||||
def get_alpha(self, image_only_indicator: torch.Tensor) -> torch.Tensor:
|
||||
if self.merge_strategy == "fixed":
|
||||
alpha = self.mix_factor
|
||||
elif self.merge_strategy == "learned":
|
||||
alpha = torch.sigmoid(self.mix_factor)
|
||||
elif self.merge_strategy == "learned_with_images":
|
||||
assert image_only_indicator is not None, "need image_only_indicator ..."
|
||||
alpha = torch.where(
|
||||
image_only_indicator.bool(),
|
||||
torch.ones(1, 1, device=image_only_indicator.device),
|
||||
rearrange(torch.sigmoid(self.mix_factor), "... -> ... 1"),
|
||||
)
|
||||
alpha = rearrange(alpha, self.rearrange_pattern)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return alpha
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x_spatial: torch.Tensor,
|
||||
x_temporal: torch.Tensor,
|
||||
image_only_indicator: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
alpha = self.get_alpha(image_only_indicator)
|
||||
x = (
|
||||
alpha.to(x_spatial.dtype) * x_spatial
|
||||
+ (1.0 - alpha).to(x_spatial.dtype) * x_temporal
|
||||
)
|
||||
return x
|
||||
|
||||
493
sgm/modules/diffusionmodules/video_model.py
Normal file
493
sgm/modules/diffusionmodules/video_model.py
Normal file
@@ -0,0 +1,493 @@
|
||||
from functools import partial
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
from ...modules.diffusionmodules.openaimodel import *
|
||||
from ...modules.video_attention import SpatialVideoTransformer
|
||||
from ...util import default
|
||||
from .util import AlphaBlender
|
||||
|
||||
|
||||
class VideoResBlock(ResBlock):
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
emb_channels: int,
|
||||
dropout: float,
|
||||
video_kernel_size: Union[int, List[int]] = 3,
|
||||
merge_strategy: str = "fixed",
|
||||
merge_factor: float = 0.5,
|
||||
out_channels: Optional[int] = None,
|
||||
use_conv: bool = False,
|
||||
use_scale_shift_norm: bool = False,
|
||||
dims: int = 2,
|
||||
use_checkpoint: bool = False,
|
||||
up: bool = False,
|
||||
down: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
channels,
|
||||
emb_channels,
|
||||
dropout,
|
||||
out_channels=out_channels,
|
||||
use_conv=use_conv,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
up=up,
|
||||
down=down,
|
||||
)
|
||||
|
||||
self.time_stack = ResBlock(
|
||||
default(out_channels, channels),
|
||||
emb_channels,
|
||||
dropout=dropout,
|
||||
dims=3,
|
||||
out_channels=default(out_channels, channels),
|
||||
use_scale_shift_norm=False,
|
||||
use_conv=False,
|
||||
up=False,
|
||||
down=False,
|
||||
kernel_size=video_kernel_size,
|
||||
use_checkpoint=use_checkpoint,
|
||||
exchange_temb_dims=True,
|
||||
)
|
||||
self.time_mixer = AlphaBlender(
|
||||
alpha=merge_factor,
|
||||
merge_strategy=merge_strategy,
|
||||
rearrange_pattern="b t -> b 1 t 1 1",
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: th.Tensor,
|
||||
emb: th.Tensor,
|
||||
num_video_frames: int,
|
||||
image_only_indicator: Optional[th.Tensor] = None,
|
||||
) -> th.Tensor:
|
||||
x = super().forward(x, emb)
|
||||
|
||||
x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames)
|
||||
x = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames)
|
||||
|
||||
x = self.time_stack(
|
||||
x, rearrange(emb, "(b t) ... -> b t ...", t=num_video_frames)
|
||||
)
|
||||
x = self.time_mixer(
|
||||
x_spatial=x_mix, x_temporal=x, image_only_indicator=image_only_indicator
|
||||
)
|
||||
x = rearrange(x, "b c t h w -> (b t) c h w")
|
||||
return x
|
||||
|
||||
|
||||
class VideoUNet(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
model_channels: int,
|
||||
out_channels: int,
|
||||
num_res_blocks: int,
|
||||
attention_resolutions: int,
|
||||
dropout: float = 0.0,
|
||||
channel_mult: List[int] = (1, 2, 4, 8),
|
||||
conv_resample: bool = True,
|
||||
dims: int = 2,
|
||||
num_classes: Optional[int] = None,
|
||||
use_checkpoint: bool = False,
|
||||
num_heads: int = -1,
|
||||
num_head_channels: int = -1,
|
||||
num_heads_upsample: int = -1,
|
||||
use_scale_shift_norm: bool = False,
|
||||
resblock_updown: bool = False,
|
||||
transformer_depth: Union[List[int], int] = 1,
|
||||
transformer_depth_middle: Optional[int] = None,
|
||||
context_dim: Optional[int] = None,
|
||||
time_downup: bool = False,
|
||||
time_context_dim: Optional[int] = None,
|
||||
extra_ff_mix_layer: bool = False,
|
||||
use_spatial_context: bool = False,
|
||||
merge_strategy: str = "fixed",
|
||||
merge_factor: float = 0.5,
|
||||
spatial_transformer_attn_type: str = "softmax",
|
||||
video_kernel_size: Union[int, List[int]] = 3,
|
||||
use_linear_in_transformer: bool = False,
|
||||
adm_in_channels: Optional[int] = None,
|
||||
disable_temporal_crossattention: bool = False,
|
||||
max_ddpm_temb_period: int = 10000,
|
||||
):
|
||||
super().__init__()
|
||||
assert context_dim is not None
|
||||
|
||||
if num_heads_upsample == -1:
|
||||
num_heads_upsample = num_heads
|
||||
|
||||
if num_heads == -1:
|
||||
assert num_head_channels != -1
|
||||
|
||||
if num_head_channels == -1:
|
||||
assert num_heads != -1
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.model_channels = model_channels
|
||||
self.out_channels = out_channels
|
||||
if isinstance(transformer_depth, int):
|
||||
transformer_depth = len(channel_mult) * [transformer_depth]
|
||||
transformer_depth_middle = default(
|
||||
transformer_depth_middle, transformer_depth[-1]
|
||||
)
|
||||
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attention_resolutions = attention_resolutions
|
||||
self.dropout = dropout
|
||||
self.channel_mult = channel_mult
|
||||
self.conv_resample = conv_resample
|
||||
self.num_classes = num_classes
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.num_heads = num_heads
|
||||
self.num_head_channels = num_head_channels
|
||||
self.num_heads_upsample = num_heads_upsample
|
||||
|
||||
time_embed_dim = model_channels * 4
|
||||
self.time_embed = nn.Sequential(
|
||||
linear(model_channels, time_embed_dim),
|
||||
nn.SiLU(),
|
||||
linear(time_embed_dim, time_embed_dim),
|
||||
)
|
||||
|
||||
if self.num_classes is not None:
|
||||
if isinstance(self.num_classes, int):
|
||||
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
||||
elif self.num_classes == "continuous":
|
||||
print("setting up linear c_adm embedding layer")
|
||||
self.label_emb = nn.Linear(1, time_embed_dim)
|
||||
elif self.num_classes == "timestep":
|
||||
self.label_emb = nn.Sequential(
|
||||
Timestep(model_channels),
|
||||
nn.Sequential(
|
||||
linear(model_channels, time_embed_dim),
|
||||
nn.SiLU(),
|
||||
linear(time_embed_dim, time_embed_dim),
|
||||
),
|
||||
)
|
||||
|
||||
elif self.num_classes == "sequential":
|
||||
assert adm_in_channels is not None
|
||||
self.label_emb = nn.Sequential(
|
||||
nn.Sequential(
|
||||
linear(adm_in_channels, time_embed_dim),
|
||||
nn.SiLU(),
|
||||
linear(time_embed_dim, time_embed_dim),
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
self.input_blocks = nn.ModuleList(
|
||||
[
|
||||
TimestepEmbedSequential(
|
||||
conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
||||
)
|
||||
]
|
||||
)
|
||||
self._feature_size = model_channels
|
||||
input_block_chans = [model_channels]
|
||||
ch = model_channels
|
||||
ds = 1
|
||||
|
||||
def get_attention_layer(
|
||||
ch,
|
||||
num_heads,
|
||||
dim_head,
|
||||
depth=1,
|
||||
context_dim=None,
|
||||
use_checkpoint=False,
|
||||
disabled_sa=False,
|
||||
):
|
||||
return SpatialVideoTransformer(
|
||||
ch,
|
||||
num_heads,
|
||||
dim_head,
|
||||
depth=depth,
|
||||
context_dim=context_dim,
|
||||
time_context_dim=time_context_dim,
|
||||
dropout=dropout,
|
||||
ff_in=extra_ff_mix_layer,
|
||||
use_spatial_context=use_spatial_context,
|
||||
merge_strategy=merge_strategy,
|
||||
merge_factor=merge_factor,
|
||||
checkpoint=use_checkpoint,
|
||||
use_linear=use_linear_in_transformer,
|
||||
attn_mode=spatial_transformer_attn_type,
|
||||
disable_self_attn=disabled_sa,
|
||||
disable_temporal_crossattention=disable_temporal_crossattention,
|
||||
max_time_embed_period=max_ddpm_temb_period,
|
||||
)
|
||||
|
||||
def get_resblock(
|
||||
merge_factor,
|
||||
merge_strategy,
|
||||
video_kernel_size,
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_ch,
|
||||
dims,
|
||||
use_checkpoint,
|
||||
use_scale_shift_norm,
|
||||
down=False,
|
||||
up=False,
|
||||
):
|
||||
return VideoResBlock(
|
||||
merge_factor=merge_factor,
|
||||
merge_strategy=merge_strategy,
|
||||
video_kernel_size=video_kernel_size,
|
||||
channels=ch,
|
||||
emb_channels=time_embed_dim,
|
||||
dropout=dropout,
|
||||
out_channels=out_ch,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
down=down,
|
||||
up=up,
|
||||
)
|
||||
|
||||
for level, mult in enumerate(channel_mult):
|
||||
for _ in range(num_res_blocks):
|
||||
layers = [
|
||||
get_resblock(
|
||||
merge_factor=merge_factor,
|
||||
merge_strategy=merge_strategy,
|
||||
video_kernel_size=video_kernel_size,
|
||||
ch=ch,
|
||||
time_embed_dim=time_embed_dim,
|
||||
dropout=dropout,
|
||||
out_ch=mult * model_channels,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
)
|
||||
]
|
||||
ch = mult * model_channels
|
||||
if ds in attention_resolutions:
|
||||
if num_head_channels == -1:
|
||||
dim_head = ch // num_heads
|
||||
else:
|
||||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
|
||||
layers.append(
|
||||
get_attention_layer(
|
||||
ch,
|
||||
num_heads,
|
||||
dim_head,
|
||||
depth=transformer_depth[level],
|
||||
context_dim=context_dim,
|
||||
use_checkpoint=use_checkpoint,
|
||||
disabled_sa=False,
|
||||
)
|
||||
)
|
||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self._feature_size += ch
|
||||
input_block_chans.append(ch)
|
||||
if level != len(channel_mult) - 1:
|
||||
ds *= 2
|
||||
out_ch = ch
|
||||
self.input_blocks.append(
|
||||
TimestepEmbedSequential(
|
||||
get_resblock(
|
||||
merge_factor=merge_factor,
|
||||
merge_strategy=merge_strategy,
|
||||
video_kernel_size=video_kernel_size,
|
||||
ch=ch,
|
||||
time_embed_dim=time_embed_dim,
|
||||
dropout=dropout,
|
||||
out_ch=out_ch,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
down=True,
|
||||
)
|
||||
if resblock_updown
|
||||
else Downsample(
|
||||
ch,
|
||||
conv_resample,
|
||||
dims=dims,
|
||||
out_channels=out_ch,
|
||||
third_down=time_downup,
|
||||
)
|
||||
)
|
||||
)
|
||||
ch = out_ch
|
||||
input_block_chans.append(ch)
|
||||
|
||||
self._feature_size += ch
|
||||
|
||||
if num_head_channels == -1:
|
||||
dim_head = ch // num_heads
|
||||
else:
|
||||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
|
||||
self.middle_block = TimestepEmbedSequential(
|
||||
get_resblock(
|
||||
merge_factor=merge_factor,
|
||||
merge_strategy=merge_strategy,
|
||||
video_kernel_size=video_kernel_size,
|
||||
ch=ch,
|
||||
time_embed_dim=time_embed_dim,
|
||||
out_ch=None,
|
||||
dropout=dropout,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
),
|
||||
get_attention_layer(
|
||||
ch,
|
||||
num_heads,
|
||||
dim_head,
|
||||
depth=transformer_depth_middle,
|
||||
context_dim=context_dim,
|
||||
use_checkpoint=use_checkpoint,
|
||||
),
|
||||
get_resblock(
|
||||
merge_factor=merge_factor,
|
||||
merge_strategy=merge_strategy,
|
||||
video_kernel_size=video_kernel_size,
|
||||
ch=ch,
|
||||
out_ch=None,
|
||||
time_embed_dim=time_embed_dim,
|
||||
dropout=dropout,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
),
|
||||
)
|
||||
self._feature_size += ch
|
||||
|
||||
self.output_blocks = nn.ModuleList([])
|
||||
for level, mult in list(enumerate(channel_mult))[::-1]:
|
||||
for i in range(num_res_blocks + 1):
|
||||
ich = input_block_chans.pop()
|
||||
layers = [
|
||||
get_resblock(
|
||||
merge_factor=merge_factor,
|
||||
merge_strategy=merge_strategy,
|
||||
video_kernel_size=video_kernel_size,
|
||||
ch=ch + ich,
|
||||
time_embed_dim=time_embed_dim,
|
||||
dropout=dropout,
|
||||
out_ch=model_channels * mult,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
)
|
||||
]
|
||||
ch = model_channels * mult
|
||||
if ds in attention_resolutions:
|
||||
if num_head_channels == -1:
|
||||
dim_head = ch // num_heads
|
||||
else:
|
||||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
|
||||
layers.append(
|
||||
get_attention_layer(
|
||||
ch,
|
||||
num_heads,
|
||||
dim_head,
|
||||
depth=transformer_depth[level],
|
||||
context_dim=context_dim,
|
||||
use_checkpoint=use_checkpoint,
|
||||
disabled_sa=False,
|
||||
)
|
||||
)
|
||||
if level and i == num_res_blocks:
|
||||
out_ch = ch
|
||||
ds //= 2
|
||||
layers.append(
|
||||
get_resblock(
|
||||
merge_factor=merge_factor,
|
||||
merge_strategy=merge_strategy,
|
||||
video_kernel_size=video_kernel_size,
|
||||
ch=ch,
|
||||
time_embed_dim=time_embed_dim,
|
||||
dropout=dropout,
|
||||
out_ch=out_ch,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
up=True,
|
||||
)
|
||||
if resblock_updown
|
||||
else Upsample(
|
||||
ch,
|
||||
conv_resample,
|
||||
dims=dims,
|
||||
out_channels=out_ch,
|
||||
third_up=time_downup,
|
||||
)
|
||||
)
|
||||
|
||||
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self._feature_size += ch
|
||||
|
||||
self.out = nn.Sequential(
|
||||
normalization(ch),
|
||||
nn.SiLU(),
|
||||
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: th.Tensor,
|
||||
timesteps: th.Tensor,
|
||||
context: Optional[th.Tensor] = None,
|
||||
y: Optional[th.Tensor] = None,
|
||||
time_context: Optional[th.Tensor] = None,
|
||||
num_video_frames: Optional[int] = None,
|
||||
image_only_indicator: Optional[th.Tensor] = None,
|
||||
):
|
||||
assert (y is not None) == (
|
||||
self.num_classes is not None
|
||||
), "must specify y if and only if the model is class-conditional -> no, relax this TODO"
|
||||
hs = []
|
||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
||||
emb = self.time_embed(t_emb)
|
||||
|
||||
if self.num_classes is not None:
|
||||
assert y.shape[0] == x.shape[0]
|
||||
emb = emb + self.label_emb(y)
|
||||
|
||||
h = x
|
||||
for module in self.input_blocks:
|
||||
h = module(
|
||||
h,
|
||||
emb,
|
||||
context=context,
|
||||
image_only_indicator=image_only_indicator,
|
||||
time_context=time_context,
|
||||
num_video_frames=num_video_frames,
|
||||
)
|
||||
hs.append(h)
|
||||
h = self.middle_block(
|
||||
h,
|
||||
emb,
|
||||
context=context,
|
||||
image_only_indicator=image_only_indicator,
|
||||
time_context=time_context,
|
||||
num_video_frames=num_video_frames,
|
||||
)
|
||||
for module in self.output_blocks:
|
||||
h = th.cat([h, hs.pop()], dim=1)
|
||||
h = module(
|
||||
h,
|
||||
emb,
|
||||
context=context,
|
||||
image_only_indicator=image_only_indicator,
|
||||
time_context=time_context,
|
||||
num_video_frames=num_video_frames,
|
||||
)
|
||||
h = h.type(x.dtype)
|
||||
return self.out(h)
|
||||
Reference in New Issue
Block a user