From 1d5dc088109c5d606096d40bea59ff8c7b8e7d9f Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Thu, 5 May 2022 07:28:53 -0700 Subject: [PATCH] take @crowsonkb 's suggestion at https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132 --- dalle2_pytorch/dalle2_pytorch.py | 19 ++++++++++++++++--- setup.py | 2 +- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 8b9bfda..806a0d8 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -756,7 +756,8 @@ class DiffusionPrior(BaseGaussianDiffusion): predict_x_start = True, beta_schedule = "cosine", condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training - sampling_clamp_l2norm = False + sampling_clamp_l2norm = False, + image_embed_scale = None, # this is for scaling the l2-normed image embedding, so it is more suitable for gaussian diffusion, as outlined by Katherine (@crowsonkb) https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132 ): super().__init__( beta_schedule = beta_schedule, @@ -782,8 +783,11 @@ class DiffusionPrior(BaseGaussianDiffusion): self.cond_drop_prob = cond_drop_prob self.condition_on_text_encodings = condition_on_text_encodings - self.predict_x_start = predict_x_start # in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both. + self.predict_x_start = predict_x_start + + # @crowsonkb 's suggestion - https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132 + self.image_embed_scale = default(image_embed_scale, image_embed_dim ** 0.5) # whether to force an l2norm, similar to clipping denoised, when sampling self.sampling_clamp_l2norm = sampling_clamp_l2norm @@ -802,7 +806,7 @@ class DiffusionPrior(BaseGaussianDiffusion): x_recon.clamp_(-1., 1.) if self.predict_x_start and self.sampling_clamp_l2norm: - x_recon = l2norm(x_recon) + x_recon = l2norm(x_recon) * self.image_embed_scale model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) return model_mean, posterior_variance, posterior_log_variance @@ -862,6 +866,11 @@ class DiffusionPrior(BaseGaussianDiffusion): text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask} image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond) + + # retrieve original unscaled image embed + + image_embeds /= self.image_embed_scale + text_embeds = text_cond['text_embed'] text_embeds = rearrange(text_embeds, '(b r) d -> b r d', r = num_samples_per_batch) @@ -909,6 +918,10 @@ class DiffusionPrior(BaseGaussianDiffusion): batch, device = image_embed.shape[0], image_embed.device times = torch.randint(0, self.num_timesteps, (batch,), device = device, dtype = torch.long) + # scale image embed (Katherine) + + image_embed *= self.image_embed_scale + # calculate forward loss return self.p_losses(image_embed, times, text_cond = text_cond, *args, **kwargs) diff --git a/setup.py b/setup.py index 06338ef..595d13d 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.102', + version = '0.0.104', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',