Files
DALLE2-pytorch/dalle2_pytorch/openai_clip.py

85 lines
2.4 KiB
Python

import torch
from PIL import Image
from dalle2_pytorch.dalle2_pytorch import BaseClipAdapter
import torchvision.transforms as T
def find_layer(model, layer):
modules = dict([*model.named_modules()])
return modules.get(layer, None)
def hook(_, input, output):
print(output.shape)
import clip
# image = preprocess(Image.open("CLIP.png")).unsqueeze(0).to(device)
text = clip.tokenize(["a diagram", "a dog", "a cat"]).cuda()
image = torch.randn(1, 3, 224, 224).cuda()
class OpenAIClipAdapter(BaseClipAdapter):
def __init__(self, name = 'ViT-B/32'):
try:
import clip
except ImportError:
print('you must install openai clip in order to use this adapter - `pip install git+https://github.com/openai/CLIP.git` - more instructions at https://github.com/openai/CLIP#usage')
openai_clip, _ = clip.load(name)
super().__init__(openai_clip)
text_attention_final = self.find_layer(self.clip, 'ln_final')
self.handle = text_attention_final.register_forward_hook(self._hook)
self.clip_normalize = T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
self.cleared = False
def find_layer(self, layer):
modules = dict([*self.clip.named_modules()])
return modules.get(layer, None)
def clear(self):
if self.cleared:
return
self.handle()
def _hook(self, _, inputs, outputs):
self.text_encodings = outputs
@property
def dim_latent(self):
return 512
@property
def image_size(self):
return self.clip.visual.input_resolution
@property
def image_channels(self):
return 3
@torch.no_grad()
def embed_text(self, text):
assert not self.cleared
text_embed = self.clip.encode_text(text)
text_encodings = self.text_encodings
del self.text_encodings
return text_embed, text_encodings
@torch.no_grad()
def embed_image(self, image):
assert not self.cleared
image = self.clip_normalize(image)
image_embed = self.clip.encode_image(image)
return image_embed, None
clip_adapter = OpenAIClipAdapter().cuda()
# print(model)
with torch.no_grad():
image_features, _ = clip_adapter.embed_image(image)
text_features, text_encodings = clip_adapter.embed_text(text)
print(text_features.shape, image_features.shape)
print(text_encodings.shape)