mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-23 07:44:40 +01:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0d1c07c803 | ||
|
|
a389f81138 | ||
|
|
0283556608 | ||
|
|
5063d192b6 |
@@ -430,8 +430,8 @@ images = torch.randn(4, 3, 256, 256).cuda()
|
|||||||
# precompute the text and image embeddings
|
# precompute the text and image embeddings
|
||||||
# here using the diffusion prior class, but could be done with CLIP alone
|
# here using the diffusion prior class, but could be done with CLIP alone
|
||||||
|
|
||||||
clip_image_embeds = diffusion_prior.get_image_embed(images)
|
clip_image_embeds = diffusion_prior.clip.embed_image(images).image_embed
|
||||||
clip_text_embeds = diffusion_prior.get_text_cond(text).get('text_embed')
|
clip_text_embeds = diffusion_prior.clip.embed_text(text).text_embed
|
||||||
|
|
||||||
# feed text and images into diffusion prior network
|
# feed text and images into diffusion prior network
|
||||||
|
|
||||||
@@ -741,6 +741,7 @@ Once built, images will be saved to the same directory the command is invoked
|
|||||||
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
|
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
|
||||||
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
|
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
|
||||||
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
|
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
|
||||||
|
- [ ] just take care of the training for the decoder in a wrapper class, as each unet in the cascade will need its own optimizer
|
||||||
- [ ] train on a toy task, offer in colab
|
- [ ] train on a toy task, offer in colab
|
||||||
- [ ] think about how best to design a declarative training config that handles preencoding for prior and training of multiple networks in decoder
|
- [ ] think about how best to design a declarative training config that handles preencoding for prior and training of multiple networks in decoder
|
||||||
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
|
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ from tqdm import tqdm
|
|||||||
from inspect import isfunction
|
from inspect import isfunction
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
from collections import namedtuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@@ -102,6 +103,9 @@ def unnormalize_img(normed_img):
|
|||||||
|
|
||||||
# clip related adapters
|
# clip related adapters
|
||||||
|
|
||||||
|
EmbeddedText = namedtuple('EmbedTextReturn', ['text_embed', 'text_encodings', 'text_mask'])
|
||||||
|
EmbeddedImage = namedtuple('EmbedImageReturn', ['image_embed', 'image_encodings'])
|
||||||
|
|
||||||
class BaseClipAdapter(nn.Module):
|
class BaseClipAdapter(nn.Module):
|
||||||
def __init__(self, clip):
|
def __init__(self, clip):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -153,7 +157,7 @@ class XClipAdapter(BaseClipAdapter):
|
|||||||
encoder_output = self.clip.text_transformer(text)
|
encoder_output = self.clip.text_transformer(text)
|
||||||
text_cls, text_encodings = encoder_output[:, 0], encoder_output[:, 1:]
|
text_cls, text_encodings = encoder_output[:, 0], encoder_output[:, 1:]
|
||||||
text_embed = self.clip.to_text_latent(text_cls)
|
text_embed = self.clip.to_text_latent(text_cls)
|
||||||
return l2norm(text_embed), text_encodings, text_mask
|
return EmbeddedText(l2norm(text_embed), text_encodings, text_mask)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def embed_image(self, image):
|
def embed_image(self, image):
|
||||||
@@ -161,7 +165,7 @@ class XClipAdapter(BaseClipAdapter):
|
|||||||
encoder_output = self.clip.visual_transformer(image)
|
encoder_output = self.clip.visual_transformer(image)
|
||||||
image_cls, image_encodings = encoder_output[:, 0], encoder_output[:, 1:]
|
image_cls, image_encodings = encoder_output[:, 0], encoder_output[:, 1:]
|
||||||
image_embed = self.clip.to_visual_latent(image_cls)
|
image_embed = self.clip.to_visual_latent(image_cls)
|
||||||
return l2norm(image_embed), image_encodings
|
return EmbeddedImage(l2norm(image_embed), image_encodings)
|
||||||
|
|
||||||
class OpenAIClipAdapter(BaseClipAdapter):
|
class OpenAIClipAdapter(BaseClipAdapter):
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -219,7 +223,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
|
|||||||
text_embed = self.clip.encode_text(text)
|
text_embed = self.clip.encode_text(text)
|
||||||
text_encodings = self.text_encodings
|
text_encodings = self.text_encodings
|
||||||
del self.text_encodings
|
del self.text_encodings
|
||||||
return text_embed.float(), text_encodings.float(), text_mask
|
return EmbeddedText(text_embed.float(), text_encodings.float(), text_mask)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def embed_image(self, image):
|
def embed_image(self, image):
|
||||||
@@ -227,7 +231,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
|
|||||||
image = resize_image_to(image, self.image_size)
|
image = resize_image_to(image, self.image_size)
|
||||||
image = self.clip_normalize(unnormalize_img(image))
|
image = self.clip_normalize(unnormalize_img(image))
|
||||||
image_embed = self.clip.encode_image(image)
|
image_embed = self.clip.encode_image(image)
|
||||||
return image_embed.float(), None
|
return EmbeddedImage(image_embed.float(), None)
|
||||||
|
|
||||||
# classifier free guidance functions
|
# classifier free guidance functions
|
||||||
|
|
||||||
@@ -684,14 +688,14 @@ class DiffusionPriorNetwork(nn.Module):
|
|||||||
|
|
||||||
# classifier free guidance
|
# classifier free guidance
|
||||||
|
|
||||||
cond_prob_mask = prob_mask_like((batch,), cond_drop_prob, device = device)
|
keep_mask = prob_mask_like((batch,), 1 - cond_drop_prob, device = device)
|
||||||
cond_prob_mask = rearrange(cond_prob_mask, 'b -> b 1')
|
keep_mask = rearrange(keep_mask, 'b -> b 1')
|
||||||
|
|
||||||
mask &= cond_prob_mask
|
mask &= keep_mask
|
||||||
|
|
||||||
# whether text embedding is masked or not depends on the classifier free guidance conditional masking
|
# whether text embedding is masked or not depends on the classifier free guidance conditional masking
|
||||||
|
|
||||||
mask = torch.cat((mask, cond_prob_mask), dim = 1)
|
mask = torch.cat((mask, keep_mask), dim = 1)
|
||||||
|
|
||||||
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
|
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
|
||||||
# but let's just do it right
|
# but let's just do it right
|
||||||
@@ -1204,8 +1208,8 @@ class Unet(nn.Module):
|
|||||||
|
|
||||||
# conditional dropout
|
# conditional dropout
|
||||||
|
|
||||||
cond_prob_mask = prob_mask_like((batch_size,), cond_drop_prob, device = device)
|
keep_mask = prob_mask_like((batch_size,), 1 - cond_drop_prob, device = device)
|
||||||
cond_prob_mask = rearrange(cond_prob_mask, 'b -> b 1 1')
|
keep_mask = rearrange(keep_mask, 'b -> b 1 1')
|
||||||
|
|
||||||
# mask out image embedding depending on condition dropout
|
# mask out image embedding depending on condition dropout
|
||||||
# for classifier free guidance
|
# for classifier free guidance
|
||||||
@@ -1216,7 +1220,7 @@ class Unet(nn.Module):
|
|||||||
image_tokens = self.image_to_cond(image_embed)
|
image_tokens = self.image_to_cond(image_embed)
|
||||||
|
|
||||||
image_tokens = torch.where(
|
image_tokens = torch.where(
|
||||||
cond_prob_mask,
|
keep_mask,
|
||||||
image_tokens,
|
image_tokens,
|
||||||
self.null_image_embed
|
self.null_image_embed
|
||||||
)
|
)
|
||||||
@@ -1228,7 +1232,7 @@ class Unet(nn.Module):
|
|||||||
if exists(text_encodings) and self.cond_on_text_encodings:
|
if exists(text_encodings) and self.cond_on_text_encodings:
|
||||||
text_tokens = self.text_to_cond(text_encodings)
|
text_tokens = self.text_to_cond(text_encodings)
|
||||||
text_tokens = torch.where(
|
text_tokens = torch.where(
|
||||||
cond_prob_mask,
|
keep_mask,
|
||||||
text_tokens,
|
text_tokens,
|
||||||
self.null_text_embed[:, :text_tokens.shape[1]]
|
self.null_text_embed[:, :text_tokens.shape[1]]
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,84 +0,0 @@
|
|||||||
import torch
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from dalle2_pytorch.dalle2_pytorch import BaseClipAdapter
|
|
||||||
import torchvision.transforms as T
|
|
||||||
|
|
||||||
def find_layer(model, layer):
|
|
||||||
modules = dict([*model.named_modules()])
|
|
||||||
return modules.get(layer, None)
|
|
||||||
|
|
||||||
def hook(_, input, output):
|
|
||||||
print(output.shape)
|
|
||||||
|
|
||||||
import clip
|
|
||||||
# image = preprocess(Image.open("CLIP.png")).unsqueeze(0).to(device)
|
|
||||||
text = clip.tokenize(["a diagram", "a dog", "a cat"]).cuda()
|
|
||||||
image = torch.randn(1, 3, 224, 224).cuda()
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAIClipAdapter(BaseClipAdapter):
|
|
||||||
def __init__(self, name = 'ViT-B/32'):
|
|
||||||
try:
|
|
||||||
import clip
|
|
||||||
except ImportError:
|
|
||||||
print('you must install openai clip in order to use this adapter - `pip install git+https://github.com/openai/CLIP.git` - more instructions at https://github.com/openai/CLIP#usage')
|
|
||||||
|
|
||||||
openai_clip, _ = clip.load(name)
|
|
||||||
super().__init__(openai_clip)
|
|
||||||
|
|
||||||
text_attention_final = self.find_layer(self.clip, 'ln_final')
|
|
||||||
self.handle = text_attention_final.register_forward_hook(self._hook)
|
|
||||||
self.clip_normalize = T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
|
|
||||||
self.cleared = False
|
|
||||||
|
|
||||||
def find_layer(self, layer):
|
|
||||||
modules = dict([*self.clip.named_modules()])
|
|
||||||
return modules.get(layer, None)
|
|
||||||
|
|
||||||
def clear(self):
|
|
||||||
if self.cleared:
|
|
||||||
return
|
|
||||||
|
|
||||||
self.handle()
|
|
||||||
|
|
||||||
def _hook(self, _, inputs, outputs):
|
|
||||||
self.text_encodings = outputs
|
|
||||||
|
|
||||||
@property
|
|
||||||
def dim_latent(self):
|
|
||||||
return 512
|
|
||||||
|
|
||||||
@property
|
|
||||||
def image_size(self):
|
|
||||||
return self.clip.visual.input_resolution
|
|
||||||
|
|
||||||
@property
|
|
||||||
def image_channels(self):
|
|
||||||
return 3
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def embed_text(self, text):
|
|
||||||
assert not self.cleared
|
|
||||||
|
|
||||||
text_embed = self.clip.encode_text(text)
|
|
||||||
text_encodings = self.text_encodings
|
|
||||||
del self.text_encodings
|
|
||||||
return text_embed, text_encodings
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def embed_image(self, image):
|
|
||||||
assert not self.cleared
|
|
||||||
|
|
||||||
image = self.clip_normalize(image)
|
|
||||||
image_embed = self.clip.encode_image(image)
|
|
||||||
return image_embed, None
|
|
||||||
|
|
||||||
clip_adapter = OpenAIClipAdapter().cuda()
|
|
||||||
|
|
||||||
# print(model)
|
|
||||||
with torch.no_grad():
|
|
||||||
image_features, _ = clip_adapter.embed_image(image)
|
|
||||||
text_features, text_encodings = clip_adapter.embed_text(text)
|
|
||||||
print(text_features.shape, image_features.shape)
|
|
||||||
print(text_encodings.shape)
|
|
||||||
Reference in New Issue
Block a user