From 8864fd0aa75a40cff678e0feb430175decafa73e Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 24 May 2022 18:14:35 -0700 Subject: [PATCH] bring in the dynamic thresholding technique from the Imagen paper, which purportedly improves classifier free guidance for the cascading ddpm --- README.md | 8 ++++++++ dalle2_pytorch/dalle2_pytorch.py | 23 ++++++++++++++++++++++- setup.py | 2 +- 3 files changed, 31 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index e51e7c2..ad0169f 100644 --- a/README.md +++ b/README.md @@ -1195,4 +1195,12 @@ This library would not have gotten to this working state without the help of } ``` +```bibtex +@misc{Saharia2022, + title = {Imagen: unprecedented photorealism × deep level of language understanding}, + author = {Chitwan Saharia*, William Chan*, Saurabh Saxena†, Lala Li†, Jay Whang†, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S. Sara Mahdavi, Rapha Gontijo Lopes, Tim Salimans, Jonathan Ho†, David Fleet†, Mohammad Norouzi*}, + year = {2022} +} +``` + *Creating noise from data is easy; creating data from noise is generative modeling.* - Yang Song's paper diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 4907b75..5dc4f41 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1704,6 +1704,8 @@ class Decoder(BaseGaussianDiffusion): vb_loss_weight = 0.001, unconditional = False, auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader + use_dynamic_thres = False, # from the Imagen paper + dynamic_thres_percentile = 0.9 ): super().__init__( beta_schedule = beta_schedule, @@ -1826,6 +1828,11 @@ class Decoder(BaseGaussianDiffusion): self.clip_denoised = clip_denoised self.clip_x_start = clip_x_start + # dynamic thresholding settings, if clipping denoised during sampling + + self.use_dynamic_thres = use_dynamic_thres + self.dynamic_thres_percentile = dynamic_thres_percentile + # normalize and unnormalize image functions self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity @@ -1868,7 +1875,21 @@ class Decoder(BaseGaussianDiffusion): x_recon = self.predict_start_from_noise(x, t = t, noise = pred) if clip_denoised: - x_recon.clamp_(-1., 1.) + # s is the threshold amount + # static thresholding would just be s = 1 + s = 1. + if self.use_dynamic_thres: + s = torch.quantile( + rearrange(x_recon, 'b ... -> b (...)').abs(), + self.dynamic_thres_percentile, + dim = -1 + ) + + s.clamp_(min = 1.) + s = s.view(-1, *((1,) * (x_recon.ndim - 1))) + + # clip by threshold, depending on whether static or dynamic + x_recon = x_recon.clamp(-s, s) / s model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) diff --git a/setup.py b/setup.py index c564511..86a12c3 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.4.14', + version = '0.5.0', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',