Compare commits

..

2 Commits

Author SHA1 Message Date
Phil Wang
846162ef3e just take care of the logic for AdamW and transformers 2022-04-29 11:43:26 -07:00
Phil Wang
39d3659ad9 now completely OpenAI CLIP compatible for training 2022-04-29 11:26:24 -07:00
4 changed files with 91 additions and 12 deletions

View File

@@ -430,8 +430,8 @@ images = torch.randn(4, 3, 256, 256).cuda()
# precompute the text and image embeddings
# here using the diffusion prior class, but could be done with CLIP alone
clip_image_embeds = diffusion_prior.clip.embed_image(images).image_embed
clip_text_embeds = diffusion_prior.clip.embed_text(text).text_embed
clip_image_embeds = diffusion_prior.get_image_embed(images)
clip_text_embeds = diffusion_prior.get_text_cond(text).get('text_embed')
# feed text and images into diffusion prior network
@@ -741,7 +741,6 @@ Once built, images will be saved to the same directory the command is invoked
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
- [ ] just take care of the training for the decoder in a wrapper class, as each unet in the cascade will need its own optimizer
- [ ] train on a toy task, offer in colab
- [ ] think about how best to design a declarative training config that handles preencoding for prior and training of multiple networks in decoder
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference

View File

@@ -3,7 +3,6 @@ from tqdm import tqdm
from inspect import isfunction
from functools import partial
from contextlib import contextmanager
from collections import namedtuple
import torch
import torch.nn.functional as F
@@ -103,9 +102,6 @@ def unnormalize_img(normed_img):
# clip related adapters
EmbeddedText = namedtuple('EmbedTextReturn', ['text_embed', 'text_encodings', 'text_mask'])
EmbeddedImage = namedtuple('EmbedImageReturn', ['image_embed', 'image_encodings'])
class BaseClipAdapter(nn.Module):
def __init__(self, clip):
super().__init__()
@@ -157,7 +153,7 @@ class XClipAdapter(BaseClipAdapter):
encoder_output = self.clip.text_transformer(text)
text_cls, text_encodings = encoder_output[:, 0], encoder_output[:, 1:]
text_embed = self.clip.to_text_latent(text_cls)
return EmbeddedText(l2norm(text_embed), text_encodings, text_mask)
return l2norm(text_embed), text_encodings, text_mask
@torch.no_grad()
def embed_image(self, image):
@@ -165,7 +161,7 @@ class XClipAdapter(BaseClipAdapter):
encoder_output = self.clip.visual_transformer(image)
image_cls, image_encodings = encoder_output[:, 0], encoder_output[:, 1:]
image_embed = self.clip.to_visual_latent(image_cls)
return EmbeddedImage(l2norm(image_embed), image_encodings)
return l2norm(image_embed), image_encodings
class OpenAIClipAdapter(BaseClipAdapter):
def __init__(
@@ -223,7 +219,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
text_embed = self.clip.encode_text(text)
text_encodings = self.text_encodings
del self.text_encodings
return EmbeddedText(text_embed.float(), text_encodings.float(), text_mask)
return text_embed.float(), text_encodings.float(), text_mask
@torch.no_grad()
def embed_image(self, image):
@@ -231,7 +227,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
image = resize_image_to(image, self.image_size)
image = self.clip_normalize(unnormalize_img(image))
image_embed = self.clip.encode_image(image)
return EmbeddedImage(image_embed.float(), None)
return image_embed.float(), None
# classifier free guidance functions

View File

@@ -0,0 +1,84 @@
import torch
from PIL import Image
from dalle2_pytorch.dalle2_pytorch import BaseClipAdapter
import torchvision.transforms as T
def find_layer(model, layer):
modules = dict([*model.named_modules()])
return modules.get(layer, None)
def hook(_, input, output):
print(output.shape)
import clip
# image = preprocess(Image.open("CLIP.png")).unsqueeze(0).to(device)
text = clip.tokenize(["a diagram", "a dog", "a cat"]).cuda()
image = torch.randn(1, 3, 224, 224).cuda()
class OpenAIClipAdapter(BaseClipAdapter):
def __init__(self, name = 'ViT-B/32'):
try:
import clip
except ImportError:
print('you must install openai clip in order to use this adapter - `pip install git+https://github.com/openai/CLIP.git` - more instructions at https://github.com/openai/CLIP#usage')
openai_clip, _ = clip.load(name)
super().__init__(openai_clip)
text_attention_final = self.find_layer(self.clip, 'ln_final')
self.handle = text_attention_final.register_forward_hook(self._hook)
self.clip_normalize = T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
self.cleared = False
def find_layer(self, layer):
modules = dict([*self.clip.named_modules()])
return modules.get(layer, None)
def clear(self):
if self.cleared:
return
self.handle()
def _hook(self, _, inputs, outputs):
self.text_encodings = outputs
@property
def dim_latent(self):
return 512
@property
def image_size(self):
return self.clip.visual.input_resolution
@property
def image_channels(self):
return 3
@torch.no_grad()
def embed_text(self, text):
assert not self.cleared
text_embed = self.clip.encode_text(text)
text_encodings = self.text_encodings
del self.text_encodings
return text_embed, text_encodings
@torch.no_grad()
def embed_image(self, image):
assert not self.cleared
image = self.clip_normalize(image)
image_embed = self.clip.encode_image(image)
return image_embed, None
clip_adapter = OpenAIClipAdapter().cuda()
# print(model)
with torch.no_grad():
image_features, _ = clip_adapter.embed_image(image)
text_features, text_encodings = clip_adapter.embed_text(text)
print(text_features.shape, image_features.shape)
print(text_encodings.shape)

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.71',
version = '0.0.70',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',