diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 2b11928..5292d44 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -978,7 +978,10 @@ class DiffusionPriorNetwork(nn.Module): # dalle1 learned padding strategy self.max_text_len = max_text_len - self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, dim)) + + self.null_text_encodings = nn.Parameter(torch.randn(1, max_text_len, dim)) + self.null_text_embeds = nn.Parameter(torch.randn(1, num_text_embeds, dim)) + self.null_image_embed = nn.Parameter(torch.randn(1, dim)) # whether to use self conditioning, Hinton's group's new ddpm technique @@ -995,7 +998,7 @@ class DiffusionPriorNetwork(nn.Module): if cond_scale == 1: return logits - null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs) + null_logits = self.forward(*args, text_cond_drop_prob = 1., image_cond_drop_prob = 1, **kwargs) return null_logits + (logits - null_logits) * cond_scale def forward( @@ -1006,7 +1009,8 @@ class DiffusionPriorNetwork(nn.Module): text_embed, text_encodings = None, self_cond = None, - cond_drop_prob = 0. + text_cond_drop_prob = 0., + image_cond_drop_prob = 0. ): batch, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype @@ -1024,6 +1028,14 @@ class DiffusionPriorNetwork(nn.Module): text_embed = self.to_text_embeds(text_embed) image_embed = self.to_image_embeds(image_embed) + # classifier free guidance masks + + text_keep_mask = prob_mask_like((batch,), 1 - text_cond_drop_prob, device = device) + text_keep_mask = rearrange(text_keep_mask, 'b -> b 1 1') + + image_keep_mask = prob_mask_like((batch,), 1 - image_cond_drop_prob, device = device) + image_keep_mask = rearrange(image_keep_mask, 'b -> b 1 1') + # make text encodings optional # although the paper seems to suggest it is present <-- @@ -1044,32 +1056,39 @@ class DiffusionPriorNetwork(nn.Module): text_encodings = F.pad(text_encodings, (0, 0, 0, remainder), value = 0.) mask = F.pad(mask, (0, remainder), value = False) - null_text_embeds = self.null_text_embed.to(text_encodings.dtype) + # mask out text encodings with null encodings + + null_text_encodings = self.null_text_encodings.to(text_encodings.dtype) text_encodings = torch.where( - rearrange(mask, 'b n -> b n 1').clone(), + rearrange(mask, 'b n -> b n 1').clone() & text_keep_mask, text_encodings, + null_text_encodings + ) + + # mask out text embeddings with null text embeddings + + null_text_embeds = self.null_text_embeds.to(text_embed.dtype) + + text_embeds = torch.where( + text_keep_mask, + text_embed, null_text_embeds ) - # classifier free guidance + # mask out image embeddings with null image embeddings - keep_mask = prob_mask_like((batch,), 1 - cond_drop_prob, device = device) - keep_mask = rearrange(keep_mask, 'b -> b 1') + null_image_embed = self.null_image_embed.to(image_embed.dtype) - mask &= keep_mask - - # whether text embedding is masked or not depends on the classifier free guidance conditional masking - - keep_mask = repeat(keep_mask, 'b 1 -> b n', n = num_text_embeds) - mask = torch.cat((mask, keep_mask), dim = 1) + image_embed = torch.where( + image_keep_mask, + image_embed, + null_image_embed + ) # whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different) # but let's just do it right - attend_padding = 1 + num_time_embeds + num_image_embeds + int(self.self_cond) # 1 for learned queries + number of image embeds + time embeds - mask = F.pad(mask, (0, attend_padding), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query - time_embed = self.to_time_embeds(diffusion_timesteps) learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch) @@ -1107,6 +1126,8 @@ class DiffusionPrior(nn.Module): timesteps = 1000, sample_timesteps = None, cond_drop_prob = 0., + text_cond_drop_prob = None, + image_cond_drop_prob = None, loss_type = "l2", predict_x_start = True, beta_schedule = "cosine", @@ -1147,8 +1168,10 @@ class DiffusionPrior(nn.Module): self.image_embed_dim = default(image_embed_dim, lambda: clip.dim_latent) self.channels = default(image_channels, lambda: clip.image_channels) - self.cond_drop_prob = cond_drop_prob - self.can_classifier_guidance = cond_drop_prob > 0. + self.text_cond_drop_prob = default(text_cond_drop_prob, cond_drop_prob) + self.image_cond_drop_prob = default(image_cond_drop_prob, cond_drop_prob) + + self.can_classifier_guidance = self.text_cond_drop_prob > 0. and self.image_cond_drop_prob > 0. self.condition_on_text_encodings = condition_on_text_encodings # in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both. @@ -1308,7 +1331,8 @@ class DiffusionPrior(nn.Module): image_embed_noisy, times, self_cond = self_cond, - cond_drop_prob = self.cond_drop_prob, + text_cond_drop_prob = self.text_cond_drop_prob, + image_cond_drop_prob = self.image_cond_drop_prob, **text_cond ) diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index fa2822c..e5102d3 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.8.4' +__version__ = '1.9.0'