mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
now completely OpenAI CLIP compatible for training
just take care of the logic for AdamW and transformers used namedtuples for clip adapter embedding outputs
This commit is contained in:
@@ -3,6 +3,7 @@ from tqdm import tqdm
|
||||
from inspect import isfunction
|
||||
from functools import partial
|
||||
from contextlib import contextmanager
|
||||
from collections import namedtuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -90,8 +91,21 @@ def resize_image_to(t, image_size, mode = 'bilinear'): # take a look at https://
|
||||
|
||||
return F.interpolate(t, size = shape, mode = mode, align_corners = False)
|
||||
|
||||
# image normalization functions
|
||||
# ddpms expect images to be in the range of -1 to 1
|
||||
# but CLIP may otherwise
|
||||
|
||||
def normalize_img(img):
|
||||
return img * 2 - 1
|
||||
|
||||
def unnormalize_img(normed_img):
|
||||
return (normed_img + 1) * 0.5
|
||||
|
||||
# clip related adapters
|
||||
|
||||
EmbeddedText = namedtuple('EmbedTextReturn', ['text_embed', 'text_encodings', 'text_mask'])
|
||||
EmbeddedImage = namedtuple('EmbedImageReturn', ['image_embed', 'image_encodings'])
|
||||
|
||||
class BaseClipAdapter(nn.Module):
|
||||
def __init__(self, clip):
|
||||
super().__init__()
|
||||
@@ -109,6 +123,10 @@ class BaseClipAdapter(nn.Module):
|
||||
def image_channels(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def max_text_len(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def embed_text(self, text):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -128,12 +146,18 @@ class XClipAdapter(BaseClipAdapter):
|
||||
def image_channels(self):
|
||||
return self.clip.image_channels
|
||||
|
||||
@property
|
||||
def max_text_len(self):
|
||||
return self.clip.text_seq_len
|
||||
|
||||
@torch.no_grad()
|
||||
def embed_text(self, text):
|
||||
text = text[..., :self.max_text_len]
|
||||
text_mask = text != 0
|
||||
encoder_output = self.clip.text_transformer(text)
|
||||
text_cls, text_encodings = encoder_output[:, 0], encoder_output[:, 1:]
|
||||
text_embed = self.clip.to_text_latent(text_cls)
|
||||
return l2norm(text_embed), text_encodings
|
||||
return EmbeddedText(l2norm(text_embed), text_encodings, text_mask)
|
||||
|
||||
@torch.no_grad()
|
||||
def embed_image(self, image):
|
||||
@@ -141,7 +165,73 @@ class XClipAdapter(BaseClipAdapter):
|
||||
encoder_output = self.clip.visual_transformer(image)
|
||||
image_cls, image_encodings = encoder_output[:, 0], encoder_output[:, 1:]
|
||||
image_embed = self.clip.to_visual_latent(image_cls)
|
||||
return l2norm(image_embed), image_encodings
|
||||
return EmbeddedImage(l2norm(image_embed), image_encodings)
|
||||
|
||||
class OpenAIClipAdapter(BaseClipAdapter):
|
||||
def __init__(
|
||||
self,
|
||||
name = 'ViT-B/32'
|
||||
):
|
||||
try:
|
||||
import clip
|
||||
except ImportError:
|
||||
print('you must install openai clip in order to use this adapter - `pip install git+https://github.com/openai/CLIP.git` - more instructions at https://github.com/openai/CLIP#usage')
|
||||
|
||||
openai_clip, _ = clip.load(name)
|
||||
super().__init__(openai_clip)
|
||||
|
||||
text_attention_final = self.find_layer('ln_final')
|
||||
self.handle = text_attention_final.register_forward_hook(self._hook)
|
||||
self.clip_normalize = T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
|
||||
self.cleared = False
|
||||
|
||||
def find_layer(self, layer):
|
||||
modules = dict([*self.clip.named_modules()])
|
||||
return modules.get(layer, None)
|
||||
|
||||
def clear(self):
|
||||
if self.cleared:
|
||||
return
|
||||
|
||||
self.handle()
|
||||
|
||||
def _hook(self, _, inputs, outputs):
|
||||
self.text_encodings = outputs
|
||||
|
||||
@property
|
||||
def dim_latent(self):
|
||||
return 512
|
||||
|
||||
@property
|
||||
def image_size(self):
|
||||
return self.clip.visual.input_resolution
|
||||
|
||||
@property
|
||||
def image_channels(self):
|
||||
return 3
|
||||
|
||||
@property
|
||||
def max_text_len(self):
|
||||
return self.clip.context_length
|
||||
|
||||
@torch.no_grad()
|
||||
def embed_text(self, text):
|
||||
text = text[..., :self.max_text_len]
|
||||
text_mask = text != 0
|
||||
assert not self.cleared
|
||||
|
||||
text_embed = self.clip.encode_text(text)
|
||||
text_encodings = self.text_encodings
|
||||
del self.text_encodings
|
||||
return EmbeddedText(text_embed.float(), text_encodings.float(), text_mask)
|
||||
|
||||
@torch.no_grad()
|
||||
def embed_image(self, image):
|
||||
assert not self.cleared
|
||||
image = resize_image_to(image, self.image_size)
|
||||
image = self.clip_normalize(unnormalize_img(image))
|
||||
image_embed = self.clip.encode_image(image)
|
||||
return EmbeddedImage(image_embed.float(), None)
|
||||
|
||||
# classifier free guidance functions
|
||||
|
||||
@@ -741,12 +831,12 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
batch_size = text.shape[0]
|
||||
image_embed_dim = self.image_embed_dim
|
||||
|
||||
text_embed, text_encodings = self.clip.embed_text(text)
|
||||
text_embed, text_encodings, text_mask = self.clip.embed_text(text)
|
||||
|
||||
text_cond = dict(text_embed = text_embed)
|
||||
|
||||
if self.condition_on_text_encodings:
|
||||
text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text != 0}
|
||||
text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask}
|
||||
|
||||
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond)
|
||||
text_embeds = text_cond['text_embed']
|
||||
@@ -783,8 +873,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
# calculate text conditionings, based on what is passed in
|
||||
|
||||
if exists(text):
|
||||
text_embed, text_encodings = self.clip.embed_text(text)
|
||||
text_mask = text != 0
|
||||
text_embed, text_encodings, text_mask = self.clip.embed_text(text)
|
||||
|
||||
text_cond = dict(text_embed = text_embed)
|
||||
|
||||
@@ -1341,11 +1430,8 @@ class Decoder(BaseGaussianDiffusion):
|
||||
|
||||
@torch.no_grad()
|
||||
def get_image_embed(self, image):
|
||||
image = resize_image_to(image, self.clip_image_size)
|
||||
image_encoding = self.clip.visual_transformer(image)
|
||||
image_cls = image_encoding[:, 0]
|
||||
image_embed = self.clip.to_visual_latent(image_cls)
|
||||
return l2norm(image_embed)
|
||||
image_embed, _ = self.clip.embed_image(image)
|
||||
return image_embed
|
||||
|
||||
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, cond_scale = 1.):
|
||||
pred = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
|
||||
@@ -1417,7 +1503,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
|
||||
text_encodings = None
|
||||
if exists(text):
|
||||
_, text_encodings = self.clip.embed_text(text)
|
||||
_, text_encodings, _ = self.clip.embed_text(text)
|
||||
|
||||
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
|
||||
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
|
||||
@@ -1485,7 +1571,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
|
||||
text_encodings = None
|
||||
if exists(text) and not exists(text_encodings):
|
||||
_, text_encodings = self.clip.embed_text(text)
|
||||
_, text_encodings, _ = self.clip.embed_text(text)
|
||||
|
||||
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
|
||||
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
|
||||
|
||||
Reference in New Issue
Block a user