diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index b157a3e..271c959 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -421,25 +421,41 @@ class DiffusionPriorNetwork(nn.Module): image_embed, diffusion_timesteps, *, - text_encodings, text_embed, + text_encodings = None, mask = None, cond_drop_prob = 0.2 ): - batch, text_enc_len, device = image_embed.shape[0], text_encodings.shape[-2], image_embed.device + batch, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype # in section 2.2, last paragraph # "... consisting of encoded text, CLIP text embedding, diffusion timestep embedding, noised CLIP image embedding, final embedding for prediction" text_embed, image_embed = rearrange_many((text_embed, image_embed), 'b d -> b 1 d') + # make text encodings optional + # although the paper seems to suggest it is present <-- + + if not exists(text_encodings): + text_encodings = torch.empty((batch, 0, dim), device = device, dtype = dtype) + + if not exists(mask): + mask = torch.ones((batch, text_encodings.shape[-2]), device = device, dtype = torch.bool) + + # classifier free guidance + + cond_prob_mask = prob_mask_like((batch,), cond_drop_prob, device = device) + cond_prob_mask = rearrange(cond_prob_mask, 'b -> b 1') + + mask &= cond_prob_mask + + # whether text embedding is masked or not depends on the classifier free guidance conditional masking + + mask = torch.cat((mask, cond_prob_mask), dim = 1) + # whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different) # but let's just do it right - if exists(mask): - not_all_masked_out = mask.any(dim = -1) - mask = torch.cat((mask, rearrange(not_all_masked_out, 'b -> b 1')), dim = 1) - if exists(mask): mask = F.pad(mask, (0, 2), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query @@ -455,16 +471,6 @@ class DiffusionPriorNetwork(nn.Module): learned_queries ), dim = -2) - # mask if it doesn't exist - - if not exists(mask): - mask = torch.ones((batch, text_enc_len), device = device, dtype = torch.bool) - - # classifier free guidance - - cond_prob_mask = prob_mask_like((batch,), cond_drop_prob, device = device) - mask &= rearrange(cond_prob_mask, 'b -> b 1') - # attend tokens = self.causal_transformer(tokens, mask = mask) @@ -486,6 +492,7 @@ class DiffusionPrior(nn.Module): loss_type = "l1", predict_x_start = True, beta_schedule = "cosine", + condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training ): super().__init__() assert isinstance(clip, CLIP) @@ -496,7 +503,9 @@ class DiffusionPrior(nn.Module): self.image_embed_dim = clip.dim_latent self.channels = clip.image_channels self.image_size = clip.image_size + self.cond_drop_prob = cond_drop_prob + self.condition_on_text_encodings = condition_on_text_encodings self.predict_x_start = predict_x_start # in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both. @@ -561,6 +570,10 @@ class DiffusionPrior(nn.Module): text_cls, text_encodings = text_encodings[:, 0], text_encodings[:, 1:] text_embed = self.clip.to_text_latent(text_cls) text_embed = l2norm(text_embed) + + if not self.condition_on_text_encodings: + return dict(text_embed = text_embed) + return dict(text_encodings = text_encodings, text_embed = text_embed, mask = text != 0) def q_mean_variance(self, x_start, t): diff --git a/setup.py b/setup.py index 010e08c..1cd68c5 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.46', + version = '0.0.47', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',