mirror of
https://github.com/Stability-AI/generative-models.git
synced 2025-12-20 06:44:22 +01:00
Stable Video Diffusion
This commit is contained in:
@@ -1,31 +1,33 @@
|
||||
from functools import partial
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from ...util import default, instantiate_from_config
|
||||
from ...util import append_dims, default
|
||||
|
||||
logpy = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VanillaCFG:
|
||||
"""
|
||||
implements parallelized CFG
|
||||
"""
|
||||
class Guider(ABC):
|
||||
@abstractmethod
|
||||
def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor:
|
||||
pass
|
||||
|
||||
def __init__(self, scale, dyn_thresh_config=None):
|
||||
scale_schedule = lambda scale, sigma: scale # independent of step
|
||||
self.scale_schedule = partial(scale_schedule, scale)
|
||||
self.dyn_thresh = instantiate_from_config(
|
||||
default(
|
||||
dyn_thresh_config,
|
||||
{
|
||||
"target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"
|
||||
},
|
||||
)
|
||||
)
|
||||
def prepare_inputs(
|
||||
self, x: torch.Tensor, s: float, c: Dict, uc: Dict
|
||||
) -> Tuple[torch.Tensor, float, Dict]:
|
||||
pass
|
||||
|
||||
def __call__(self, x, sigma):
|
||||
|
||||
class VanillaCFG(Guider):
|
||||
def __init__(self, scale: float):
|
||||
self.scale = scale
|
||||
|
||||
def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
|
||||
x_u, x_c = x.chunk(2)
|
||||
scale_value = self.scale_schedule(sigma)
|
||||
x_pred = self.dyn_thresh(x_u, x_c, scale_value)
|
||||
x_pred = x_u + self.scale * (x_c - x_u)
|
||||
return x_pred
|
||||
|
||||
def prepare_inputs(self, x, s, c, uc):
|
||||
@@ -40,14 +42,58 @@ class VanillaCFG:
|
||||
return torch.cat([x] * 2), torch.cat([s] * 2), c_out
|
||||
|
||||
|
||||
class IdentityGuider:
|
||||
def __call__(self, x, sigma):
|
||||
class IdentityGuider(Guider):
|
||||
def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor:
|
||||
return x
|
||||
|
||||
def prepare_inputs(self, x, s, c, uc):
|
||||
def prepare_inputs(
|
||||
self, x: torch.Tensor, s: float, c: Dict, uc: Dict
|
||||
) -> Tuple[torch.Tensor, float, Dict]:
|
||||
c_out = dict()
|
||||
|
||||
for k in c:
|
||||
c_out[k] = c[k]
|
||||
|
||||
return x, s, c_out
|
||||
|
||||
|
||||
class LinearPredictionGuider(Guider):
|
||||
def __init__(
|
||||
self,
|
||||
max_scale: float,
|
||||
num_frames: int,
|
||||
min_scale: float = 1.0,
|
||||
additional_cond_keys: Optional[Union[List[str], str]] = None,
|
||||
):
|
||||
self.min_scale = min_scale
|
||||
self.max_scale = max_scale
|
||||
self.num_frames = num_frames
|
||||
self.scale = torch.linspace(min_scale, max_scale, num_frames).unsqueeze(0)
|
||||
|
||||
additional_cond_keys = default(additional_cond_keys, [])
|
||||
if isinstance(additional_cond_keys, str):
|
||||
additional_cond_keys = [additional_cond_keys]
|
||||
self.additional_cond_keys = additional_cond_keys
|
||||
|
||||
def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
|
||||
x_u, x_c = x.chunk(2)
|
||||
|
||||
x_u = rearrange(x_u, "(b t) ... -> b t ...", t=self.num_frames)
|
||||
x_c = rearrange(x_c, "(b t) ... -> b t ...", t=self.num_frames)
|
||||
scale = repeat(self.scale, "1 t -> b t", b=x_u.shape[0])
|
||||
scale = append_dims(scale, x_u.ndim).to(x_u.device)
|
||||
|
||||
return rearrange(x_u + scale * (x_c - x_u), "b t ... -> (b t) ...")
|
||||
|
||||
def prepare_inputs(
|
||||
self, x: torch.Tensor, s: torch.Tensor, c: dict, uc: dict
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, dict]:
|
||||
c_out = dict()
|
||||
|
||||
for k in c:
|
||||
if k in ["vector", "crossattn", "concat"] + self.additional_cond_keys:
|
||||
c_out[k] = torch.cat((uc[k], c[k]), 0)
|
||||
else:
|
||||
assert c[k] == uc[k]
|
||||
c_out[k] = c[k]
|
||||
return torch.cat([x] * 2), torch.cat([s] * 2), c_out
|
||||
|
||||
Reference in New Issue
Block a user