mirror of
https://github.com/Stability-AI/generative-models.git
synced 2025-12-20 06:44:22 +01:00
Stable Video Diffusion
This commit is contained in:
@@ -1,31 +1,34 @@
|
||||
from typing import List, Optional, Union
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from omegaconf import ListConfig
|
||||
|
||||
from ...util import append_dims, instantiate_from_config
|
||||
from ...modules.autoencoding.lpips.loss.lpips import LPIPS
|
||||
from ...modules.encoders.modules import GeneralConditioner
|
||||
from ...util import append_dims, instantiate_from_config
|
||||
from .denoiser import Denoiser
|
||||
|
||||
|
||||
class StandardDiffusionLoss(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
sigma_sampler_config,
|
||||
type="l2",
|
||||
offset_noise_level=0.0,
|
||||
batch2model_keys: Optional[Union[str, List[str], ListConfig]] = None,
|
||||
sigma_sampler_config: dict,
|
||||
loss_weighting_config: dict,
|
||||
loss_type: str = "l2",
|
||||
offset_noise_level: float = 0.0,
|
||||
batch2model_keys: Optional[Union[str, List[str]]] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
assert type in ["l2", "l1", "lpips"]
|
||||
assert loss_type in ["l2", "l1", "lpips"]
|
||||
|
||||
self.sigma_sampler = instantiate_from_config(sigma_sampler_config)
|
||||
self.loss_weighting = instantiate_from_config(loss_weighting_config)
|
||||
|
||||
self.type = type
|
||||
self.loss_type = loss_type
|
||||
self.offset_noise_level = offset_noise_level
|
||||
|
||||
if type == "lpips":
|
||||
if loss_type == "lpips":
|
||||
self.lpips = LPIPS().eval()
|
||||
|
||||
if not batch2model_keys:
|
||||
@@ -36,34 +39,67 @@ class StandardDiffusionLoss(nn.Module):
|
||||
|
||||
self.batch2model_keys = set(batch2model_keys)
|
||||
|
||||
def __call__(self, network, denoiser, conditioner, input, batch):
|
||||
def get_noised_input(
|
||||
self, sigmas_bc: torch.Tensor, noise: torch.Tensor, input: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
noised_input = input + noise * sigmas_bc
|
||||
return noised_input
|
||||
|
||||
def forward(
|
||||
self,
|
||||
network: nn.Module,
|
||||
denoiser: Denoiser,
|
||||
conditioner: GeneralConditioner,
|
||||
input: torch.Tensor,
|
||||
batch: Dict,
|
||||
) -> torch.Tensor:
|
||||
cond = conditioner(batch)
|
||||
return self._forward(network, denoiser, cond, input, batch)
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
network: nn.Module,
|
||||
denoiser: Denoiser,
|
||||
cond: Dict,
|
||||
input: torch.Tensor,
|
||||
batch: Dict,
|
||||
) -> Tuple[torch.Tensor, Dict]:
|
||||
additional_model_inputs = {
|
||||
key: batch[key] for key in self.batch2model_keys.intersection(batch)
|
||||
}
|
||||
sigmas = self.sigma_sampler(input.shape[0]).to(input)
|
||||
|
||||
sigmas = self.sigma_sampler(input.shape[0]).to(input.device)
|
||||
noise = torch.randn_like(input)
|
||||
if self.offset_noise_level > 0.0:
|
||||
noise = noise + self.offset_noise_level * append_dims(
|
||||
torch.randn(input.shape[0], device=input.device), input.ndim
|
||||
offset_shape = (
|
||||
(input.shape[0], 1, input.shape[2])
|
||||
if self.n_frames is not None
|
||||
else (input.shape[0], input.shape[1])
|
||||
)
|
||||
noised_input = input + noise * append_dims(sigmas, input.ndim)
|
||||
noise = noise + self.offset_noise_level * append_dims(
|
||||
torch.randn(offset_shape, device=input.device),
|
||||
input.ndim,
|
||||
)
|
||||
sigmas_bc = append_dims(sigmas, input.ndim)
|
||||
noised_input = self.get_noised_input(sigmas_bc, noise, input)
|
||||
|
||||
model_output = denoiser(
|
||||
network, noised_input, sigmas, cond, **additional_model_inputs
|
||||
)
|
||||
w = append_dims(denoiser.w(sigmas), input.ndim)
|
||||
w = append_dims(self.loss_weighting(sigmas), input.ndim)
|
||||
return self.get_loss(model_output, input, w)
|
||||
|
||||
def get_loss(self, model_output, target, w):
|
||||
if self.type == "l2":
|
||||
if self.loss_type == "l2":
|
||||
return torch.mean(
|
||||
(w * (model_output - target) ** 2).reshape(target.shape[0], -1), 1
|
||||
)
|
||||
elif self.type == "l1":
|
||||
elif self.loss_type == "l1":
|
||||
return torch.mean(
|
||||
(w * (model_output - target).abs()).reshape(target.shape[0], -1), 1
|
||||
)
|
||||
elif self.type == "lpips":
|
||||
elif self.loss_type == "lpips":
|
||||
loss = self.lpips(model_output, target).reshape(-1)
|
||||
return loss
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown loss type {self.loss_type}")
|
||||
|
||||
Reference in New Issue
Block a user