mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-20 18:24:19 +01:00
85 lines
2.4 KiB
Python
85 lines
2.4 KiB
Python
import torch
|
|
from PIL import Image
|
|
|
|
from dalle2_pytorch.dalle2_pytorch import BaseClipAdapter
|
|
import torchvision.transforms as T
|
|
|
|
def find_layer(model, layer):
|
|
modules = dict([*model.named_modules()])
|
|
return modules.get(layer, None)
|
|
|
|
def hook(_, input, output):
|
|
print(output.shape)
|
|
|
|
import clip
|
|
# image = preprocess(Image.open("CLIP.png")).unsqueeze(0).to(device)
|
|
text = clip.tokenize(["a diagram", "a dog", "a cat"]).cuda()
|
|
image = torch.randn(1, 3, 224, 224).cuda()
|
|
|
|
|
|
class OpenAIClipAdapter(BaseClipAdapter):
|
|
def __init__(self, name = 'ViT-B/32'):
|
|
try:
|
|
import clip
|
|
except ImportError:
|
|
print('you must install openai clip in order to use this adapter - `pip install git+https://github.com/openai/CLIP.git` - more instructions at https://github.com/openai/CLIP#usage')
|
|
|
|
openai_clip, _ = clip.load(name)
|
|
super().__init__(openai_clip)
|
|
|
|
text_attention_final = self.find_layer(self.clip, 'ln_final')
|
|
self.handle = text_attention_final.register_forward_hook(self._hook)
|
|
self.clip_normalize = T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
|
|
self.cleared = False
|
|
|
|
def find_layer(self, layer):
|
|
modules = dict([*self.clip.named_modules()])
|
|
return modules.get(layer, None)
|
|
|
|
def clear(self):
|
|
if self.cleared:
|
|
return
|
|
|
|
self.handle()
|
|
|
|
def _hook(self, _, inputs, outputs):
|
|
self.text_encodings = outputs
|
|
|
|
@property
|
|
def dim_latent(self):
|
|
return 512
|
|
|
|
@property
|
|
def image_size(self):
|
|
return self.clip.visual.input_resolution
|
|
|
|
@property
|
|
def image_channels(self):
|
|
return 3
|
|
|
|
@torch.no_grad()
|
|
def embed_text(self, text):
|
|
assert not self.cleared
|
|
|
|
text_embed = self.clip.encode_text(text)
|
|
text_encodings = self.text_encodings
|
|
del self.text_encodings
|
|
return text_embed, text_encodings
|
|
|
|
@torch.no_grad()
|
|
def embed_image(self, image):
|
|
assert not self.cleared
|
|
|
|
image = self.clip_normalize(image)
|
|
image_embed = self.clip.encode_image(image)
|
|
return image_embed, None
|
|
|
|
clip_adapter = OpenAIClipAdapter().cuda()
|
|
|
|
# print(model)
|
|
with torch.no_grad():
|
|
image_features, _ = clip_adapter.embed_image(image)
|
|
text_features, text_encodings = clip_adapter.embed_text(text)
|
|
print(text_features.shape, image_features.shape)
|
|
print(text_encodings.shape)
|