mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
add p2 loss reweighting for decoder training as an option
This commit is contained in:
@@ -11,7 +11,7 @@ import torch.nn.functional as F
|
||||
from torch import nn, einsum
|
||||
import torchvision.transforms as T
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops import rearrange, repeat, reduce
|
||||
from einops.layers.torch import Rearrange
|
||||
from einops_exts import rearrange_many, repeat_many, check_shape
|
||||
from einops_exts.torch import EinopsToAndFrom
|
||||
@@ -379,7 +379,7 @@ def sigmoid_beta_schedule(timesteps):
|
||||
|
||||
|
||||
class BaseGaussianDiffusion(nn.Module):
|
||||
def __init__(self, *, beta_schedule, timesteps, loss_type):
|
||||
def __init__(self, *, beta_schedule, timesteps, loss_type, p2_loss_weight_gamma = 0., p2_loss_weight_k = 1):
|
||||
super().__init__()
|
||||
|
||||
if beta_schedule == "cosine":
|
||||
@@ -444,6 +444,11 @@ class BaseGaussianDiffusion(nn.Module):
|
||||
register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
||||
register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
|
||||
|
||||
# p2 loss reweighting
|
||||
|
||||
self.has_p2_loss_reweighting = p2_loss_weight_gamma > 0.
|
||||
register_buffer('p2_loss_weight', (p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod)) ** -p2_loss_weight_gamma)
|
||||
|
||||
def q_posterior(self, x_start, x_t, t):
|
||||
posterior_mean = (
|
||||
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
|
||||
@@ -1755,12 +1760,16 @@ class Decoder(BaseGaussianDiffusion):
|
||||
unconditional = False,
|
||||
auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader
|
||||
use_dynamic_thres = False, # from the Imagen paper
|
||||
dynamic_thres_percentile = 0.9
|
||||
dynamic_thres_percentile = 0.9,
|
||||
p2_loss_weight_gamma = 0., # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended
|
||||
p2_loss_weight_k = 1
|
||||
):
|
||||
super().__init__(
|
||||
beta_schedule = beta_schedule,
|
||||
timesteps = timesteps,
|
||||
loss_type = loss_type
|
||||
loss_type = loss_type,
|
||||
p2_loss_weight_gamma = p2_loss_weight_gamma,
|
||||
p2_loss_weight_k = p2_loss_weight_k
|
||||
)
|
||||
|
||||
self.unconditional = unconditional
|
||||
@@ -2028,7 +2037,13 @@ class Decoder(BaseGaussianDiffusion):
|
||||
|
||||
target = noise if not predict_x_start else x_start
|
||||
|
||||
loss = self.loss_fn(pred, target)
|
||||
loss = self.loss_fn(pred, target, reduction = 'none')
|
||||
loss = reduce(loss, 'b ... -> b (...)', 'mean')
|
||||
|
||||
if self.has_p2_loss_reweighting:
|
||||
loss = loss * extract(self.p2_loss_weight, times, loss.shape)
|
||||
|
||||
loss = loss.mean()
|
||||
|
||||
if not learned_variance:
|
||||
# return simple loss if not using learned variance
|
||||
|
||||
Reference in New Issue
Block a user