mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-21 10:44:18 +01:00
now completely OpenAI CLIP compatible for training
This commit is contained in:
@@ -90,6 +90,16 @@ def resize_image_to(t, image_size, mode = 'bilinear'): # take a look at https://
|
||||
|
||||
return F.interpolate(t, size = shape, mode = mode, align_corners = False)
|
||||
|
||||
# image normalization functions
|
||||
# ddpms expect images to be in the range of -1 to 1
|
||||
# but CLIP may otherwise
|
||||
|
||||
def normalize_img(img):
|
||||
return img * 2 - 1
|
||||
|
||||
def unnormalize_img(normed_img):
|
||||
return (normed_img + 1) * 0.5
|
||||
|
||||
# clip related adapters
|
||||
|
||||
class BaseClipAdapter(nn.Module):
|
||||
@@ -109,6 +119,10 @@ class BaseClipAdapter(nn.Module):
|
||||
def image_channels(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def max_text_len(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def embed_text(self, text):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -128,12 +142,18 @@ class XClipAdapter(BaseClipAdapter):
|
||||
def image_channels(self):
|
||||
return self.clip.image_channels
|
||||
|
||||
@property
|
||||
def max_text_len(self):
|
||||
return self.clip.text_seq_len
|
||||
|
||||
@torch.no_grad()
|
||||
def embed_text(self, text):
|
||||
text = text[..., :self.max_text_len]
|
||||
text_mask = text != 0
|
||||
encoder_output = self.clip.text_transformer(text)
|
||||
text_cls, text_encodings = encoder_output[:, 0], encoder_output[:, 1:]
|
||||
text_embed = self.clip.to_text_latent(text_cls)
|
||||
return l2norm(text_embed), text_encodings
|
||||
return l2norm(text_embed), text_encodings, text_mask
|
||||
|
||||
@torch.no_grad()
|
||||
def embed_image(self, image):
|
||||
@@ -143,6 +163,72 @@ class XClipAdapter(BaseClipAdapter):
|
||||
image_embed = self.clip.to_visual_latent(image_cls)
|
||||
return l2norm(image_embed), image_encodings
|
||||
|
||||
class OpenAIClipAdapter(BaseClipAdapter):
|
||||
def __init__(
|
||||
self,
|
||||
name = 'ViT-B/32'
|
||||
):
|
||||
try:
|
||||
import clip
|
||||
except ImportError:
|
||||
print('you must install openai clip in order to use this adapter - `pip install git+https://github.com/openai/CLIP.git` - more instructions at https://github.com/openai/CLIP#usage')
|
||||
|
||||
openai_clip, _ = clip.load(name)
|
||||
super().__init__(openai_clip)
|
||||
|
||||
text_attention_final = self.find_layer('ln_final')
|
||||
self.handle = text_attention_final.register_forward_hook(self._hook)
|
||||
self.clip_normalize = T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
|
||||
self.cleared = False
|
||||
|
||||
def find_layer(self, layer):
|
||||
modules = dict([*self.clip.named_modules()])
|
||||
return modules.get(layer, None)
|
||||
|
||||
def clear(self):
|
||||
if self.cleared:
|
||||
return
|
||||
|
||||
self.handle()
|
||||
|
||||
def _hook(self, _, inputs, outputs):
|
||||
self.text_encodings = outputs
|
||||
|
||||
@property
|
||||
def dim_latent(self):
|
||||
return 512
|
||||
|
||||
@property
|
||||
def image_size(self):
|
||||
return self.clip.visual.input_resolution
|
||||
|
||||
@property
|
||||
def image_channels(self):
|
||||
return 3
|
||||
|
||||
@property
|
||||
def max_text_len(self):
|
||||
return self.clip.context_length
|
||||
|
||||
@torch.no_grad()
|
||||
def embed_text(self, text):
|
||||
text = text[..., :self.max_text_len]
|
||||
text_mask = text != 0
|
||||
assert not self.cleared
|
||||
|
||||
text_embed = self.clip.encode_text(text)
|
||||
text_encodings = self.text_encodings
|
||||
del self.text_encodings
|
||||
return text_embed.float(), text_encodings.float(), text_mask
|
||||
|
||||
@torch.no_grad()
|
||||
def embed_image(self, image):
|
||||
assert not self.cleared
|
||||
image = resize_image_to(image, self.image_size)
|
||||
image = self.clip_normalize(unnormalize_img(image))
|
||||
image_embed = self.clip.encode_image(image)
|
||||
return image_embed.float(), None
|
||||
|
||||
# classifier free guidance functions
|
||||
|
||||
def prob_mask_like(shape, prob, device):
|
||||
@@ -741,12 +827,12 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
batch_size = text.shape[0]
|
||||
image_embed_dim = self.image_embed_dim
|
||||
|
||||
text_embed, text_encodings = self.clip.embed_text(text)
|
||||
text_embed, text_encodings, text_mask = self.clip.embed_text(text)
|
||||
|
||||
text_cond = dict(text_embed = text_embed)
|
||||
|
||||
if self.condition_on_text_encodings:
|
||||
text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text != 0}
|
||||
text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask}
|
||||
|
||||
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond)
|
||||
text_embeds = text_cond['text_embed']
|
||||
@@ -783,8 +869,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
# calculate text conditionings, based on what is passed in
|
||||
|
||||
if exists(text):
|
||||
text_embed, text_encodings = self.clip.embed_text(text)
|
||||
text_mask = text != 0
|
||||
text_embed, text_encodings, text_mask = self.clip.embed_text(text)
|
||||
|
||||
text_cond = dict(text_embed = text_embed)
|
||||
|
||||
@@ -1341,11 +1426,8 @@ class Decoder(BaseGaussianDiffusion):
|
||||
|
||||
@torch.no_grad()
|
||||
def get_image_embed(self, image):
|
||||
image = resize_image_to(image, self.clip_image_size)
|
||||
image_encoding = self.clip.visual_transformer(image)
|
||||
image_cls = image_encoding[:, 0]
|
||||
image_embed = self.clip.to_visual_latent(image_cls)
|
||||
return l2norm(image_embed)
|
||||
image_embed, _ = self.clip.embed_image(image)
|
||||
return image_embed
|
||||
|
||||
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, cond_scale = 1.):
|
||||
pred = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
|
||||
@@ -1417,7 +1499,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
|
||||
text_encodings = None
|
||||
if exists(text):
|
||||
_, text_encodings = self.clip.embed_text(text)
|
||||
_, text_encodings, _ = self.clip.embed_text(text)
|
||||
|
||||
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
|
||||
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
|
||||
@@ -1485,7 +1567,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
|
||||
text_encodings = None
|
||||
if exists(text) and not exists(text_encodings):
|
||||
_, text_encodings = self.clip.embed_text(text)
|
||||
_, text_encodings, _ = self.clip.embed_text(text)
|
||||
|
||||
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
|
||||
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
|
||||
|
||||
Reference in New Issue
Block a user