Stable Video Diffusion

This commit is contained in:
Tim Dockhorn
2023-11-21 10:40:21 -08:00
parent 477d8b9a77
commit 059d8e9cd9
59 changed files with 5463 additions and 1691 deletions

View File

@@ -1,5 +1,6 @@
import math
from contextlib import contextmanager
from typing import Any, Dict, List, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import pytorch_lightning as pl
import torch
@@ -8,15 +9,11 @@ from safetensors.torch import load_file as load_safetensors
from torch.optim.lr_scheduler import LambdaLR
from ..modules import UNCONDITIONAL_CONFIG
from ..modules.autoencoding.temporal_ae import VideoDecoder
from ..modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
from ..modules.ema import LitEma
from ..util import (
default,
disabled_train,
get_obj_from_str,
instantiate_from_config,
log_txt_as_img,
)
from ..util import (default, disabled_train, get_obj_from_str,
instantiate_from_config, log_txt_as_img)
class DiffusionEngine(pl.LightningModule):
@@ -40,6 +37,7 @@ class DiffusionEngine(pl.LightningModule):
log_keys: Union[List, None] = None,
no_cond_log: bool = False,
compile_model: bool = False,
en_and_decode_n_samples_a_time: Optional[int] = None,
):
super().__init__()
self.log_keys = log_keys
@@ -82,6 +80,8 @@ class DiffusionEngine(pl.LightningModule):
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path)
self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time
def init_from_ckpt(
self,
path: str,
@@ -117,14 +117,35 @@ class DiffusionEngine(pl.LightningModule):
@torch.no_grad()
def decode_first_stage(self, z):
z = 1.0 / self.scale_factor * z
n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0])
n_rounds = math.ceil(z.shape[0] / n_samples)
all_out = []
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
out = self.first_stage_model.decode(z)
for n in range(n_rounds):
if isinstance(self.first_stage_model.decoder, VideoDecoder):
kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])}
else:
kwargs = {}
out = self.first_stage_model.decode(
z[n * n_samples : (n + 1) * n_samples], **kwargs
)
all_out.append(out)
out = torch.cat(all_out, dim=0)
return out
@torch.no_grad()
def encode_first_stage(self, x):
n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0])
n_rounds = math.ceil(x.shape[0] / n_samples)
all_out = []
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
z = self.first_stage_model.encode(x)
for n in range(n_rounds):
out = self.first_stage_model.encode(
x[n * n_samples : (n + 1) * n_samples]
)
all_out.append(out)
z = torch.cat(all_out, dim=0)
z = self.scale_factor * z
return z