now completely OpenAI CLIP compatible for training

This commit is contained in:
Phil Wang
2022-04-29 11:26:24 -07:00
parent f4a54e475e
commit 39d3659ad9
5 changed files with 187 additions and 23 deletions

View File

@@ -495,6 +495,96 @@ loss.backward()
# now the diffusion prior can generate image embeddings from the text embeddings # now the diffusion prior can generate image embeddings from the text embeddings
``` ```
## OpenAI CLIP
Although there is the possibility they are using an unreleased, more powerful CLIP, you can use one of the released ones, if you do not wish to train your own CLIP from scratch. This will also allow the community to more quickly validate the conclusions of the paper.
First you'll need to install <a href="https://github.com/openai/CLIP#usage">the prerequisites</a>
Then to use a pretrained OpenAI CLIP, simply import `OpenAIClipAdapter` and pass it into the `DiffusionPrior` or `Decoder` like so
```python
import torch
from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder, OpenAIClipAdapter
# openai pretrained clip - defaults to ViT/B-32
clip = OpenAIClipAdapter()
# mock data
text = torch.randint(0, 49408, (4, 256)).cuda()
images = torch.randn(4, 3, 256, 256).cuda()
# prior networks (with transformer)
prior_network = DiffusionPriorNetwork(
dim = 512,
depth = 6,
dim_head = 64,
heads = 8
).cuda()
diffusion_prior = DiffusionPrior(
net = prior_network,
clip = clip,
timesteps = 100,
cond_drop_prob = 0.2
).cuda()
loss = diffusion_prior(text, images)
loss.backward()
# do above for many steps ...
# decoder (with unet)
unet1 = Unet(
dim = 128,
image_embed_dim = 512,
cond_dim = 128,
channels = 3,
dim_mults=(1, 2, 4, 8)
).cuda()
unet2 = Unet(
dim = 16,
image_embed_dim = 512,
cond_dim = 128,
channels = 3,
dim_mults = (1, 2, 4, 8, 16)
).cuda()
decoder = Decoder(
unet = (unet1, unet2),
image_sizes = (128, 256),
clip = clip,
timesteps = 100,
cond_drop_prob = 0.2,
condition_on_text_encodings = False # set this to True if you wish to condition on text during training and sampling
).cuda()
for unet_number in (1, 2):
loss = decoder(images, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
loss.backward()
# do above for many steps
dalle2 = DALLE2(
prior = diffusion_prior,
decoder = decoder
)
images = dalle2(
['a butterfly trying to escape a tornado'],
cond_scale = 2. # classifier free guidance strength (> 1 would strengthen the condition)
)
# save your image (in this example, of size 256x256)
```
Now you'll just have to worry about training the Prior and the Decoder!
## Experimental ## Experimental
### DALL-E2 with Latent Diffusion ### DALL-E2 with Latent Diffusion

View File

@@ -1,4 +1,5 @@
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter
from dalle2_pytorch.vqgan_vae import VQGanVAE from dalle2_pytorch.vqgan_vae import VQGanVAE
from x_clip import CLIP from x_clip import CLIP

View File

