mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-20 18:24:19 +01:00
just take care of the logic for AdamW and transformers
This commit is contained in:
84
dalle2_pytorch/openai_clip.py
Normal file
84
dalle2_pytorch/openai_clip.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from dalle2_pytorch.dalle2_pytorch import BaseClipAdapter
|
||||
import torchvision.transforms as T
|
||||
|
||||
def find_layer(model, layer):
|
||||
modules = dict([*model.named_modules()])
|
||||
return modules.get(layer, None)
|
||||
|
||||
def hook(_, input, output):
|
||||
print(output.shape)
|
||||
|
||||
import clip
|
||||
# image = preprocess(Image.open("CLIP.png")).unsqueeze(0).to(device)
|
||||
text = clip.tokenize(["a diagram", "a dog", "a cat"]).cuda()
|
||||
image = torch.randn(1, 3, 224, 224).cuda()
|
||||
|
||||
|
||||
class OpenAIClipAdapter(BaseClipAdapter):
|
||||
def __init__(self, name = 'ViT-B/32'):
|
||||
try:
|
||||
import clip
|
||||
except ImportError:
|
||||
print('you must install openai clip in order to use this adapter - `pip install git+https://github.com/openai/CLIP.git` - more instructions at https://github.com/openai/CLIP#usage')
|
||||
|
||||
openai_clip, _ = clip.load(name)
|
||||
super().__init__(openai_clip)
|
||||
|
||||
text_attention_final = self.find_layer(self.clip, 'ln_final')
|
||||
self.handle = text_attention_final.register_forward_hook(self._hook)
|
||||
self.clip_normalize = T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
|
||||
self.cleared = False
|
||||
|
||||
def find_layer(self, layer):
|
||||
modules = dict([*self.clip.named_modules()])
|
||||
return modules.get(layer, None)
|
||||
|
||||
def clear(self):
|
||||
if self.cleared:
|
||||
return
|
||||
|
||||
self.handle()
|
||||
|
||||
def _hook(self, _, inputs, outputs):
|
||||
self.text_encodings = outputs
|
||||
|
||||
@property
|
||||
def dim_latent(self):
|
||||
return 512
|
||||
|
||||
@property
|
||||
def image_size(self):
|
||||
return self.clip.visual.input_resolution
|
||||
|
||||
@property
|
||||
def image_channels(self):
|
||||
return 3
|
||||
|
||||
@torch.no_grad()
|
||||
def embed_text(self, text):
|
||||
assert not self.cleared
|
||||
|
||||
text_embed = self.clip.encode_text(text)
|
||||
text_encodings = self.text_encodings
|
||||
del self.text_encodings
|
||||
return text_embed, text_encodings
|
||||
|
||||
@torch.no_grad()
|
||||
def embed_image(self, image):
|
||||
assert not self.cleared
|
||||
|
||||
image = self.clip_normalize(image)
|
||||
image_embed = self.clip.encode_image(image)
|
||||
return image_embed, None
|
||||
|
||||
clip_adapter = OpenAIClipAdapter().cuda()
|
||||
|
||||
# print(model)
|
||||
with torch.no_grad():
|
||||
image_features, _ = clip_adapter.embed_image(image)
|
||||
text_features, text_encodings = clip_adapter.embed_text(text)
|
||||
print(text_features.shape, image_features.shape)
|
||||
print(text_encodings.shape)
|
||||
29
dalle2_pytorch/optimizer.py
Normal file
29
dalle2_pytorch/optimizer.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from torch.optim import AdamW, Adam
|
||||
|
||||
def separate_weight_decayable_params(params):
|
||||
no_wd_params = set([param for param in params if param.ndim < 2])
|
||||
wd_params = set(params) - no_wd_params
|
||||
return wd_params, no_wd_params
|
||||
|
||||
def get_optimizer(
|
||||
params,
|
||||
lr = 3e-4,
|
||||
wd = 1e-2,
|
||||
betas = (0.9, 0.999),
|
||||
filter_by_requires_grad = False
|
||||
):
|
||||
if filter_by_requires_grad:
|
||||
params = list(filter(lambda t: t.requires_grad, params))
|
||||
|
||||
if wd == 0:
|
||||
return Adam(params, lr = lr, betas = betas)
|
||||
|
||||
params = set(params)
|
||||
wd_params, no_wd_params = separate_weight_decayable_params(params)
|
||||
|
||||
param_groups = [
|
||||
{'params': list(wd_params)},
|
||||
{'params': list(no_wd_params), 'weight_decay': 0},
|
||||
]
|
||||
|
||||
return AdamW(param_groups, lr = lr, weight_decay = wd, betas = betas)
|
||||
Reference in New Issue
Block a user