mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-20 18:24:19 +01:00
just take care of the logic for AdamW and transformers
This commit is contained in:
84
dalle2_pytorch/openai_clip.py
Normal file
84
dalle2_pytorch/openai_clip.py
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from dalle2_pytorch.dalle2_pytorch import BaseClipAdapter
|
||||||
|
import torchvision.transforms as T
|
||||||
|
|
||||||
|
def find_layer(model, layer):
|
||||||
|
modules = dict([*model.named_modules()])
|
||||||
|
return modules.get(layer, None)
|
||||||
|
|
||||||
|
def hook(_, input, output):
|
||||||
|
print(output.shape)
|
||||||
|
|
||||||
|
import clip
|
||||||
|
# image = preprocess(Image.open("CLIP.png")).unsqueeze(0).to(device)
|
||||||
|
text = clip.tokenize(["a diagram", "a dog", "a cat"]).cuda()
|
||||||
|
image = torch.randn(1, 3, 224, 224).cuda()
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIClipAdapter(BaseClipAdapter):
|
||||||
|
def __init__(self, name = 'ViT-B/32'):
|
||||||
|
try:
|
||||||
|
import clip
|
||||||
|
except ImportError:
|
||||||
|
print('you must install openai clip in order to use this adapter - `pip install git+https://github.com/openai/CLIP.git` - more instructions at https://github.com/openai/CLIP#usage')
|
||||||
|
|
||||||
|
openai_clip, _ = clip.load(name)
|
||||||
|
super().__init__(openai_clip)
|
||||||
|
|
||||||
|
text_attention_final = self.find_layer(self.clip, 'ln_final')
|
||||||
|
self.handle = text_attention_final.register_forward_hook(self._hook)
|
||||||
|
self.clip_normalize = T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
|
||||||
|
self.cleared = False
|
||||||
|
|
||||||
|
def find_layer(self, layer):
|
||||||
|
modules = dict([*self.clip.named_modules()])
|
||||||
|
return modules.get(layer, None)
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
if self.cleared:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.handle()
|
||||||
|
|
||||||
|
def _hook(self, _, inputs, outputs):
|
||||||
|
self.text_encodings = outputs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dim_latent(self):
|
||||||
|
return 512
|
||||||
|
|
||||||
|
@property
|
||||||
|
def image_size(self):
|
||||||
|
return self.clip.visual.input_resolution
|
||||||
|
|
||||||
|
@property
|
||||||
|
def image_channels(self):
|
||||||
|
return 3
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def embed_text(self, text):
|
||||||
|
assert not self.cleared
|
||||||
|
|
||||||
|
text_embed = self.clip.encode_text(text)
|
||||||
|
text_encodings = self.text_encodings
|
||||||
|
del self.text_encodings
|
||||||
|
return text_embed, text_encodings
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def embed_image(self, image):
|
||||||
|
assert not self.cleared
|
||||||
|
|
||||||
|
image = self.clip_normalize(image)
|
||||||
|
image_embed = self.clip.encode_image(image)
|
||||||
|
return image_embed, None
|
||||||
|
|
||||||
|
clip_adapter = OpenAIClipAdapter().cuda()
|
||||||
|
|
||||||
|
# print(model)
|
||||||
|
with torch.no_grad():
|
||||||
|
image_features, _ = clip_adapter.embed_image(image)
|
||||||
|
text_features, text_encodings = clip_adapter.embed_text(text)
|
||||||
|
print(text_features.shape, image_features.shape)
|
||||||
|
print(text_encodings.shape)
|
||||||
29
dalle2_pytorch/optimizer.py
Normal file
29
dalle2_pytorch/optimizer.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
from torch.optim import AdamW, Adam
|
||||||
|
|
||||||
|
def separate_weight_decayable_params(params):
|
||||||
|
no_wd_params = set([param for param in params if param.ndim < 2])
|
||||||
|
wd_params = set(params) - no_wd_params
|
||||||
|
return wd_params, no_wd_params
|
||||||
|
|
||||||
|
def get_optimizer(
|
||||||
|
params,
|
||||||
|
lr = 3e-4,
|
||||||
|
wd = 1e-2,
|
||||||
|
betas = (0.9, 0.999),
|
||||||
|
filter_by_requires_grad = False
|
||||||
|
):
|
||||||
|
if filter_by_requires_grad:
|
||||||
|
params = list(filter(lambda t: t.requires_grad, params))
|
||||||
|
|
||||||
|
if wd == 0:
|
||||||
|
return Adam(params, lr = lr, betas = betas)
|
||||||
|
|
||||||
|
params = set(params)
|
||||||
|
wd_params, no_wd_params = separate_weight_decayable_params(params)
|
||||||
|
|
||||||
|
param_groups = [
|
||||||
|
{'params': list(wd_params)},
|
||||||
|
{'params': list(no_wd_params), 'weight_decay': 0},
|
||||||
|
]
|
||||||
|
|
||||||
|
return AdamW(param_groups, lr = lr, weight_decay = wd, betas = betas)
|
||||||
Reference in New Issue
Block a user