@@ -90,6 +90,16 @@ def resize_image_to(t, image_size, mode = 'bilinear'): # take a look at https://
return F.interpolate(t, size = shape, mode = mode, align_corners = False) return F.interpolate(t, size = shape, mode = mode, align_corners = False)
# image normalization functions
# ddpms expect images to be in the range of -1 to 1
# but CLIP may otherwise
def normalize_img(img):
return img * 2 - 1
def unnormalize_img(normed_img):
return (normed_img + 1) * 0.5
# clip related adapters # clip related adapters
class BaseClipAdapter(nn.Module): class BaseClipAdapter(nn.Module):
@@ -109,6 +119,10 @@ class BaseClipAdapter(nn.Module):
def image_channels(self): def image_channels(self):
raise NotImplementedError raise NotImplementedError
@property
def max_text_len(self):
raise NotImplementedError
def embed_text(self, text): def embed_text(self, text):
raise NotImplementedError raise NotImplementedError
@@ -128,12 +142,18 @@ class XClipAdapter(BaseClipAdapter):
def image_channels(self): def image_channels(self):
return self.clip.image_channels return self.clip.image_channels
@property
def max_text_len(self):
return self.clip.text_seq_len
@torch.no_grad() @torch.no_grad()
def embed_text(self, text): def embed_text(self, text):
text = text[..., :self.max_text_len]
text_mask = text != 0
encoder_output = self.clip.text_transformer(text) encoder_output = self.clip.text_transformer(text)
text_cls, text_encodings = encoder_output[:, 0], encoder_output[:, 1:] text_cls, text_encodings = encoder_output[:, 0], encoder_output[:, 1:]
text_embed = self.clip.to_text_latent(text_cls) text_embed = self.clip.to_text_latent(text_cls)
return l2norm(text_embed), text_encodings return l2norm(text_embed), text_encodings, text_mask
@torch.no_grad() @torch.no_grad()
def embed_image(self, image): def embed_image(self, image):
@@ -143,6 +163,72 @@ class XClipAdapter(BaseClipAdapter):
image_embed = self.clip.to_visual_latent(image_cls) image_embed = self.clip.to_visual_latent(image_cls)
return l2norm(image_embed), image_encodings return l2norm(image_embed), image_encodings
class OpenAIClipAdapter(BaseClipAdapter):
def __init__(
self,
name = 'ViT-B/32'
):
try:
import clip
except ImportError:
print('you must install openai clip in order to use this adapter - `pip install git+https://github.com/openai/CLIP.git` - more instructions at https://github.com/openai/CLIP#usage')
openai_clip, _ = clip.load(name)
super().__init__(openai_clip)
text_attention_final = self.find_layer('ln_final')
self.handle = text_attention_final.register_forward_hook(self._hook)
self.clip_normalize = T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
self.cleared = False
def find_layer(self, layer):
modules = dict([*self.clip.named_modules()])
return modules.get(layer, None)
def clear(self):
if self.cleared:
return
self.handle()
def _hook(self, _, inputs, outputs):
self.text_encodings = outputs
@property
def dim_latent(self):
return 512
@property
def image_size(self):
return self.clip.visual.input_resolution
@property
def image_channels(self):
return 3
@property
def max_text_len(self):
return self.clip.context_length
@torch.no_grad()
def embed_text(self, text):
text = text[..., :self.max_text_len]
text_mask = text != 0
assert not self.cleared
text_embed = self.clip.encode_text(text)
text_encodings = self.text_encodings
del self.text_encodings
return text_embed.float(), text_encodings.float(), text_mask
@torch.no_grad()
def embed_image(self, image):
assert not self.cleared
image = resize_image_to(image, self.image_size)
image = self.clip_normalize(unnormalize_img(image))
image_embed = self.clip.encode_image(image)
return image_embed.float(), None
# classifier free guidance functions # classifier free guidance functions
def prob_mask_like(shape, prob, device): def prob_mask_like(shape, prob, device):
@@ -741,12 +827,12 @@ class DiffusionPrior(BaseGaussianDiffusion):
batch_size = text.shape[0] batch_size = text.shape[0]
image_embed_dim = self.image_embed_dim image_embed_dim = self.image_embed_dim
text_embed, text_encodings = self.clip.embed_text(text) text_embed, text_encodings, text_mask = self.clip.embed_text(text)
text_cond = dict(text_embed = text_embed) text_cond = dict(text_embed = text_embed)
if self.condition_on_text_encodings: if self.condition_on_text_encodings:
text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text != 0} text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask}
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond) image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond)
text_embeds = text_cond['text_embed'] text_embeds = text_cond['text_embed']
@@ -783,8 +869,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
# calculate text conditionings, based on what is passed in # calculate text conditionings, based on what is passed in
if exists(text): if exists(text):
text_embed, text_encodings = self.clip.embed_text(text) text_embed, text_encodings, text_mask = self.clip.embed_text(text)
text_mask = text != 0
text_cond = dict(text_embed = text_embed) text_cond = dict(text_embed = text_embed)
@@ -1341,11 +1426,8 @@ class Decoder(BaseGaussianDiffusion):
@torch.no_grad() @torch.no_grad()
def get_image_embed(self, image): def get_image_embed(self, image):
image = resize_image_to(image, self.clip_image_size) image_embed, _ = self.clip.embed_image(image)
image_encoding = self.clip.visual_transformer(image) return image_embed
image_cls = image_encoding[:, 0]
image_embed = self.clip.to_visual_latent(image_cls)
return l2norm(image_embed)
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, cond_scale = 1.): def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, cond_scale = 1.):
pred = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img) pred = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
@@ -1417,7 +1499,7 @@ class Decoder(BaseGaussianDiffusion):
text_encodings = None text_encodings = None
if exists(text): if exists(text):
_, text_encodings = self.clip.embed_text(text) _, text_encodings, _ = self.clip.embed_text(text)
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified' assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented' assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
@@ -1485,7 +1567,7 @@ class Decoder(BaseGaussianDiffusion):
text_encodings = None text_encodings = None
if exists(text) and not exists(text_encodings): if exists(text) and not exists(text_encodings):
_, text_encodings = self.clip.embed_text(text) _, text_encodings, _ = self.clip.embed_text(text)
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified' assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented' assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'

View File

@@ -2,15 +2,6 @@ import copy
import torch import torch
from torch import nn from torch import nn
# image related normalizations
# ddpms expect images to be in the range of -1 to 1
def normalize_img(img):
return img * 2 - 1
def unnormalize_img(normed_img):
return (normed_img + 1) * 0.5
# exponential moving average wrapper # exponential moving average wrapper
class EMA(nn.Module): class EMA(nn.Module):

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.0.65', version = '0.0.67',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',
@@ -31,7 +31,7 @@ setup(
'torchvision', 'torchvision',
'tqdm', 'tqdm',
'vector-quantize-pytorch', 'vector-quantize-pytorch',
'x-clip>=0.4.4', 'x-clip>=0.5.1',
'youtokentome' 'youtokentome'
], ],
classifiers=[ classifiers=[