allow for training the Prior network with precomputed CLIP embeddings (or text encodings)

This commit is contained in:
Phil Wang
2022-04-26 09:29:51 -07:00
parent 76e063e8b7
commit 7ba6357c05
3 changed files with 104 additions and 7 deletions

View File

@@ -376,6 +376,75 @@ You can also train the decoder on images of greater than the size (say 512x512)
For the layperson, no worries, training will all be automated into a CLI tool, at least for small scale training. For the layperson, no worries, training will all be automated into a CLI tool, at least for small scale training.
## Training on Preprocessed CLIP Embeddings
It is likely, when scaling up, that you would first preprocess your images and text into corresponding embeddings before training the prior network. You can do so easily by simply passing in `image_embed`, `text_embed`, and optionally `text_encodings` and `text_mask`
Working example below
```python
import torch
from dalle2_pytorch import DiffusionPriorNetwork, DiffusionPrior, CLIP
# get trained CLIP from step one
clip = CLIP(
dim_text = 512,
dim_image = 512,
dim_latent = 512,
num_text_tokens = 49408,
text_enc_depth = 6,
text_seq_len = 256,
text_heads = 8,
visual_enc_depth = 6,
visual_image_size = 256,
visual_patch_size = 32,
visual_heads = 8,
).cuda()
# setup prior network, which contains an autoregressive transformer
prior_network = DiffusionPriorNetwork(
dim = 512,
depth = 6,
dim_head = 64,
heads = 8
).cuda()
# diffusion prior network, which contains the CLIP and network (with transformer) above
diffusion_prior = DiffusionPrior(
net = prior_network,
clip = clip,
timesteps = 100,
cond_drop_prob = 0.2,
condition_on_text_encodings = False # this probably should be true, but just to get Laion started
).cuda()
# mock data
text = torch.randint(0, 49408, (4, 256)).cuda()
images = torch.randn(4, 3, 256, 256).cuda()
# precompute the text and image embeddings
# here using the diffusion prior class, but could be done with CLIP alone
clip_image_embeds = diffusion_prior.get_image_embed(images)
clip_text_embeds = diffusion_prior.get_text_cond(text).get('text_embed')
# feed text and images into diffusion prior network
loss = diffusion_prior(
text_embed = clip_text_embeds,
image_embed = clip_image_embeds
)
loss.backward()
# do the above for many many many steps
# now the diffusion prior can generate image embeddings from the text embeddings
```
## Experimental ## Experimental
### DALL-E2 with Latent Diffusion ### DALL-E2 with Latent Diffusion

View File

@@ -692,13 +692,41 @@ class DiffusionPrior(nn.Module):
top_image_embeds = image_embeds.gather(1, top_sim_indices) top_image_embeds = image_embeds.gather(1, top_sim_indices)
return rearrange(top_image_embeds, 'b 1 d -> b d') return rearrange(top_image_embeds, 'b 1 d -> b d')
def forward(self, text, image, *args, **kwargs): def forward(
b, device, img_size, = image.shape[0], image.device, self.image_size self,
check_shape(image, 'b c h w', h = img_size, w = img_size, c = self.channels) text = None,
image = None,
text_embed = None, # allow for training on preprocessed CLIP text and image embeddings
image_embed = None,
text_encodings = None, # as well as CLIP text encodings
text_mask = None, # text mask <- may eventually opt for the learned padding tokens technique from DALL-E1 to reduce complexity
*args,
**kwargs
):
assert exists(text) ^ exists(text_embed), 'either text or text embedding must be supplied'
assert exists(image) ^ exists(image_embed), 'either text or text embedding must be supplied'
assert not (self.condition_on_text_encodings and (not exists(text_encodings) and not exists(text))), 'text encodings must be present if you specified you wish to condition on it on initialization'
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long) if exists(image):
image_embed = self.get_image_embed(image) image_embed = self.get_image_embed(image)
# calculate text conditionings, based on what is passed in
if exists(text):
text_cond = self.get_text_cond(text) text_cond = self.get_text_cond(text)
else:
text_cond = dict(
text_embed = text_embed,
text_encodings = text_encodings,
mask = text_mask
)
# timestep conditioning from ddpm
batch, device = image_embed.shape[0], image_embed.device
times = torch.randint(0, self.num_timesteps, (batch,), device = device, dtype = torch.long)
# calculate forward loss
loss = self.p_losses(image_embed, times, text_cond = text_cond, *args, **kwargs) loss = self.p_losses(image_embed, times, text_cond = text_cond, *args, **kwargs)
return loss return loss

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.0.47', version = '0.0.48',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',