From e46eaec817172e502f13e374a26b3834d3ab6df5 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Mon, 9 May 2022 11:08:46 -0700 Subject: [PATCH] deal the diffusion prior problem yet another blow --- dalle2_pytorch/dalle2_pytorch.py | 14 +++++++++++--- setup.py | 2 +- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 0ab1f76..7717a6b 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -743,11 +743,18 @@ class DiffusionPriorNetwork(nn.Module): num_timesteps = None, num_time_embeds = 1, num_image_embeds = 1, + num_text_embeds = 1, **kwargs ): super().__init__() self.num_time_embeds = num_time_embeds self.num_image_embeds = num_image_embeds + self.num_text_embeds = num_text_embeds + + self.to_text_embeds = nn.Sequential( + nn.Linear(dim, dim * num_text_embeds) if num_text_embeds > 1 else nn.Identity(), + Rearrange('b (n d) -> b n d', n = num_text_embeds) + ) self.to_time_embeds = nn.Sequential( nn.Embedding(num_timesteps, dim * num_time_embeds) if exists(num_timesteps) else nn.Sequential(SinusoidalPosEmb(dim), MLP(dim, dim * num_time_embeds)), # also offer a continuous version of timestep embeddings, with a 2 layer MLP @@ -755,7 +762,7 @@ class DiffusionPriorNetwork(nn.Module): ) self.to_image_embeds = nn.Sequential( - nn.Linear(dim, dim * num_image_embeds), + nn.Linear(dim, dim * num_image_embeds) if num_image_embeds > 1 else nn.Identity(), Rearrange('b (n d) -> b n d', n = num_image_embeds) ) @@ -788,12 +795,12 @@ class DiffusionPriorNetwork(nn.Module): ): batch, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype - num_time_embeds, num_image_embeds = self.num_time_embeds, self.num_image_embeds + num_time_embeds, num_image_embeds, num_text_embeds = self.num_time_embeds, self.num_image_embeds, self.num_text_embeds # in section 2.2, last paragraph # "... consisting of encoded text, CLIP text embedding, diffusion timestep embedding, noised CLIP image embedding, final embedding for prediction" - text_embed = rearrange(text_embed, 'b d -> b 1 d') + text_embed = self.to_text_embeds(text_embed) image_embed = self.to_image_embeds(image_embed) # make text encodings optional @@ -814,6 +821,7 @@ class DiffusionPriorNetwork(nn.Module): # whether text embedding is masked or not depends on the classifier free guidance conditional masking + keep_mask = repeat(keep_mask, 'b 1 -> b n', n = num_text_embeds) mask = torch.cat((mask, keep_mask), dim = 1) # whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different) diff --git a/setup.py b/setup.py index f21114c..30601d0 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.2.0', + version = '0.2.1', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',