mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
add p2 loss reweighting for decoder training as an option
This commit is contained in:
10
README.md
10
README.md
@@ -1207,4 +1207,14 @@ This library would not have gotten to this working state without the help of
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@article{Choi2022PerceptionPT,
|
||||||
|
title = {Perception Prioritized Training of Diffusion Models},
|
||||||
|
author = {Jooyoung Choi and Jungbeom Lee and Chaehun Shin and Sungwon Kim and Hyunwoo J. Kim and Sung-Hoon Yoon},
|
||||||
|
journal = {ArXiv},
|
||||||
|
year = {2022},
|
||||||
|
volume = {abs/2204.00227}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>
|
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import torch.nn.functional as F
|
|||||||
from torch import nn, einsum
|
from torch import nn, einsum
|
||||||
import torchvision.transforms as T
|
import torchvision.transforms as T
|
||||||
|
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat, reduce
|
||||||
from einops.layers.torch import Rearrange
|
from einops.layers.torch import Rearrange
|
||||||
from einops_exts import rearrange_many, repeat_many, check_shape
|
from einops_exts import rearrange_many, repeat_many, check_shape
|
||||||
from einops_exts.torch import EinopsToAndFrom
|
from einops_exts.torch import EinopsToAndFrom
|
||||||
@@ -379,7 +379,7 @@ def sigmoid_beta_schedule(timesteps):
|
|||||||
|
|
||||||
|
|
||||||
class BaseGaussianDiffusion(nn.Module):
|
class BaseGaussianDiffusion(nn.Module):
|
||||||
def __init__(self, *, beta_schedule, timesteps, loss_type):
|
def __init__(self, *, beta_schedule, timesteps, loss_type, p2_loss_weight_gamma = 0., p2_loss_weight_k = 1):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if beta_schedule == "cosine":
|
if beta_schedule == "cosine":
|
||||||
@@ -444,6 +444,11 @@ class BaseGaussianDiffusion(nn.Module):
|
|||||||
register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
||||||
register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
|
register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
|
||||||
|
|
||||||
|
# p2 loss reweighting
|
||||||
|
|
||||||
|
self.has_p2_loss_reweighting = p2_loss_weight_gamma > 0.
|
||||||
|
register_buffer('p2_loss_weight', (p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod)) ** -p2_loss_weight_gamma)
|
||||||
|
|
||||||
def q_posterior(self, x_start, x_t, t):
|
def q_posterior(self, x_start, x_t, t):
|
||||||
posterior_mean = (
|
posterior_mean = (
|
||||||
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
|
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
|
||||||
@@ -1755,12 +1760,16 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
unconditional = False,
|
unconditional = False,
|
||||||
auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader
|
auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader
|
||||||
use_dynamic_thres = False, # from the Imagen paper
|
use_dynamic_thres = False, # from the Imagen paper
|
||||||
dynamic_thres_percentile = 0.9
|
dynamic_thres_percentile = 0.9,
|
||||||
|
p2_loss_weight_gamma = 0., # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended
|
||||||
|
p2_loss_weight_k = 1
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
beta_schedule = beta_schedule,
|
beta_schedule = beta_schedule,
|
||||||
timesteps = timesteps,
|
timesteps = timesteps,
|
||||||
loss_type = loss_type
|
loss_type = loss_type,
|
||||||
|
p2_loss_weight_gamma = p2_loss_weight_gamma,
|
||||||
|
p2_loss_weight_k = p2_loss_weight_k
|
||||||
)
|
)
|
||||||
|
|
||||||
self.unconditional = unconditional
|
self.unconditional = unconditional
|
||||||
@@ -2028,7 +2037,13 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
|
|
||||||
target = noise if not predict_x_start else x_start
|
target = noise if not predict_x_start else x_start
|
||||||
|
|
||||||
loss = self.loss_fn(pred, target)
|
loss = self.loss_fn(pred, target, reduction = 'none')
|
||||||
|
loss = reduce(loss, 'b ... -> b (...)', 'mean')
|
||||||
|
|
||||||
|
if self.has_p2_loss_reweighting:
|
||||||
|
loss = loss * extract(self.p2_loss_weight, times, loss.shape)
|
||||||
|
|
||||||
|
loss = loss.mean()
|
||||||
|
|
||||||
if not learned_variance:
|
if not learned_variance:
|
||||||
# return simple loss if not using learned variance
|
# return simple loss if not using learned variance
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = '0.7.0'
|
__version__ = '0.7.1'
|
||||||
|
|||||||
Reference in New Issue
Block a user