mirror of
https://github.com/Stability-AI/generative-models.git
synced 2025-12-20 06:44:22 +01:00
Stable Video Diffusion
This commit is contained in:
73
sgm/modules/autoencoding/losses/lpips.py
Normal file
73
sgm/modules/autoencoding/losses/lpips.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ....util import default, instantiate_from_config
|
||||
from ..lpips.loss.lpips import LPIPS
|
||||
|
||||
|
||||
class LatentLPIPS(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
decoder_config,
|
||||
perceptual_weight=1.0,
|
||||
latent_weight=1.0,
|
||||
scale_input_to_tgt_size=False,
|
||||
scale_tgt_to_input_size=False,
|
||||
perceptual_weight_on_inputs=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.scale_input_to_tgt_size = scale_input_to_tgt_size
|
||||
self.scale_tgt_to_input_size = scale_tgt_to_input_size
|
||||
self.init_decoder(decoder_config)
|
||||
self.perceptual_loss = LPIPS().eval()
|
||||
self.perceptual_weight = perceptual_weight
|
||||
self.latent_weight = latent_weight
|
||||
self.perceptual_weight_on_inputs = perceptual_weight_on_inputs
|
||||
|
||||
def init_decoder(self, config):
|
||||
self.decoder = instantiate_from_config(config)
|
||||
if hasattr(self.decoder, "encoder"):
|
||||
del self.decoder.encoder
|
||||
|
||||
def forward(self, latent_inputs, latent_predictions, image_inputs, split="train"):
|
||||
log = dict()
|
||||
loss = (latent_inputs - latent_predictions) ** 2
|
||||
log[f"{split}/latent_l2_loss"] = loss.mean().detach()
|
||||
image_reconstructions = None
|
||||
if self.perceptual_weight > 0.0:
|
||||
image_reconstructions = self.decoder.decode(latent_predictions)
|
||||
image_targets = self.decoder.decode(latent_inputs)
|
||||
perceptual_loss = self.perceptual_loss(
|
||||
image_targets.contiguous(), image_reconstructions.contiguous()
|
||||
)
|
||||
loss = (
|
||||
self.latent_weight * loss.mean()
|
||||
+ self.perceptual_weight * perceptual_loss.mean()
|
||||
)
|
||||
log[f"{split}/perceptual_loss"] = perceptual_loss.mean().detach()
|
||||
|
||||
if self.perceptual_weight_on_inputs > 0.0:
|
||||
image_reconstructions = default(
|
||||
image_reconstructions, self.decoder.decode(latent_predictions)
|
||||
)
|
||||
if self.scale_input_to_tgt_size:
|
||||
image_inputs = torch.nn.functional.interpolate(
|
||||
image_inputs,
|
||||
image_reconstructions.shape[2:],
|
||||
mode="bicubic",
|
||||
antialias=True,
|
||||
)
|
||||
elif self.scale_tgt_to_input_size:
|
||||
image_reconstructions = torch.nn.functional.interpolate(
|
||||
image_reconstructions,
|
||||
image_inputs.shape[2:],
|
||||
mode="bicubic",
|
||||
antialias=True,
|
||||
)
|
||||
|
||||
perceptual_loss2 = self.perceptual_loss(
|
||||
image_inputs.contiguous(), image_reconstructions.contiguous()
|
||||
)
|
||||
loss = loss + self.perceptual_weight_on_inputs * perceptual_loss2.mean()
|
||||
log[f"{split}/perceptual_loss_on_inputs"] = perceptual_loss2.mean().detach()
|
||||
return loss, log
|
||||
Reference in New Issue
Block a user