From 59fa101c4d21843477a5202560a877328ad47afd Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 23 Aug 2022 08:28:54 -0700 Subject: [PATCH] fix classifier free guidance for diffusion prior, thanks to @jaykim9870 for spotting the issue --- dalle2_pytorch/dalle2_pytorch.py | 64 ++++++++++++++++++++++---------- dalle2_pytorch/version.py | 2 +- 2 files changed, 45 insertions(+), 21 deletions(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 2b11928..5292d44 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -978,7 +978,10 @@ class DiffusionPriorNetwork(nn.Module): # dalle1 learned padding strategy self.max_text_len = max_text_len - self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, dim)) + + self.null_text_encodings = nn.Parameter(torch.randn(1, max_text_len, dim)) + self.null_text_embeds = nn.Parameter(torch.randn(1, num_text_embeds, dim)) + self.null_image_embed = nn.Parameter(torch.randn(1, dim)) # whether to use self conditioning, Hinton's group's new ddpm technique @@ -995,7 +998,7 @@ class DiffusionPriorNetwork(nn.Module): if cond_scale == 1: return logits - null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs) + null_logits = self.forward(*args, text_cond_drop_prob = 1., image_cond_drop_prob = 1, **kwargs) return null_logits + (logits - null_logits) * cond_scale def forward( @@ -1006,7 +1009,8 @@ class DiffusionPriorNetwork(nn.Module): text_embed, text_encodings = None, self_cond = None, - cond_drop_prob = 0. + text_cond_drop_prob = 0., + image_cond_drop_prob = 0. ): batch, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype @@ -1024,6 +1028,14 @@ class DiffusionPriorNetwork(nn.Module): text_embed = self.to_text_embeds(text_embed) image_embed = self.to_image_embeds(image_embed) + # classifier free guidance masks + + text_keep_mask = prob_mask_like((batch,), 1 - text_cond_drop_prob, device = device) + text_keep_mask = rearrange(text_keep_mask, 'b -> b 1 1') + + image_keep_mask = prob_mask_like((batch,), 1 - image_cond_drop_prob, device = device) + image_keep_mask = rearrange(image_keep_mask, 'b -> b 1 1') + # make text encodings optional # although the paper seems to suggest it is present <-- @@ -1044,32 +1056,39 @@ class DiffusionPriorNetwork(nn.Module): text_encodings = F.pad(text_encodings, (0, 0, 0, remainder), value = 0.) mask = F.pad(mask, (0, remainder), value = False) - null_text_embeds = self.null_text_embed.to(text_encodings.dtype) + # mask out text encodings with null encodings + + null_text_encodings = self.null_text_encodings.to(text_encodings.dtype) text_encodings = torch.where( - rearrange(mask, 'b n -> b n 1').clone(), + rearrange(mask, 'b n -> b n 1').clone() & text_keep_mask, text_encodings, + null_text_encodings + ) + + # mask out text embeddings with null text embeddings + + null_text_embeds = self.null_text_embeds.to(text_embed.dtype) + + text_embeds = torch.where( + text_keep_mask, + text_embed, null_text_embeds ) - # classifier free guidance + # mask out image embeddings with null image embeddings - keep_mask = prob_mask_like((batch,), 1 - cond_drop_prob, device = device) - keep_mask = rearrange(keep_mask, 'b -> b 1') + null_image_embed = self.null_image_embed.to(image_embed.dtype) - mask &= keep_mask - - # whether text embedding is masked or not depends on the classifier free guidance conditional masking - - keep_mask = repeat(keep_mask, 'b 1 -> b n', n = num_text_embeds) - mask = torch.cat((mask, keep_mask), dim = 1) + image_embed = torch.where( + image_keep_mask, + image_embed, + null_image_embed + ) # whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different) # but let's just do it right - attend_padding = 1 + num_time_embeds + num_image_embeds + int(self.self_cond) # 1 for learned queries + number of image embeds + time embeds - mask = F.pad(mask, (0, attend_padding), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query - time_embed = self.to_time_embeds(diffusion_timesteps) learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch) @@ -1107,6 +1126,8 @@ class DiffusionPrior(nn.Module): timesteps = 1000, sample_timesteps = None, cond_drop_prob = 0., + text_cond_drop_prob = None, + image_cond_drop_prob = None, loss_type = "l2", predict_x_start = True, beta_schedule = "cosine", @@ -1147,8 +1168,10 @@ class DiffusionPrior(nn.Module): self.image_embed_dim = default(image_embed_dim, lambda: clip.dim_latent) self.channels = default(image_channels, lambda: clip.image_channels) - self.cond_drop_prob = cond_drop_prob - self.can_classifier_guidance = cond_drop_prob > 0. + self.text_cond_drop_prob = default(text_cond_drop_prob, cond_drop_prob) + self.image_cond_drop_prob = default(image_cond_drop_prob, cond_drop_prob) + + self.can_classifier_guidance = self.text_cond_drop_prob > 0. and self.image_cond_drop_prob > 0. self.condition_on_text_encodings = condition_on_text_encodings # in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both. @@ -1308,7 +1331,8 @@ class DiffusionPrior(nn.Module): image_embed_noisy, times, self_cond = self_cond, - cond_drop_prob = self.cond_drop_prob, + text_cond_drop_prob = self.text_cond_drop_prob, + image_cond_drop_prob = self.image_cond_drop_prob, **text_cond ) diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index fa2822c..e5102d3 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.8.4' +__version__ = '1.9.0'