From 7ba6357c057b053cedb8e80693bb638b5b26c99c Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 26 Apr 2022 09:29:51 -0700 Subject: [PATCH] allow for training the Prior network with precomputed CLIP embeddings (or text encodings) --- README.md | 69 ++++++++++++++++++++++++++++++++ dalle2_pytorch/dalle2_pytorch.py | 40 +++++++++++++++--- setup.py | 2 +- 3 files changed, 104 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 345efea..3979503 100644 --- a/README.md +++ b/README.md @@ -376,6 +376,75 @@ You can also train the decoder on images of greater than the size (say 512x512) For the layperson, no worries, training will all be automated into a CLI tool, at least for small scale training. +## Training on Preprocessed CLIP Embeddings + +It is likely, when scaling up, that you would first preprocess your images and text into corresponding embeddings before training the prior network. You can do so easily by simply passing in `image_embed`, `text_embed`, and optionally `text_encodings` and `text_mask` + +Working example below + +```python +import torch +from dalle2_pytorch import DiffusionPriorNetwork, DiffusionPrior, CLIP + +# get trained CLIP from step one + +clip = CLIP( + dim_text = 512, + dim_image = 512, + dim_latent = 512, + num_text_tokens = 49408, + text_enc_depth = 6, + text_seq_len = 256, + text_heads = 8, + visual_enc_depth = 6, + visual_image_size = 256, + visual_patch_size = 32, + visual_heads = 8, +).cuda() + +# setup prior network, which contains an autoregressive transformer + +prior_network = DiffusionPriorNetwork( + dim = 512, + depth = 6, + dim_head = 64, + heads = 8 +).cuda() + +# diffusion prior network, which contains the CLIP and network (with transformer) above + +diffusion_prior = DiffusionPrior( + net = prior_network, + clip = clip, + timesteps = 100, + cond_drop_prob = 0.2, + condition_on_text_encodings = False # this probably should be true, but just to get Laion started +).cuda() + +# mock data + +text = torch.randint(0, 49408, (4, 256)).cuda() +images = torch.randn(4, 3, 256, 256).cuda() + +# precompute the text and image embeddings +# here using the diffusion prior class, but could be done with CLIP alone + +clip_image_embeds = diffusion_prior.get_image_embed(images) +clip_text_embeds = diffusion_prior.get_text_cond(text).get('text_embed') + +# feed text and images into diffusion prior network + +loss = diffusion_prior( + text_embed = clip_text_embeds, + image_embed = clip_image_embeds +) + +loss.backward() + +# do the above for many many many steps +# now the diffusion prior can generate image embeddings from the text embeddings +``` + ## Experimental ### DALL-E2 with Latent Diffusion diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 271c959..e1cedcc 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -692,13 +692,41 @@ class DiffusionPrior(nn.Module): top_image_embeds = image_embeds.gather(1, top_sim_indices) return rearrange(top_image_embeds, 'b 1 d -> b d') - def forward(self, text, image, *args, **kwargs): - b, device, img_size, = image.shape[0], image.device, self.image_size - check_shape(image, 'b c h w', h = img_size, w = img_size, c = self.channels) + def forward( + self, + text = None, + image = None, + text_embed = None, # allow for training on preprocessed CLIP text and image embeddings + image_embed = None, + text_encodings = None, # as well as CLIP text encodings + text_mask = None, # text mask <- may eventually opt for the learned padding tokens technique from DALL-E1 to reduce complexity + *args, + **kwargs + ): + assert exists(text) ^ exists(text_embed), 'either text or text embedding must be supplied' + assert exists(image) ^ exists(image_embed), 'either text or text embedding must be supplied' + assert not (self.condition_on_text_encodings and (not exists(text_encodings) and not exists(text))), 'text encodings must be present if you specified you wish to condition on it on initialization' - times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long) - image_embed = self.get_image_embed(image) - text_cond = self.get_text_cond(text) + if exists(image): + image_embed = self.get_image_embed(image) + + # calculate text conditionings, based on what is passed in + + if exists(text): + text_cond = self.get_text_cond(text) + else: + text_cond = dict( + text_embed = text_embed, + text_encodings = text_encodings, + mask = text_mask + ) + + # timestep conditioning from ddpm + + batch, device = image_embed.shape[0], image_embed.device + times = torch.randint(0, self.num_timesteps, (batch,), device = device, dtype = torch.long) + + # calculate forward loss loss = self.p_losses(image_embed, times, text_cond = text_cond, *args, **kwargs) return loss diff --git a/setup.py b/setup.py index 1cd68c5..dbbea31 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.47', + version = '0.0.48', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',