Compare commits

..

9 Commits

Author SHA1 Message Date
Phil Wang
f19c99ecb0 fix decoder needing separate conditional dropping probabilities for image embeddings and text encodings, thanks to @xiankgx ! 2022-04-30 08:48:05 -07:00
Phil Wang
721a444686 Merge pull request #37 from ProGamerGov/patch-1
Fix spelling and grammatical errors
2022-04-30 08:19:07 -07:00
ProGamerGov
63450b466d Fix spelling and grammatical errors 2022-04-30 09:18:13 -06:00
Phil Wang
20e7eb5a9b cleanup 2022-04-30 07:22:57 -07:00
Phil Wang
e2f9615afa use @clip-anytorch , thanks to @rom1504 2022-04-30 06:40:54 -07:00
Phil Wang
0d1c07c803 fix a bug with classifier free guidance, thanks to @xiankgx again! 2022-04-30 06:34:57 -07:00
Phil Wang
a389f81138 todo 2022-04-29 15:40:51 -07:00
Phil Wang
0283556608 fix example in readme, since api changed 2022-04-29 13:40:55 -07:00
Phil Wang
5063d192b6 now completely OpenAI CLIP compatible for training
just take care of the logic for AdamW and transformers

used namedtuples for clip adapter embedding outputs
2022-04-29 13:05:01 -07:00
4 changed files with 48 additions and 122 deletions

View File

