mirror of
https://github.com/Stability-AI/generative-models.git
synced 2026-01-26 17:04:27 +01:00
Stable Video Diffusion
This commit is contained in:
@@ -1,62 +1,74 @@
|
||||
from typing import Dict, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...util import append_dims, instantiate_from_config
|
||||
from .denoiser_scaling import DenoiserScaling
|
||||
from .discretizer import Discretization
|
||||
|
||||
|
||||
class Denoiser(nn.Module):
|
||||
def __init__(self, weighting_config, scaling_config):
|
||||
def __init__(self, scaling_config: Dict):
|
||||
super().__init__()
|
||||
|
||||
self.weighting = instantiate_from_config(weighting_config)
|
||||
self.scaling = instantiate_from_config(scaling_config)
|
||||
self.scaling: DenoiserScaling = instantiate_from_config(scaling_config)
|
||||
|
||||
def possibly_quantize_sigma(self, sigma):
|
||||
def possibly_quantize_sigma(self, sigma: torch.Tensor) -> torch.Tensor:
|
||||
return sigma
|
||||
|
||||
def possibly_quantize_c_noise(self, c_noise):
|
||||
def possibly_quantize_c_noise(self, c_noise: torch.Tensor) -> torch.Tensor:
|
||||
return c_noise
|
||||
|
||||
def w(self, sigma):
|
||||
return self.weighting(sigma)
|
||||
|
||||
def __call__(self, network, input, sigma, cond):
|
||||
def forward(
|
||||
self,
|
||||
network: nn.Module,
|
||||
input: torch.Tensor,
|
||||
sigma: torch.Tensor,
|
||||
cond: Dict,
|
||||
**additional_model_inputs,
|
||||
) -> torch.Tensor:
|
||||
sigma = self.possibly_quantize_sigma(sigma)
|
||||
sigma_shape = sigma.shape
|
||||
sigma = append_dims(sigma, input.ndim)
|
||||
c_skip, c_out, c_in, c_noise = self.scaling(sigma)
|
||||
c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape))
|
||||
return network(input * c_in, c_noise, cond) * c_out + input * c_skip
|
||||
return (
|
||||
network(input * c_in, c_noise, cond, **additional_model_inputs) * c_out
|
||||
+ input * c_skip
|
||||
)
|
||||
|
||||
|
||||
class DiscreteDenoiser(Denoiser):
|
||||
def __init__(
|
||||
self,
|
||||
weighting_config,
|
||||
scaling_config,
|
||||
num_idx,
|
||||
discretization_config,
|
||||
do_append_zero=False,
|
||||
quantize_c_noise=True,
|
||||
flip=True,
|
||||
scaling_config: Dict,
|
||||
num_idx: int,
|
||||
discretization_config: Dict,
|
||||
do_append_zero: bool = False,
|
||||
quantize_c_noise: bool = True,
|
||||
flip: bool = True,
|
||||
):
|
||||
super().__init__(weighting_config, scaling_config)
|
||||
sigmas = instantiate_from_config(discretization_config)(
|
||||
num_idx, do_append_zero=do_append_zero, flip=flip
|
||||
super().__init__(scaling_config)
|
||||
self.discretization: Discretization = instantiate_from_config(
|
||||
discretization_config
|
||||
)
|
||||
sigmas = self.discretization(num_idx, do_append_zero=do_append_zero, flip=flip)
|
||||
self.register_buffer("sigmas", sigmas)
|
||||
self.quantize_c_noise = quantize_c_noise
|
||||
self.num_idx = num_idx
|
||||
|
||||
def sigma_to_idx(self, sigma):
|
||||
def sigma_to_idx(self, sigma: torch.Tensor) -> torch.Tensor:
|
||||
dists = sigma - self.sigmas[:, None]
|
||||
return dists.abs().argmin(dim=0).view(sigma.shape)
|
||||
|
||||
def idx_to_sigma(self, idx):
|
||||
def idx_to_sigma(self, idx: Union[torch.Tensor, int]) -> torch.Tensor:
|
||||
return self.sigmas[idx]
|
||||
|
||||
def possibly_quantize_sigma(self, sigma):
|
||||
def possibly_quantize_sigma(self, sigma: torch.Tensor) -> torch.Tensor:
|
||||
return self.idx_to_sigma(self.sigma_to_idx(sigma))
|
||||
|
||||
def possibly_quantize_c_noise(self, c_noise):
|
||||
def possibly_quantize_c_noise(self, c_noise: torch.Tensor) -> torch.Tensor:
|
||||
if self.quantize_c_noise:
|
||||
return self.sigma_to_idx(c_noise)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user