mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
allow for training the Prior network with precomputed CLIP embeddings (or text encodings)
This commit is contained in:
69
README.md
69
README.md
@@ -376,6 +376,75 @@ You can also train the decoder on images of greater than the size (say 512x512)
|
|||||||
|
|
||||||
For the layperson, no worries, training will all be automated into a CLI tool, at least for small scale training.
|
For the layperson, no worries, training will all be automated into a CLI tool, at least for small scale training.
|
||||||
|
|
||||||
|
## Training on Preprocessed CLIP Embeddings
|
||||||
|
|
||||||
|
It is likely, when scaling up, that you would first preprocess your images and text into corresponding embeddings before training the prior network. You can do so easily by simply passing in `image_embed`, `text_embed`, and optionally `text_encodings` and `text_mask`
|
||||||
|
|
||||||
|
Working example below
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from dalle2_pytorch import DiffusionPriorNetwork, DiffusionPrior, CLIP
|
||||||
|
|
||||||
|
# get trained CLIP from step one
|
||||||
|
|
||||||
|
clip = CLIP(
|
||||||
|
dim_text = 512,
|
||||||
|
dim_image = 512,
|
||||||
|
dim_latent = 512,
|
||||||
|
num_text_tokens = 49408,
|
||||||
|
text_enc_depth = 6,
|
||||||
|
text_seq_len = 256,
|
||||||
|
text_heads = 8,
|
||||||
|
visual_enc_depth = 6,
|
||||||
|
visual_image_size = 256,
|
||||||
|
visual_patch_size = 32,
|
||||||
|
visual_heads = 8,
|
||||||
|
).cuda()
|
||||||
|
|
||||||
|
# setup prior network, which contains an autoregressive transformer
|
||||||
|
|
||||||
|
prior_network = DiffusionPriorNetwork(
|
||||||
|
dim = 512,
|
||||||
|
depth = 6,
|
||||||
|
dim_head = 64,
|
||||||
|
heads = 8
|
||||||
|
).cuda()
|
||||||
|
|
||||||
|
# diffusion prior network, which contains the CLIP and network (with transformer) above
|
||||||
|
|
||||||
|
diffusion_prior = DiffusionPrior(
|
||||||
|
net = prior_network,
|
||||||
|
clip = clip,
|
||||||
|
timesteps = 100,
|
||||||
|
cond_drop_prob = 0.2,
|
||||||
|
condition_on_text_encodings = False # this probably should be true, but just to get Laion started
|
||||||
|
).cuda()
|
||||||
|
|
||||||
|
# mock data
|
||||||
|
|
||||||
|
text = torch.randint(0, 49408, (4, 256)).cuda()
|
||||||
|
images = torch.randn(4, 3, 256, 256).cuda()
|
||||||
|
|
||||||
|
# precompute the text and image embeddings
|
||||||
|
# here using the diffusion prior class, but could be done with CLIP alone
|
||||||
|
|
||||||
|
clip_image_embeds = diffusion_prior.get_image_embed(images)
|
||||||
|
clip_text_embeds = diffusion_prior.get_text_cond(text).get('text_embed')
|
||||||
|
|
||||||
|
# feed text and images into diffusion prior network
|
||||||
|
|
||||||
|
loss = diffusion_prior(
|
||||||
|
text_embed = clip_text_embeds,
|
||||||
|
image_embed = clip_image_embeds
|
||||||
|
)
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
# do the above for many many many steps
|
||||||
|
# now the diffusion prior can generate image embeddings from the text embeddings
|
||||||
|
```
|
||||||
|
|
||||||
## Experimental
|
## Experimental
|
||||||
|
|
||||||
### DALL-E2 with Latent Diffusion
|
### DALL-E2 with Latent Diffusion
|
||||||
|
|||||||
@@ -692,13 +692,41 @@ class DiffusionPrior(nn.Module):
|
|||||||
top_image_embeds = image_embeds.gather(1, top_sim_indices)
|
top_image_embeds = image_embeds.gather(1, top_sim_indices)
|
||||||
return rearrange(top_image_embeds, 'b 1 d -> b d')
|
return rearrange(top_image_embeds, 'b 1 d -> b d')
|
||||||
|
|
||||||
def forward(self, text, image, *args, **kwargs):
|
def forward(
|
||||||
b, device, img_size, = image.shape[0], image.device, self.image_size
|
self,
|
||||||
check_shape(image, 'b c h w', h = img_size, w = img_size, c = self.channels)
|
text = None,
|
||||||
|
image = None,
|
||||||
|
text_embed = None, # allow for training on preprocessed CLIP text and image embeddings
|
||||||
|
image_embed = None,
|
||||||
|
text_encodings = None, # as well as CLIP text encodings
|
||||||
|
text_mask = None, # text mask <- may eventually opt for the learned padding tokens technique from DALL-E1 to reduce complexity
|
||||||
|
*args,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
assert exists(text) ^ exists(text_embed), 'either text or text embedding must be supplied'
|
||||||
|
assert exists(image) ^ exists(image_embed), 'either text or text embedding must be supplied'
|
||||||
|
assert not (self.condition_on_text_encodings and (not exists(text_encodings) and not exists(text))), 'text encodings must be present if you specified you wish to condition on it on initialization'
|
||||||
|
|
||||||
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
|
if exists(image):
|
||||||
image_embed = self.get_image_embed(image)
|
image_embed = self.get_image_embed(image)
|
||||||
|
|
||||||
|
# calculate text conditionings, based on what is passed in
|
||||||
|
|
||||||
|
if exists(text):
|
||||||
text_cond = self.get_text_cond(text)
|
text_cond = self.get_text_cond(text)
|
||||||
|
else:
|
||||||
|
text_cond = dict(
|
||||||
|
text_embed = text_embed,
|
||||||
|
text_encodings = text_encodings,
|
||||||
|
mask = text_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
# timestep conditioning from ddpm
|
||||||
|
|
||||||
|
batch, device = image_embed.shape[0], image_embed.device
|
||||||
|
times = torch.randint(0, self.num_timesteps, (batch,), device = device, dtype = torch.long)
|
||||||
|
|
||||||
|
# calculate forward loss
|
||||||
|
|
||||||
loss = self.p_losses(image_embed, times, text_cond = text_cond, *args, **kwargs)
|
loss = self.p_losses(image_embed, times, text_cond = text_cond, *args, **kwargs)
|
||||||
return loss
|
return loss
|
||||||
|
|||||||
Reference in New Issue
Block a user