mirror of
https://github.com/Stability-AI/generative-models.git
synced 2025-12-20 06:44:22 +01:00
Stable Video Diffusion
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import math
|
||||
from contextlib import nullcontext
|
||||
from functools import partial
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
@@ -10,27 +11,17 @@ import torch.nn as nn
|
||||
from einops import rearrange, repeat
|
||||
from omegaconf import ListConfig
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
from transformers import (
|
||||
ByT5Tokenizer,
|
||||
CLIPTextModel,
|
||||
CLIPTokenizer,
|
||||
T5EncoderModel,
|
||||
T5Tokenizer,
|
||||
)
|
||||
from transformers import (ByT5Tokenizer, CLIPTextModel, CLIPTokenizer,
|
||||
T5EncoderModel, T5Tokenizer)
|
||||
|
||||
from ...modules.autoencoding.regularizers import DiagonalGaussianRegularizer
|
||||
from ...modules.diffusionmodules.model import Encoder
|
||||
from ...modules.diffusionmodules.openaimodel import Timestep
|
||||
from ...modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule
|
||||
from ...modules.diffusionmodules.util import (extract_into_tensor,
|
||||
make_beta_schedule)
|
||||
from ...modules.distributions.distributions import DiagonalGaussianDistribution
|
||||
from ...util import (
|
||||
autocast,
|
||||
count_params,
|
||||
default,
|
||||
disabled_train,
|
||||
expand_dims_like,
|
||||
instantiate_from_config,
|
||||
)
|
||||
from ...util import (append_dims, autocast, count_params, default,
|
||||
disabled_train, expand_dims_like, instantiate_from_config)
|
||||
|
||||
|
||||
class AbstractEmbModel(nn.Module):
|
||||
@@ -173,7 +164,11 @@ class GeneralConditioner(nn.Module):
|
||||
return output
|
||||
|
||||
def get_unconditional_conditioning(
|
||||
self, batch_c, batch_uc=None, force_uc_zero_embeddings=None
|
||||
self,
|
||||
batch_c: Dict,
|
||||
batch_uc: Optional[Dict] = None,
|
||||
force_uc_zero_embeddings: Optional[List[str]] = None,
|
||||
force_cond_zero_embeddings: Optional[List[str]] = None,
|
||||
):
|
||||
if force_uc_zero_embeddings is None:
|
||||
force_uc_zero_embeddings = []
|
||||
@@ -181,7 +176,7 @@ class GeneralConditioner(nn.Module):
|
||||
for embedder in self.embedders:
|
||||
ucg_rates.append(embedder.ucg_rate)
|
||||
embedder.ucg_rate = 0.0
|
||||
c = self(batch_c)
|
||||
c = self(batch_c, force_cond_zero_embeddings)
|
||||
uc = self(batch_c if batch_uc is None else batch_uc, force_uc_zero_embeddings)
|
||||
|
||||
for embedder, rate in zip(self.embedders, ucg_rates):
|
||||
@@ -201,12 +196,6 @@ class InceptionV3(nn.Module):
|
||||
self.model = inception.InceptionV3(normalize_input=normalize_input, **kwargs)
|
||||
|
||||
def forward(self, inp):
|
||||
# inp = kornia.geometry.resize(inp, (299, 299),
|
||||
# interpolation='bicubic',
|
||||
# align_corners=False,
|
||||
# antialias=True)
|
||||
# inp = inp.clamp(min=-1, max=1)
|
||||
|
||||
outp = self.model(inp)
|
||||
|
||||
if len(outp) == 1:
|
||||
@@ -277,7 +266,6 @@ class FrozenT5Embedder(AbstractEmbModel):
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
# @autocast
|
||||
def forward(self, text):
|
||||
batch_encoding = self.tokenizer(
|
||||
text,
|
||||
@@ -597,11 +585,12 @@ class FrozenOpenCLIPImageEmbedder(AbstractEmbModel):
|
||||
repeat_to_max_len=False,
|
||||
num_image_crops=0,
|
||||
output_tokens=False,
|
||||
init_device=None,
|
||||
):
|
||||
super().__init__()
|
||||
model, _, _ = open_clip.create_model_and_transforms(
|
||||
arch,
|
||||
device=torch.device("cpu"),
|
||||
device=torch.device(default(init_device, "cpu")),
|
||||
pretrained=version,
|
||||
)
|
||||
del model.transformer
|
||||
@@ -914,7 +903,6 @@ class LowScaleEncoder(nn.Module):
|
||||
z = self.q_sample(z, noise_level)
|
||||
if self.out_size is not None:
|
||||
z = torch.nn.functional.interpolate(z, size=self.out_size, mode="nearest")
|
||||
# z = z.repeat_interleave(2, -2).repeat_interleave(2, -1)
|
||||
return z, noise_level
|
||||
|
||||
def decode(self, z):
|
||||
@@ -958,3 +946,101 @@ class GaussianEncoder(Encoder, AbstractEmbModel):
|
||||
if self.flatten_output:
|
||||
z = rearrange(z, "b c h w -> b (h w ) c")
|
||||
return log, z
|
||||
|
||||
|
||||
class VideoPredictionEmbedderWithEncoder(AbstractEmbModel):
|
||||
def __init__(
|
||||
self,
|
||||
n_cond_frames: int,
|
||||
n_copies: int,
|
||||
encoder_config: dict,
|
||||
sigma_sampler_config: Optional[dict] = None,
|
||||
sigma_cond_config: Optional[dict] = None,
|
||||
is_ae: bool = False,
|
||||
scale_factor: float = 1.0,
|
||||
disable_encoder_autocast: bool = False,
|
||||
en_and_decode_n_samples_a_time: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.n_cond_frames = n_cond_frames
|
||||
self.n_copies = n_copies
|
||||
self.encoder = instantiate_from_config(encoder_config)
|
||||
self.sigma_sampler = (
|
||||
instantiate_from_config(sigma_sampler_config)
|
||||
if sigma_sampler_config is not None
|
||||
else None
|
||||
)
|
||||
self.sigma_cond = (
|
||||
instantiate_from_config(sigma_cond_config)
|
||||
if sigma_cond_config is not None
|
||||
else None
|
||||
)
|
||||
self.is_ae = is_ae
|
||||
self.scale_factor = scale_factor
|
||||
self.disable_encoder_autocast = disable_encoder_autocast
|
||||
self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time
|
||||
|
||||
def forward(
|
||||
self, vid: torch.Tensor
|
||||
) -> Union[
|
||||
torch.Tensor,
|
||||
Tuple[torch.Tensor, torch.Tensor],
|
||||
Tuple[torch.Tensor, dict],
|
||||
Tuple[Tuple[torch.Tensor, torch.Tensor], dict],
|
||||
]:
|
||||
if self.sigma_sampler is not None:
|
||||
b = vid.shape[0] // self.n_cond_frames
|
||||
sigmas = self.sigma_sampler(b).to(vid.device)
|
||||
if self.sigma_cond is not None:
|
||||
sigma_cond = self.sigma_cond(sigmas)
|
||||
sigma_cond = repeat(sigma_cond, "b d -> (b t) d", t=self.n_copies)
|
||||
sigmas = repeat(sigmas, "b -> (b t)", t=self.n_cond_frames)
|
||||
noise = torch.randn_like(vid)
|
||||
vid = vid + noise * append_dims(sigmas, vid.ndim)
|
||||
|
||||
with torch.autocast("cuda", enabled=not self.disable_encoder_autocast):
|
||||
n_samples = (
|
||||
self.en_and_decode_n_samples_a_time
|
||||
if self.en_and_decode_n_samples_a_time is not None
|
||||
else vid.shape[0]
|
||||
)
|
||||
n_rounds = math.ceil(vid.shape[0] / n_samples)
|
||||
all_out = []
|
||||
for n in range(n_rounds):
|
||||
if self.is_ae:
|
||||
out = self.encoder.encode(vid[n * n_samples : (n + 1) * n_samples])
|
||||
else:
|
||||
out = self.encoder(vid[n * n_samples : (n + 1) * n_samples])
|
||||
all_out.append(out)
|
||||
|
||||
vid = torch.cat(all_out, dim=0)
|
||||
vid *= self.scale_factor
|
||||
|
||||
vid = rearrange(vid, "(b t) c h w -> b () (t c) h w", t=self.n_cond_frames)
|
||||
vid = repeat(vid, "b 1 c h w -> (b t) c h w", t=self.n_copies)
|
||||
|
||||
return_val = (vid, sigma_cond) if self.sigma_cond is not None else vid
|
||||
|
||||
return return_val
|
||||
|
||||
|
||||
class FrozenOpenCLIPImagePredictionEmbedder(AbstractEmbModel):
|
||||
def __init__(
|
||||
self,
|
||||
open_clip_embedding_config: Dict,
|
||||
n_cond_frames: int,
|
||||
n_copies: int,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.n_cond_frames = n_cond_frames
|
||||
self.n_copies = n_copies
|
||||
self.open_clip = instantiate_from_config(open_clip_embedding_config)
|
||||
|
||||
def forward(self, vid):
|
||||
vid = self.open_clip(vid)
|
||||
vid = rearrange(vid, "(b t) d -> b t d", t=self.n_cond_frames)
|
||||
vid = repeat(vid, "b t d -> (b s) t d", s=self.n_copies)
|
||||
|
||||
return vid
|
||||
|
||||
Reference in New Issue
Block a user