no CLIP altogether for training DiffusionPrior

This commit is contained in:
Phil Wang
2022-04-26 10:23:34 -07:00
parent bdf5e9c009
commit c30544b73a
3 changed files with 68 additions and 8 deletions

View File

@@ -446,6 +446,55 @@ loss.backward()
# now the diffusion prior can generate image embeddings from the text embeddings # now the diffusion prior can generate image embeddings from the text embeddings
``` ```
You can also completely go `CLIP`-less, in which case you will need to pass in the `image_embed_dim` into the `DiffusionPrior` on initialization
```python
import torch
from dalle2_pytorch import DiffusionPriorNetwork, DiffusionPrior
# setup prior network, which contains an autoregressive transformer
prior_network = DiffusionPriorNetwork(
dim = 512,
depth = 6,
dim_head = 64,
heads = 8
).cuda()
# diffusion prior network, which contains the CLIP and network (with transformer) above
diffusion_prior = DiffusionPrior(
net = prior_network,
image_embed_dim = 512, # this needs to be set
timesteps = 100,
cond_drop_prob = 0.2,
condition_on_text_encodings = False # this probably should be true, but just to get Laion started
).cuda()
# mock data
text = torch.randint(0, 49408, (4, 256)).cuda()
images = torch.randn(4, 3, 256, 256).cuda()
# precompute the text and image embeddings
# here using the diffusion prior class, but could be done with CLIP alone
clip_image_embeds = torch.randn(4, 512).cuda()
clip_text_embeds = torch.randn(4, 512).cuda()
# feed text and images into diffusion prior network
loss = diffusion_prior(
text_embed = clip_text_embeds,
image_embed = clip_image_embeds
)
loss.backward()
# do the above for many many many steps
# now the diffusion prior can generate image embeddings from the text embeddings
```
## Experimental ## Experimental
### DALL-E2 with Latent Diffusion ### DALL-E2 with Latent Diffusion

View File

@@ -486,7 +486,10 @@ class DiffusionPrior(nn.Module):
self, self,
net, net,
*, *,
clip, clip = None,
image_embed_dim = None,
image_size = None,
image_channels = 3,
timesteps = 1000, timesteps = 1000,
cond_drop_prob = 0.2, cond_drop_prob = 0.2,
loss_type = "l1", loss_type = "l1",
@@ -495,14 +498,18 @@ class DiffusionPrior(nn.Module):
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
): ):
super().__init__() super().__init__()
assert isinstance(clip, CLIP)
freeze_model_and_make_eval_(clip) if exists(clip):
self.clip = clip assert isinstance(clip, CLIP)
freeze_model_and_make_eval_(clip)
self.clip = clip
else:
assert exists(image_embed_dim), 'latent dimension must be given, if training prior network without CLIP given'
self.clip = None
self.net = net self.net = net
self.image_embed_dim = clip.dim_latent self.image_embed_dim = default(image_embed_dim, lambda: clip.dim_latent)
self.channels = clip.image_channels self.channels = default(image_channels, lambda: clip.image_channels)
self.image_size = clip.image_size
self.cond_drop_prob = cond_drop_prob self.cond_drop_prob = cond_drop_prob
self.condition_on_text_encodings = condition_on_text_encodings self.condition_on_text_encodings = condition_on_text_encodings
@@ -559,6 +566,8 @@ class DiffusionPrior(nn.Module):
@torch.no_grad() @torch.no_grad()
def get_image_embed(self, image): def get_image_embed(self, image):
assert exists(self.clip)
image_encoding = self.clip.visual_transformer(image) image_encoding = self.clip.visual_transformer(image)
image_cls = image_encoding[:, 0] image_cls = image_encoding[:, 0]
image_embed = self.clip.to_visual_latent(image_cls) image_embed = self.clip.to_visual_latent(image_cls)
@@ -566,6 +575,8 @@ class DiffusionPrior(nn.Module):
@torch.no_grad() @torch.no_grad()
def get_text_cond(self, text): def get_text_cond(self, text):
assert exists(self.clip)
text_encodings = self.clip.text_transformer(text) text_encodings = self.clip.text_transformer(text)
text_cls, text_encodings = text_encodings[:, 0], text_encodings[:, 1:] text_cls, text_encodings = text_encodings[:, 0], text_encodings[:, 1:]
text_embed = self.clip.to_text_latent(text_cls) text_embed = self.clip.to_text_latent(text_cls)

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.0.49', version = '0.0.50',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',