From 77fa34eae90f6bb321d5f461eca3a9094d1cf225 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sat, 30 Apr 2022 10:08:24 -0700 Subject: [PATCH] fix all clipping / clamping issues --- dalle2_pytorch/dalle2_pytorch.py | 28 +++++++++++++++++++++++----- setup.py | 2 +- 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index bf0d7fa..3988fae 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -736,6 +736,7 @@ class DiffusionPrior(BaseGaussianDiffusion): predict_x_start = True, beta_schedule = "cosine", condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training + sampling_clamp_l2norm = False ): super().__init__( beta_schedule = beta_schedule, @@ -764,6 +765,9 @@ class DiffusionPrior(BaseGaussianDiffusion): self.predict_x_start = predict_x_start # in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both. + # whether to force an l2norm, similar to clipping denoised, when sampling + self.sampling_clamp_l2norm = sampling_clamp_l2norm + def p_mean_variance(self, x, t, text_cond, clip_denoised: bool): pred = self.net(x, t, **text_cond) @@ -777,6 +781,9 @@ class DiffusionPrior(BaseGaussianDiffusion): if clip_denoised and not self.predict_x_start: x_recon.clamp_(-1., 1.) + if self.predict_x_start and self.sampling_clamp_l2norm: + x_recon = l2norm(x_recon) + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) return model_mean, posterior_variance, posterior_log_variance @@ -1232,6 +1239,7 @@ class Unet(nn.Module): text_tokens = None if exists(text_encodings) and self.cond_on_text_encodings: + text_tokens = self.text_to_cond(text_encodings) text_tokens = text_tokens[:, :self.max_text_len] text_tokens_len = text_tokens.shape[1] @@ -1244,9 +1252,9 @@ class Unet(nn.Module): if remainder > 0: text_mask = F.pad(text_mask, (0, remainder), value = False) - text_keep_mask &= text_mask + text_mask = rearrange(text_mask, 'b n -> b n 1') + text_keep_mask = text_mask & text_keep_mask - text_tokens = self.text_to_cond(text_encodings) text_tokens = torch.where( text_keep_mask, text_tokens, @@ -1350,6 +1358,8 @@ class Decoder(BaseGaussianDiffusion): blur_sigma = 0.1, # cascading ddpm - blur sigma blur_kernel_size = 3, # cascading ddpm - blur kernel size condition_on_text_encodings = False, # the paper suggested that this didn't do much in the decoder, but i'm allowing the option for experimentation + clip_denoised = True, + clip_x_start = True ): super().__init__( beta_schedule = beta_schedule, @@ -1426,6 +1436,11 @@ class Decoder(BaseGaussianDiffusion): self.image_cond_drop_prob = image_cond_drop_prob self.text_cond_drop_prob = text_cond_drop_prob + # whether to clip when sampling + + self.clip_denoised = clip_denoised + self.clip_x_start = clip_x_start + def get_unet(self, unet_number): assert 0 < unet_number <= len(self.unets) index = unet_number - 1 @@ -1459,7 +1474,7 @@ class Decoder(BaseGaussianDiffusion): else: x_recon = self.predict_start_from_noise(x, t = t, noise = pred) - if clip_denoised and not predict_x_start: + if clip_denoised: x_recon.clamp_(-1., 1.) model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) @@ -1475,7 +1490,7 @@ class Decoder(BaseGaussianDiffusion): return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise @torch.no_grad() - def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1): + def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1): device = self.betas.device b = shape[0] @@ -1491,7 +1506,8 @@ class Decoder(BaseGaussianDiffusion): text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, - predict_x_start = predict_x_start + predict_x_start = predict_x_start, + clip_denoised = clip_denoised ) return img @@ -1542,6 +1558,7 @@ class Decoder(BaseGaussianDiffusion): if unet.lowres_cond: lowres_cond_img = self.to_lowres_cond(img, target_image_size = image_size) + is_latent_diffusion = isinstance(vae, VQGanVAE) image_size = vae.get_encoded_fmap_size(image_size) shape = (batch_size, vae.encoded_dim, image_size, image_size) @@ -1556,6 +1573,7 @@ class Decoder(BaseGaussianDiffusion): text_mask = text_mask, cond_scale = cond_scale, predict_x_start = predict_x_start, + clip_denoised = not is_latent_diffusion, lowres_cond_img = lowres_cond_img ) diff --git a/setup.py b/setup.py index 6d46224..7fe4fa5 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.75', + version = '0.0.76', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',