From fd38eb83c492c433de742ba98371227ce7d3ef81 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 12 Apr 2022 11:43:59 -0700 Subject: [PATCH] complete the main contribution of the paper, the diffusion prior network, minus the diffusion training setup --- dalle2_pytorch/dalle2_pytorch.py | 50 ++++++++++++++++++++++++++++++++ setup.py | 1 + 2 files changed, 51 insertions(+) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index db145aa..5712c6a 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -149,6 +149,56 @@ class Transformer(nn.Module): return self.norm(x) +class PriorNetwork(nn.Module): + def __init__( + self, + dim, + num_timesteps = 1000, + **kwargs + ): + super().__init__() + self.time_embeddings = nn.Embedding(num_timesteps, dim) # also offer a continuous version of timestep embeddings, with a 2 layer MLP + self.learned_query = nn.Parameter(torch.randn(dim)) + self.causal_transformer = Transformer(**kwargs) + + def forward( + self, + image_embed, + *, + diffusion_timesteps, + text_encodings, + text_embed, + mask = None, + ): + batch = image_embed.shape[0] + + # in section 2.2, last paragraph + # "... consisting of encoded text, CLIP text embedding, diffusion timestep embedding, noised CLIP image embedding, final embedding for prediction" + + text_embed, image_embed = rearrange_many((text_embed, image_embed), 'b d -> b 1 d') + + if exists(mask): + mask = F.pad(mask, (0, 4), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query + + time_embed = self.time_embeddings(diffusion_timesteps) + + learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch) + + tokens = torch.cat(( + text_encodings, + text_embed, + time_embed, + learned_queries + ), dim = -2) + + tokens = self.causal_transformer(tokens, mask = mask) + + # get learned query, which should predict the image embedding (per DDPM timestep) + + pred_image_embed = tokens[..., -1, :] + + return pred_image_embed + class DiffusionPrior(nn.Module): def __init__( self, diff --git a/setup.py b/setup.py index 1244b5d..8ffea1f 100644 --- a/setup.py +++ b/setup.py @@ -28,6 +28,7 @@ setup( 'pillow', 'torch>=1.10', 'torchvision', + 'tqdm', 'x-clip>=0.4.1', 'youtokentome' ],