mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
no CLIP altogether for training DiffusionPrior
This commit is contained in:
49
README.md
49
README.md
@@ -446,6 +446,55 @@ loss.backward()
|
|||||||
# now the diffusion prior can generate image embeddings from the text embeddings
|
# now the diffusion prior can generate image embeddings from the text embeddings
|
||||||
```
|
```
|
||||||
|
|
||||||
|
You can also completely go `CLIP`-less, in which case you will need to pass in the `image_embed_dim` into the `DiffusionPrior` on initialization
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from dalle2_pytorch import DiffusionPriorNetwork, DiffusionPrior
|
||||||
|
|
||||||
|
# setup prior network, which contains an autoregressive transformer
|
||||||
|
|
||||||
|
prior_network = DiffusionPriorNetwork(
|
||||||
|
dim = 512,
|
||||||
|
depth = 6,
|
||||||
|
dim_head = 64,
|
||||||
|
heads = 8
|
||||||
|
).cuda()
|
||||||
|
|
||||||
|
# diffusion prior network, which contains the CLIP and network (with transformer) above
|
||||||
|
|
||||||
|
diffusion_prior = DiffusionPrior(
|
||||||
|
net = prior_network,
|
||||||
|
image_embed_dim = 512, # this needs to be set
|
||||||
|
timesteps = 100,
|
||||||
|
cond_drop_prob = 0.2,
|
||||||
|
condition_on_text_encodings = False # this probably should be true, but just to get Laion started
|
||||||
|
).cuda()
|
||||||
|
|
||||||
|
# mock data
|
||||||
|
|
||||||
|
text = torch.randint(0, 49408, (4, 256)).cuda()
|
||||||
|
images = torch.randn(4, 3, 256, 256).cuda()
|
||||||
|
|
||||||
|
# precompute the text and image embeddings
|
||||||
|
# here using the diffusion prior class, but could be done with CLIP alone
|
||||||
|
|
||||||
|
clip_image_embeds = torch.randn(4, 512).cuda()
|
||||||
|
clip_text_embeds = torch.randn(4, 512).cuda()
|
||||||
|
|
||||||
|
# feed text and images into diffusion prior network
|
||||||
|
|
||||||
|
loss = diffusion_prior(
|
||||||
|
text_embed = clip_text_embeds,
|
||||||
|
image_embed = clip_image_embeds
|
||||||
|
)
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
# do the above for many many many steps
|
||||||
|
# now the diffusion prior can generate image embeddings from the text embeddings
|
||||||
|
```
|
||||||
|
|
||||||
## Experimental
|
## Experimental
|
||||||
|
|
||||||
### DALL-E2 with Latent Diffusion
|
### DALL-E2 with Latent Diffusion
|
||||||
|
|||||||
@@ -486,7 +486,10 @@ class DiffusionPrior(nn.Module):
|
|||||||
self,
|
self,
|
||||||
net,
|
net,
|
||||||
*,
|
*,
|
||||||
clip,
|
clip = None,
|
||||||
|
image_embed_dim = None,
|
||||||
|
image_size = None,
|
||||||
|
image_channels = 3,
|
||||||
timesteps = 1000,
|
timesteps = 1000,
|
||||||
cond_drop_prob = 0.2,
|
cond_drop_prob = 0.2,
|
||||||
loss_type = "l1",
|
loss_type = "l1",
|
||||||
@@ -495,14 +498,18 @@ class DiffusionPrior(nn.Module):
|
|||||||
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
|
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
if exists(clip):
|
||||||
assert isinstance(clip, CLIP)
|
assert isinstance(clip, CLIP)
|
||||||
freeze_model_and_make_eval_(clip)
|
freeze_model_and_make_eval_(clip)
|
||||||
self.clip = clip
|
self.clip = clip
|
||||||
|
else:
|
||||||
|
assert exists(image_embed_dim), 'latent dimension must be given, if training prior network without CLIP given'
|
||||||
|
self.clip = None
|
||||||
|
|
||||||
self.net = net
|
self.net = net
|
||||||
self.image_embed_dim = clip.dim_latent
|
self.image_embed_dim = default(image_embed_dim, lambda: clip.dim_latent)
|
||||||
self.channels = clip.image_channels
|
self.channels = default(image_channels, lambda: clip.image_channels)
|
||||||
self.image_size = clip.image_size
|
|
||||||
|
|
||||||
self.cond_drop_prob = cond_drop_prob
|
self.cond_drop_prob = cond_drop_prob
|
||||||
self.condition_on_text_encodings = condition_on_text_encodings
|
self.condition_on_text_encodings = condition_on_text_encodings
|
||||||
@@ -559,6 +566,8 @@ class DiffusionPrior(nn.Module):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def get_image_embed(self, image):
|
def get_image_embed(self, image):
|
||||||
|
assert exists(self.clip)
|
||||||
|
|
||||||
image_encoding = self.clip.visual_transformer(image)
|
image_encoding = self.clip.visual_transformer(image)
|
||||||
image_cls = image_encoding[:, 0]
|
image_cls = image_encoding[:, 0]
|
||||||
image_embed = self.clip.to_visual_latent(image_cls)
|
image_embed = self.clip.to_visual_latent(image_cls)
|
||||||
@@ -566,6 +575,8 @@ class DiffusionPrior(nn.Module):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def get_text_cond(self, text):
|
def get_text_cond(self, text):
|
||||||
|
assert exists(self.clip)
|
||||||
|
|
||||||
text_encodings = self.clip.text_transformer(text)
|
text_encodings = self.clip.text_transformer(text)
|
||||||
text_cls, text_encodings = text_encodings[:, 0], text_encodings[:, 1:]
|
text_cls, text_encodings = text_encodings[:, 0], text_encodings[:, 1:]
|
||||||
text_embed = self.clip.to_text_latent(text_cls)
|
text_embed = self.clip.to_text_latent(text_cls)
|
||||||
|
|||||||
Reference in New Issue
Block a user