Stable Video Diffusion

This commit is contained in:
Tim Dockhorn
2023-11-21 10:40:21 -08:00
parent 477d8b9a77
commit 059d8e9cd9
59 changed files with 5463 additions and 1691 deletions

View File

@@ -1,5 +1,5 @@
"""
adopted from
partially adopted from
https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
and
https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
@@ -10,10 +10,11 @@ thanks!
"""
import math
from typing import Optional
import torch
import torch.nn as nn
from einops import repeat
from einops import rearrange, repeat
def make_beta_schedule(
@@ -306,3 +307,63 @@ def avg_pool_nd(dims, *args, **kwargs):
elif dims == 3:
return nn.AvgPool3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
class AlphaBlender(nn.Module):
strategies = ["learned", "fixed", "learned_with_images"]
def __init__(
self,
alpha: float,
merge_strategy: str = "learned_with_images",
rearrange_pattern: str = "b t -> (b t) 1 1",
):
super().__init__()
self.merge_strategy = merge_strategy
self.rearrange_pattern = rearrange_pattern
assert (
merge_strategy in self.strategies
), f"merge_strategy needs to be in {self.strategies}"
if self.merge_strategy == "fixed":
self.register_buffer("mix_factor", torch.Tensor([alpha]))
elif (
self.merge_strategy == "learned"
or self.merge_strategy == "learned_with_images"
):
self.register_parameter(
"mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
)
else:
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
def get_alpha(self, image_only_indicator: torch.Tensor) -> torch.Tensor:
if self.merge_strategy == "fixed":
alpha = self.mix_factor
elif self.merge_strategy == "learned":
alpha = torch.sigmoid(self.mix_factor)
elif self.merge_strategy == "learned_with_images":
assert image_only_indicator is not None, "need image_only_indicator ..."
alpha = torch.where(
image_only_indicator.bool(),
torch.ones(1, 1, device=image_only_indicator.device),
rearrange(torch.sigmoid(self.mix_factor), "... -> ... 1"),
)
alpha = rearrange(alpha, self.rearrange_pattern)
else:
raise NotImplementedError
return alpha
def forward(
self,
x_spatial: torch.Tensor,
x_temporal: torch.Tensor,
image_only_indicator: Optional[torch.Tensor] = None,
) -> torch.Tensor:
alpha = self.get_alpha(image_only_indicator)
x = (
alpha.to(x_spatial.dtype) * x_spatial
+ (1.0 - alpha).to(x_spatial.dtype) * x_temporal
)
return x