mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
bring in the dynamic thresholding technique from the Imagen paper, which purportedly improves classifier free guidance for the cascading ddpm
This commit is contained in:
@@ -1195,4 +1195,12 @@ This library would not have gotten to this working state without the help of
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@misc{Saharia2022,
|
||||||
|
title = {Imagen: unprecedented photorealism × deep level of language understanding},
|
||||||
|
author = {Chitwan Saharia*, William Chan*, Saurabh Saxena†, Lala Li†, Jay Whang†, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S. Sara Mahdavi, Rapha Gontijo Lopes, Tim Salimans, Jonathan Ho†, David Fleet†, Mohammad Norouzi*},
|
||||||
|
year = {2022}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>
|
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>
|
||||||
|
|||||||
@@ -1704,6 +1704,8 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
vb_loss_weight = 0.001,
|
vb_loss_weight = 0.001,
|
||||||
unconditional = False,
|
unconditional = False,
|
||||||
auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader
|
auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader
|
||||||
|
use_dynamic_thres = False, # from the Imagen paper
|
||||||
|
dynamic_thres_percentile = 0.9
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
beta_schedule = beta_schedule,
|
beta_schedule = beta_schedule,
|
||||||
@@ -1826,6 +1828,11 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
self.clip_denoised = clip_denoised
|
self.clip_denoised = clip_denoised
|
||||||
self.clip_x_start = clip_x_start
|
self.clip_x_start = clip_x_start
|
||||||
|
|
||||||
|
# dynamic thresholding settings, if clipping denoised during sampling
|
||||||
|
|
||||||
|
self.use_dynamic_thres = use_dynamic_thres
|
||||||
|
self.dynamic_thres_percentile = dynamic_thres_percentile
|
||||||
|
|
||||||
# normalize and unnormalize image functions
|
# normalize and unnormalize image functions
|
||||||
|
|
||||||
self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity
|
self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity
|
||||||
@@ -1868,7 +1875,21 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
x_recon = self.predict_start_from_noise(x, t = t, noise = pred)
|
x_recon = self.predict_start_from_noise(x, t = t, noise = pred)
|
||||||
|
|
||||||
if clip_denoised:
|
if clip_denoised:
|
||||||
x_recon.clamp_(-1., 1.)
|
# s is the threshold amount
|
||||||
|
# static thresholding would just be s = 1
|
||||||
|
s = 1.
|
||||||
|
if self.use_dynamic_thres:
|
||||||
|
s = torch.quantile(
|
||||||
|
rearrange(x_recon, 'b ... -> b (...)').abs(),
|
||||||
|
self.dynamic_thres_percentile,
|
||||||
|
dim = -1
|
||||||
|
)
|
||||||
|
|
||||||
|
s.clamp_(min = 1.)
|
||||||
|
s = s.view(-1, *((1,) * (x_recon.ndim - 1)))
|
||||||
|
|
||||||
|
# clip by threshold, depending on whether static or dynamic
|
||||||
|
x_recon = x_recon.clamp(-s, s) / s
|
||||||
|
|
||||||
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user