@@ -47,7 +47,7 @@ clip = CLIP(
use_all_token_embeds = True, # whether to use fine-grained contrastive learning (FILIP)
decoupled_contrastive_learning = True, # use decoupled contrastive learning (DCL) objective function, removing positive pairs from the denominator of the InfoNCE loss (CLOOB + DCL)
extra_latent_projection = True, # whether to use separate projections for text-to-image vs image-to-text comparisons (CLOOB)
use_visual_ssl = True, # whether to do self supervised learning on iages
use_visual_ssl = True, # whether to do self supervised learning on images
visual_ssl_type = 'simclr', # can be either 'simclr' or 'simsiam', depending on using DeCLIP or SLIP
use_mlm = False, # use masked language learning (MLM) on text (DeCLIP)
text_ssl_loss_weight = 0.05, # weight for text MLM loss
@@ -110,7 +110,8 @@ decoder = Decoder(
unet = unet,
clip = clip,
timesteps = 100,
cond_drop_prob = 0.2
image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5
).cuda()
# mock images (get a lot of this)
@@ -229,7 +230,8 @@ decoder = Decoder(
unet = (unet1, unet2), # insert both unets in order of low resolution to highest resolution (you can have as many stages as you want here)
image_sizes = (256, 512), # resolutions, 256 for first unet, 512 for second. these must be unique and in ascending order (matches with the unets passed in)
timesteps = 1000,
cond_drop_prob = 0.2
image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5
).cuda()
# mock images (get a lot of this)
@@ -348,7 +350,8 @@ decoder = Decoder(
image_sizes = (128, 256),
clip = clip,
timesteps = 100,
cond_drop_prob = 0.2,
image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5,
condition_on_text_encodings = False # set this to True if you wish to condition on text during training and sampling
).cuda()
@@ -430,8 +433,8 @@ images = torch.randn(4, 3, 256, 256).cuda()
# precompute the text and image embeddings
# here using the diffusion prior class, but could be done with CLIP alone
clip_image_embeds = diffusion_prior.get_image_embed(images)
clip_text_embeds = diffusion_prior.get_text_cond(text).get('text_embed')
clip_image_embeds = diffusion_prior.clip.embed_image(images).image_embed
clip_text_embeds = diffusion_prior.clip.embed_text(text).text_embed
# feed text and images into diffusion prior network
@@ -499,9 +502,7 @@ loss.backward()
Although there is the possibility they are using an unreleased, more powerful CLIP, you can use one of the released ones, if you do not wish to train your own CLIP from scratch. This will also allow the community to more quickly validate the conclusions of the paper.
First you'll need to install <a href="https://github.com/openai/CLIP#usage">the prerequisites</a>
Then to use a pretrained OpenAI CLIP, simply import `OpenAIClipAdapter` and pass it into the `DiffusionPrior` or `Decoder` like so
To use a pretrained OpenAI CLIP, simply import `OpenAIClipAdapter` and pass it into the `DiffusionPrior` or `Decoder` like so
```python
import torch
@@ -560,7 +561,8 @@ decoder = Decoder(
image_sizes = (128, 256),
clip = clip,
timesteps = 100,
cond_drop_prob = 0.2,
image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5,
condition_on_text_encodings = False # set this to True if you wish to condition on text during training and sampling
).cuda()
@@ -618,7 +620,7 @@ clip = CLIP(
# 3 unets for the decoder (a la cascading DDPM)
# first two unets are doing latent diffusion
# vqgan-vae must be trained before hand
# vqgan-vae must be trained beforehand
vae1 = VQGanVAE(
dim = 32,
@@ -671,7 +673,8 @@ decoder = Decoder(
unet = (unet1, unet2, unet3), # insert unets in order of low resolution to highest resolution (you can have as many stages as you want here)
image_sizes = (256, 512, 1024), # resolutions, 256 for first unet, 512 for second, 1024 for third
timesteps = 100,
cond_drop_prob = 0.2
image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5
).cuda()
# mock images (get a lot of this)
@@ -741,6 +744,7 @@ Once built, images will be saved to the same directory the command is invoked
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
- [ ] just take care of the training for the decoder in a wrapper class, as each unet in the cascade will need its own optimizer
- [ ] train on a toy task, offer in colab
- [ ] think about how best to design a declarative training config that handles preencoding for prior and training of multiple networks in decoder
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference

View File

@@ -3,6 +3,7 @@ from tqdm import tqdm
from inspect import isfunction
from functools import partial
from contextlib import contextmanager
from collections import namedtuple
import torch
import torch.nn.functional as F
@@ -102,6 +103,9 @@ def unnormalize_img(normed_img):
# clip related adapters
EmbeddedText = namedtuple('EmbedTextReturn', ['text_embed', 'text_encodings', 'text_mask'])
EmbeddedImage = namedtuple('EmbedImageReturn', ['image_embed', 'image_encodings'])
class BaseClipAdapter(nn.Module):
def __init__(self, clip):
super().__init__()
@@ -153,7 +157,7 @@ class XClipAdapter(BaseClipAdapter):
encoder_output = self.clip.text_transformer(text)
text_cls, text_encodings = encoder_output[:, 0], encoder_output[:, 1:]
text_embed = self.clip.to_text_latent(text_cls)
return l2norm(text_embed), text_encodings, text_mask
return EmbeddedText(l2norm(text_embed), text_encodings, text_mask)
@torch.no_grad()
def embed_image(self, image):
@@ -161,24 +165,20 @@ class XClipAdapter(BaseClipAdapter):
encoder_output = self.clip.visual_transformer(image)
image_cls, image_encodings = encoder_output[:, 0], encoder_output[:, 1:]
image_embed = self.clip.to_visual_latent(image_cls)
return l2norm(image_embed), image_encodings
return EmbeddedImage(l2norm(image_embed), image_encodings)
class OpenAIClipAdapter(BaseClipAdapter):
def __init__(
self,
name = 'ViT-B/32'
):
try:
import clip
except ImportError:
print('you must install openai clip in order to use this adapter - `pip install git+https://github.com/openai/CLIP.git` - more instructions at https://github.com/openai/CLIP#usage')
openai_clip, _ = clip.load(name)
import clip
openai_clip, preprocess = clip.load(name)
super().__init__(openai_clip)
text_attention_final = self.find_layer('ln_final')
self.handle = text_attention_final.register_forward_hook(self._hook)
self.clip_normalize = T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
self.clip_normalize = preprocess.transforms[-1]
self.cleared = False
def find_layer(self, layer):
@@ -219,7 +219,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
text_embed = self.clip.encode_text(text)
text_encodings = self.text_encodings
del self.text_encodings
return text_embed.float(), text_encodings.float(), text_mask
return EmbeddedText(text_embed.float(), text_encodings.float(), text_mask)
@torch.no_grad()
def embed_image(self, image):
@@ -227,7 +227,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
image = resize_image_to(image, self.image_size)
image = self.clip_normalize(unnormalize_img(image))
image_embed = self.clip.encode_image(image)
return image_embed.float(), None
return EmbeddedImage(image_embed.float(), None)
# classifier free guidance functions
@@ -684,14 +684,14 @@ class DiffusionPriorNetwork(nn.Module):
# classifier free guidance
cond_prob_mask = prob_mask_like((batch,), cond_drop_prob, device = device)
cond_prob_mask = rearrange(cond_prob_mask, 'b -> b 1')
keep_mask = prob_mask_like((batch,), 1 - cond_drop_prob, device = device)
keep_mask = rearrange(keep_mask, 'b -> b 1')
mask &= cond_prob_mask
mask &= keep_mask
# whether text embedding is masked or not depends on the classifier free guidance conditional masking
mask = torch.cat((mask, cond_prob_mask), dim = 1)
mask = torch.cat((mask, keep_mask), dim = 1)
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
# but let's just do it right
@@ -1174,7 +1174,7 @@ class Unet(nn.Module):
if cond_scale == 1:
return logits
null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
null_logits = self.forward(*args, text_cond_drop_prob = 1., image_cond_drop_prob = 1., **kwargs)
return null_logits + (logits - null_logits) * cond_scale
def forward(
@@ -1185,7 +1185,8 @@ class Unet(nn.Module):
image_embed,
lowres_cond_img = None,
text_encodings = None,
cond_drop_prob = 0.,
image_cond_drop_prob = 0.,
text_cond_drop_prob = 0.,
blur_sigma = None,
blur_kernel_size = None
):
@@ -1204,8 +1205,10 @@ class Unet(nn.Module):
# conditional dropout
cond_prob_mask = prob_mask_like((batch_size,), cond_drop_prob, device = device)
cond_prob_mask = rearrange(cond_prob_mask, 'b -> b 1 1')
image_keep_mask = prob_mask_like((batch_size,), 1 - image_cond_drop_prob, device = device)
text_keep_mask = prob_mask_like((batch_size,), 1 - text_cond_drop_prob, device = device)
image_keep_mask, text_keep_mask = rearrange_many((image_keep_mask, text_keep_mask), 'b -> b 1 1')
# mask out image embedding depending on condition dropout
# for classifier free guidance
@@ -1216,7 +1219,7 @@ class Unet(nn.Module):
image_tokens = self.image_to_cond(image_embed)
image_tokens = torch.where(
cond_prob_mask,
image_keep_mask,
image_tokens,
self.null_image_embed
)
@@ -1228,7 +1231,7 @@ class Unet(nn.Module):
if exists(text_encodings) and self.cond_on_text_encodings:
text_tokens = self.text_to_cond(text_encodings)
text_tokens = torch.where(
cond_prob_mask,
text_keep_mask,
text_tokens,
self.null_text_embed[:, :text_tokens.shape[1]]
)
@@ -1318,7 +1321,8 @@ class Decoder(BaseGaussianDiffusion):
clip,
vae = tuple(),
timesteps = 1000,
cond_drop_prob = 0.2,
image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5,
loss_type = 'l1',
beta_schedule = 'cosine',
predict_x_start = False,
@@ -1402,7 +1406,8 @@ class Decoder(BaseGaussianDiffusion):
# classifier free guidance
self.cond_drop_prob = cond_drop_prob
self.image_cond_drop_prob = image_cond_drop_prob
self.text_cond_drop_prob = text_cond_drop_prob
def get_unet(self, unet_number):
assert 0 < unet_number <= len(self.unets)
@@ -1484,7 +1489,8 @@ class Decoder(BaseGaussianDiffusion):
image_embed = image_embed,
text_encodings = text_encodings,
lowres_cond_img = lowres_cond_img,
cond_drop_prob = self.cond_drop_prob
image_cond_drop_prob = self.image_cond_drop_prob,
text_cond_drop_prob = self.text_cond_drop_prob,
)
target = noise if not predict_x_start else x_start
@@ -1632,4 +1638,3 @@ class DALLE2(nn.Module):
return images[0]
return images

