From 9eea9b9862910a22627d50e6174d4b3cab4dd172 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 14 Jun 2022 10:58:57 -0700 Subject: [PATCH] add p2 loss reweighting for decoder training as an option --- README.md | 10 ++++++++++ dalle2_pytorch/dalle2_pytorch.py | 25 ++++++++++++++++++++----- dalle2_pytorch/version.py | 2 +- 3 files changed, 31 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 043c29a..0a6f442 100644 --- a/README.md +++ b/README.md @@ -1207,4 +1207,14 @@ This library would not have gotten to this working state without the help of } ``` +```bibtex +@article{Choi2022PerceptionPT, + title = {Perception Prioritized Training of Diffusion Models}, + author = {Jooyoung Choi and Jungbeom Lee and Chaehun Shin and Sungwon Kim and Hyunwoo J. Kim and Sung-Hoon Yoon}, + journal = {ArXiv}, + year = {2022}, + volume = {abs/2204.00227} +} +``` + *Creating noise from data is easy; creating data from noise is generative modeling.* - Yang Song's paper diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index bea84d7..c5ce08a 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -11,7 +11,7 @@ import torch.nn.functional as F from torch import nn, einsum import torchvision.transforms as T -from einops import rearrange, repeat +from einops import rearrange, repeat, reduce from einops.layers.torch import Rearrange from einops_exts import rearrange_many, repeat_many, check_shape from einops_exts.torch import EinopsToAndFrom @@ -379,7 +379,7 @@ def sigmoid_beta_schedule(timesteps): class BaseGaussianDiffusion(nn.Module): - def __init__(self, *, beta_schedule, timesteps, loss_type): + def __init__(self, *, beta_schedule, timesteps, loss_type, p2_loss_weight_gamma = 0., p2_loss_weight_k = 1): super().__init__() if beta_schedule == "cosine": @@ -444,6 +444,11 @@ class BaseGaussianDiffusion(nn.Module): register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)) register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)) + # p2 loss reweighting + + self.has_p2_loss_reweighting = p2_loss_weight_gamma > 0. + register_buffer('p2_loss_weight', (p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod)) ** -p2_loss_weight_gamma) + def q_posterior(self, x_start, x_t, t): posterior_mean = ( extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + @@ -1755,12 +1760,16 @@ class Decoder(BaseGaussianDiffusion): unconditional = False, auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader use_dynamic_thres = False, # from the Imagen paper - dynamic_thres_percentile = 0.9 + dynamic_thres_percentile = 0.9, + p2_loss_weight_gamma = 0., # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended + p2_loss_weight_k = 1 ): super().__init__( beta_schedule = beta_schedule, timesteps = timesteps, - loss_type = loss_type + loss_type = loss_type, + p2_loss_weight_gamma = p2_loss_weight_gamma, + p2_loss_weight_k = p2_loss_weight_k ) self.unconditional = unconditional @@ -2028,7 +2037,13 @@ class Decoder(BaseGaussianDiffusion): target = noise if not predict_x_start else x_start - loss = self.loss_fn(pred, target) + loss = self.loss_fn(pred, target, reduction = 'none') + loss = reduce(loss, 'b ... -> b (...)', 'mean') + + if self.has_p2_loss_reweighting: + loss = loss * extract(self.p2_loss_weight, times, loss.shape) + + loss = loss.mean() if not learned_variance: # return simple loss if not using learned variance diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index a71c5c7..f0788a8 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.7.0' +__version__ = '0.7.1'