mirror of
https://github.com/Stability-AI/generative-models.git
synced 2025-12-19 14:24:21 +01:00
Stable Video Diffusion
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import math
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
@@ -8,15 +9,11 @@ from safetensors.torch import load_file as load_safetensors
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
|
||||
from ..modules import UNCONDITIONAL_CONFIG
|
||||
from ..modules.autoencoding.temporal_ae import VideoDecoder
|
||||
from ..modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
|
||||
from ..modules.ema import LitEma
|
||||
from ..util import (
|
||||
default,
|
||||
disabled_train,
|
||||
get_obj_from_str,
|
||||
instantiate_from_config,
|
||||
log_txt_as_img,
|
||||
)
|
||||
from ..util import (default, disabled_train, get_obj_from_str,
|
||||
instantiate_from_config, log_txt_as_img)
|
||||
|
||||
|
||||
class DiffusionEngine(pl.LightningModule):
|
||||
@@ -40,6 +37,7 @@ class DiffusionEngine(pl.LightningModule):
|
||||
log_keys: Union[List, None] = None,
|
||||
no_cond_log: bool = False,
|
||||
compile_model: bool = False,
|
||||
en_and_decode_n_samples_a_time: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.log_keys = log_keys
|
||||
@@ -82,6 +80,8 @@ class DiffusionEngine(pl.LightningModule):
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path)
|
||||
|
||||
self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time
|
||||
|
||||
def init_from_ckpt(
|
||||
self,
|
||||
path: str,
|
||||
@@ -117,14 +117,35 @@ class DiffusionEngine(pl.LightningModule):
|
||||
@torch.no_grad()
|
||||
def decode_first_stage(self, z):
|
||||
z = 1.0 / self.scale_factor * z
|
||||
n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0])
|
||||
|
||||
n_rounds = math.ceil(z.shape[0] / n_samples)
|
||||
all_out = []
|
||||
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
|
||||
out = self.first_stage_model.decode(z)
|
||||
for n in range(n_rounds):
|
||||
if isinstance(self.first_stage_model.decoder, VideoDecoder):
|
||||
kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])}
|
||||
else:
|
||||
kwargs = {}
|
||||
out = self.first_stage_model.decode(
|
||||
z[n * n_samples : (n + 1) * n_samples], **kwargs
|
||||
)
|
||||
all_out.append(out)
|
||||
out = torch.cat(all_out, dim=0)
|
||||
return out
|
||||
|
||||
@torch.no_grad()
|
||||
def encode_first_stage(self, x):
|
||||
n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0])
|
||||
n_rounds = math.ceil(x.shape[0] / n_samples)
|
||||
all_out = []
|
||||
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
|
||||
z = self.first_stage_model.encode(x)
|
||||
for n in range(n_rounds):
|
||||
out = self.first_stage_model.encode(
|
||||
x[n * n_samples : (n + 1) * n_samples]
|
||||
)
|
||||
all_out.append(out)
|
||||
z = torch.cat(all_out, dim=0)
|
||||
z = self.scale_factor * z
|
||||
return z
|
||||
|
||||
|
||||
Reference in New Issue
Block a user