From 5063d192b61f71f3c48e8b3d2037599c2156d472 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Fri, 29 Apr 2022 13:05:01 -0700 Subject: [PATCH] now completely OpenAI CLIP compatible for training just take care of the logic for AdamW and transformers used namedtuples for clip adapter embedding outputs --- README.md | 90 +++++++++++++++++++++++++ dalle2_pytorch/__init__.py | 1 + dalle2_pytorch/dalle2_pytorch.py | 112 +++++++++++++++++++++++++++---- dalle2_pytorch/optimizer.py | 29 ++++++++ dalle2_pytorch/train.py | 9 --- setup.py | 4 +- 6 files changed, 221 insertions(+), 24 deletions(-) create mode 100644 dalle2_pytorch/optimizer.py diff --git a/README.md b/README.md index 277fef0..e4fee0b 100644 --- a/README.md +++ b/README.md @@ -495,6 +495,96 @@ loss.backward() # now the diffusion prior can generate image embeddings from the text embeddings ``` +## OpenAI CLIP + +Although there is the possibility they are using an unreleased, more powerful CLIP, you can use one of the released ones, if you do not wish to train your own CLIP from scratch. This will also allow the community to more quickly validate the conclusions of the paper. + +First you'll need to install the prerequisites + +Then to use a pretrained OpenAI CLIP, simply import `OpenAIClipAdapter` and pass it into the `DiffusionPrior` or `Decoder` like so + +```python +import torch +from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder, OpenAIClipAdapter + +# openai pretrained clip - defaults to ViT/B-32 + +clip = OpenAIClipAdapter() + +# mock data + +text = torch.randint(0, 49408, (4, 256)).cuda() +images = torch.randn(4, 3, 256, 256).cuda() + +# prior networks (with transformer) + +prior_network = DiffusionPriorNetwork( + dim = 512, + depth = 6, + dim_head = 64, + heads = 8 +).cuda() + +diffusion_prior = DiffusionPrior( + net = prior_network, + clip = clip, + timesteps = 100, + cond_drop_prob = 0.2 +).cuda() + +loss = diffusion_prior(text, images) +loss.backward() + +# do above for many steps ... + +# decoder (with unet) + +unet1 = Unet( + dim = 128, + image_embed_dim = 512, + cond_dim = 128, + channels = 3, + dim_mults=(1, 2, 4, 8) +).cuda() + +unet2 = Unet( + dim = 16, + image_embed_dim = 512, + cond_dim = 128, + channels = 3, + dim_mults = (1, 2, 4, 8, 16) +).cuda() + +decoder = Decoder( + unet = (unet1, unet2), + image_sizes = (128, 256), + clip = clip, + timesteps = 100, + cond_drop_prob = 0.2, + condition_on_text_encodings = False # set this to True if you wish to condition on text during training and sampling +).cuda() + +for unet_number in (1, 2): + loss = decoder(images, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much + loss.backward() + +# do above for many steps + +dalle2 = DALLE2( + prior = diffusion_prior, + decoder = decoder +) + +images = dalle2( + ['a butterfly trying to escape a tornado'], + cond_scale = 2. # classifier free guidance strength (> 1 would strengthen the condition) +) + +# save your image (in this example, of size 256x256) +``` + +Now you'll just have to worry about training the Prior and the Decoder! + ## Experimental ### DALL-E2 with Latent Diffusion diff --git a/dalle2_pytorch/__init__.py b/dalle2_pytorch/__init__.py index 5c3290d..dc0cc71 100644 --- a/dalle2_pytorch/__init__.py +++ b/dalle2_pytorch/__init__.py @@ -1,4 +1,5 @@ from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder +from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter from dalle2_pytorch.vqgan_vae import VQGanVAE from x_clip import CLIP diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 9a1d6f9..e3d5801 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -3,6 +3,7 @@ from tqdm import tqdm from inspect import isfunction from functools import partial from contextlib import contextmanager +from collections import namedtuple import torch import torch.nn.functional as F @@ -90,8 +91,21 @@ def resize_image_to(t, image_size, mode = 'bilinear'): # take a look at https:// return F.interpolate(t, size = shape, mode = mode, align_corners = False) +# image normalization functions +# ddpms expect images to be in the range of -1 to 1 +# but CLIP may otherwise + +def normalize_img(img): + return img * 2 - 1 + +def unnormalize_img(normed_img): + return (normed_img + 1) * 0.5 + # clip related adapters +EmbeddedText = namedtuple('EmbedTextReturn', ['text_embed', 'text_encodings', 'text_mask']) +EmbeddedImage = namedtuple('EmbedImageReturn', ['image_embed', 'image_encodings']) + class BaseClipAdapter(nn.Module): def __init__(self, clip): super().__init__() @@ -109,6 +123,10 @@ class BaseClipAdapter(nn.Module): def image_channels(self): raise NotImplementedError + @property + def max_text_len(self): + raise NotImplementedError + def embed_text(self, text): raise NotImplementedError @@ -128,12 +146,18 @@ class XClipAdapter(BaseClipAdapter): def image_channels(self): return self.clip.image_channels + @property + def max_text_len(self): + return self.clip.text_seq_len + @torch.no_grad() def embed_text(self, text): + text = text[..., :self.max_text_len] + text_mask = text != 0 encoder_output = self.clip.text_transformer(text) text_cls, text_encodings = encoder_output[:, 0], encoder_output[:, 1:] text_embed = self.clip.to_text_latent(text_cls) - return l2norm(text_embed), text_encodings + return EmbeddedText(l2norm(text_embed), text_encodings, text_mask) @torch.no_grad() def embed_image(self, image): @@ -141,7 +165,73 @@ class XClipAdapter(BaseClipAdapter): encoder_output = self.clip.visual_transformer(image) image_cls, image_encodings = encoder_output[:, 0], encoder_output[:, 1:] image_embed = self.clip.to_visual_latent(image_cls) - return l2norm(image_embed), image_encodings + return EmbeddedImage(l2norm(image_embed), image_encodings) + +class OpenAIClipAdapter(BaseClipAdapter): + def __init__( + self, + name = 'ViT-B/32' + ): + try: + import clip + except ImportError: + print('you must install openai clip in order to use this adapter - `pip install git+https://github.com/openai/CLIP.git` - more instructions at https://github.com/openai/CLIP#usage') + + openai_clip, _ = clip.load(name) + super().__init__(openai_clip) + + text_attention_final = self.find_layer('ln_final') + self.handle = text_attention_final.register_forward_hook(self._hook) + self.clip_normalize = T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) + self.cleared = False + + def find_layer(self, layer): + modules = dict([*self.clip.named_modules()]) + return modules.get(layer, None) + + def clear(self): + if self.cleared: + return + + self.handle() + + def _hook(self, _, inputs, outputs): + self.text_encodings = outputs + + @property + def dim_latent(self): + return 512 + + @property + def image_size(self): + return self.clip.visual.input_resolution + + @property + def image_channels(self): + return 3 + + @property + def max_text_len(self): + return self.clip.context_length + + @torch.no_grad() + def embed_text(self, text): + text = text[..., :self.max_text_len] + text_mask = text != 0 + assert not self.cleared + + text_embed = self.clip.encode_text(text) + text_encodings = self.text_encodings + del self.text_encodings + return EmbeddedText(text_embed.float(), text_encodings.float(), text_mask) + + @torch.no_grad() + def embed_image(self, image): + assert not self.cleared + image = resize_image_to(image, self.image_size) + image = self.clip_normalize(unnormalize_img(image)) + image_embed = self.clip.encode_image(image) + return EmbeddedImage(image_embed.float(), None) # classifier free guidance functions @@ -741,12 +831,12 @@ class DiffusionPrior(BaseGaussianDiffusion): batch_size = text.shape[0] image_embed_dim = self.image_embed_dim - text_embed, text_encodings = self.clip.embed_text(text) + text_embed, text_encodings, text_mask = self.clip.embed_text(text) text_cond = dict(text_embed = text_embed) if self.condition_on_text_encodings: - text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text != 0} + text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask} image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond) text_embeds = text_cond['text_embed'] @@ -783,8 +873,7 @@ class DiffusionPrior(BaseGaussianDiffusion): # calculate text conditionings, based on what is passed in if exists(text): - text_embed, text_encodings = self.clip.embed_text(text) - text_mask = text != 0 + text_embed, text_encodings, text_mask = self.clip.embed_text(text) text_cond = dict(text_embed = text_embed) @@ -1341,11 +1430,8 @@ class Decoder(BaseGaussianDiffusion): @torch.no_grad() def get_image_embed(self, image): - image = resize_image_to(image, self.clip_image_size) - image_encoding = self.clip.visual_transformer(image) - image_cls = image_encoding[:, 0] - image_embed = self.clip.to_visual_latent(image_cls) - return l2norm(image_embed) + image_embed, _ = self.clip.embed_image(image) + return image_embed def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, cond_scale = 1.): pred = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img) @@ -1417,7 +1503,7 @@ class Decoder(BaseGaussianDiffusion): text_encodings = None if exists(text): - _, text_encodings = self.clip.embed_text(text) + _, text_encodings, _ = self.clip.embed_text(text) assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified' assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented' @@ -1485,7 +1571,7 @@ class Decoder(BaseGaussianDiffusion): text_encodings = None if exists(text) and not exists(text_encodings): - _, text_encodings = self.clip.embed_text(text) + _, text_encodings, _ = self.clip.embed_text(text) assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified' assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented' diff --git a/dalle2_pytorch/optimizer.py b/dalle2_pytorch/optimizer.py new file mode 100644 index 0000000..5de2bfa --- /dev/null +++ b/dalle2_pytorch/optimizer.py @@ -0,0 +1,29 @@ +from torch.optim import AdamW, Adam + +def separate_weight_decayable_params(params): + no_wd_params = set([param for param in params if param.ndim < 2]) + wd_params = set(params) - no_wd_params + return wd_params, no_wd_params + +def get_optimizer( + params, + lr = 3e-4, + wd = 1e-2, + betas = (0.9, 0.999), + filter_by_requires_grad = False +): + if filter_by_requires_grad: + params = list(filter(lambda t: t.requires_grad, params)) + + if wd == 0: + return Adam(params, lr = lr, betas = betas) + + params = set(params) + wd_params, no_wd_params = separate_weight_decayable_params(params) + + param_groups = [ + {'params': list(wd_params)}, + {'params': list(no_wd_params), 'weight_decay': 0}, + ] + + return AdamW(param_groups, lr = lr, weight_decay = wd, betas = betas) diff --git a/dalle2_pytorch/train.py b/dalle2_pytorch/train.py index f6cebe3..c35bfe0 100644 --- a/dalle2_pytorch/train.py +++ b/dalle2_pytorch/train.py @@ -2,15 +2,6 @@ import copy import torch from torch import nn -# image related normalizations -# ddpms expect images to be in the range of -1 to 1 - -def normalize_img(img): - return img * 2 - 1 - -def unnormalize_img(normed_img): - return (normed_img + 1) * 0.5 - # exponential moving average wrapper class EMA(nn.Module): diff --git a/setup.py b/setup.py index 4293e0a..c0c34b7 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.65', + version = '0.0.71', license='MIT', description = 'DALL-E 2', author = 'Phil Wang', @@ -31,7 +31,7 @@ setup( 'torchvision', 'tqdm', 'vector-quantize-pytorch', - 'x-clip>=0.4.4', + 'x-clip>=0.5.1', 'youtokentome' ], classifiers=[