View File

@@ -1,84 +0,0 @@
import torch
from PIL import Image
from dalle2_pytorch.dalle2_pytorch import BaseClipAdapter
import torchvision.transforms as T
def find_layer(model, layer):
modules = dict([*model.named_modules()])
return modules.get(layer, None)
def hook(_, input, output):
print(output.shape)
import clip
# image = preprocess(Image.open("CLIP.png")).unsqueeze(0).to(device)
text = clip.tokenize(["a diagram", "a dog", "a cat"]).cuda()
image = torch.randn(1, 3, 224, 224).cuda()
class OpenAIClipAdapter(BaseClipAdapter):
def __init__(self, name = 'ViT-B/32'):
try:
import clip
except ImportError:
print('you must install openai clip in order to use this adapter - `pip install git+https://github.com/openai/CLIP.git` - more instructions at https://github.com/openai/CLIP#usage')
openai_clip, _ = clip.load(name)
super().__init__(openai_clip)
text_attention_final = self.find_layer(self.clip, 'ln_final')
self.handle = text_attention_final.register_forward_hook(self._hook)
self.clip_normalize = T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
self.cleared = False
def find_layer(self, layer):
modules = dict([*self.clip.named_modules()])
return modules.get(layer, None)
def clear(self):
if self.cleared:
return
self.handle()
def _hook(self, _, inputs, outputs):
self.text_encodings = outputs
@property
def dim_latent(self):
return 512
@property
def image_size(self):
return self.clip.visual.input_resolution
@property
def image_channels(self):
return 3
@torch.no_grad()
def embed_text(self, text):
assert not self.cleared
text_embed = self.clip.encode_text(text)
text_encodings = self.text_encodings
del self.text_encodings
return text_embed, text_encodings
@torch.no_grad()
def embed_image(self, image):
assert not self.cleared
image = self.clip_normalize(image)
image_embed = self.clip.encode_image(image)
return image_embed, None
clip_adapter = OpenAIClipAdapter().cuda()
# print(model)
with torch.no_grad():
image_features, _ = clip_adapter.embed_image(image)
text_features, text_encodings = clip_adapter.embed_text(text)
print(text_features.shape, image_features.shape)
print(text_encodings.shape)

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.70',
version = '0.0.74',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',
@@ -23,6 +23,7 @@ setup(
],
install_requires=[
'click',
'clip-anytorch',
'einops>=0.4',
'einops-exts>=0.0.3',
'kornia>=0.5.4',