From 846162ef3ec0c862b1d8b5cdd36e5f7bd9ca66b4 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Fri, 29 Apr 2022 11:43:26 -0700 Subject: [PATCH] just take care of the logic for AdamW and transformers --- dalle2_pytorch/openai_clip.py | 84 +++++++++++++++++++++++++++++++++++ dalle2_pytorch/optimizer.py | 29 ++++++++++++ setup.py | 2 +- 3 files changed, 114 insertions(+), 1 deletion(-) create mode 100644 dalle2_pytorch/openai_clip.py create mode 100644 dalle2_pytorch/optimizer.py diff --git a/dalle2_pytorch/openai_clip.py b/dalle2_pytorch/openai_clip.py new file mode 100644 index 0000000..67dc03c --- /dev/null +++ b/dalle2_pytorch/openai_clip.py @@ -0,0 +1,84 @@ +import torch +from PIL import Image + +from dalle2_pytorch.dalle2_pytorch import BaseClipAdapter +import torchvision.transforms as T + +def find_layer(model, layer): + modules = dict([*model.named_modules()]) + return modules.get(layer, None) + +def hook(_, input, output): + print(output.shape) + +import clip +# image = preprocess(Image.open("CLIP.png")).unsqueeze(0).to(device) +text = clip.tokenize(["a diagram", "a dog", "a cat"]).cuda() +image = torch.randn(1, 3, 224, 224).cuda() + + +class OpenAIClipAdapter(BaseClipAdapter): + def __init__(self, name = 'ViT-B/32'): + try: + import clip + except ImportError: + print('you must install openai clip in order to use this adapter - `pip install git+https://github.com/openai/CLIP.git` - more instructions at https://github.com/openai/CLIP#usage') + + openai_clip, _ = clip.load(name) + super().__init__(openai_clip) + + text_attention_final = self.find_layer(self.clip, 'ln_final') + self.handle = text_attention_final.register_forward_hook(self._hook) + self.clip_normalize = T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) + self.cleared = False + + def find_layer(self, layer): + modules = dict([*self.clip.named_modules()]) + return modules.get(layer, None) + + def clear(self): + if self.cleared: + return + + self.handle() + + def _hook(self, _, inputs, outputs): + self.text_encodings = outputs + + @property + def dim_latent(self): + return 512 + + @property + def image_size(self): + return self.clip.visual.input_resolution + + @property + def image_channels(self): + return 3 + + @torch.no_grad() + def embed_text(self, text): + assert not self.cleared + + text_embed = self.clip.encode_text(text) + text_encodings = self.text_encodings + del self.text_encodings + return text_embed, text_encodings + + @torch.no_grad() + def embed_image(self, image): + assert not self.cleared + + image = self.clip_normalize(image) + image_embed = self.clip.encode_image(image) + return image_embed, None + +clip_adapter = OpenAIClipAdapter().cuda() + +# print(model) +with torch.no_grad(): + image_features, _ = clip_adapter.embed_image(image) + text_features, text_encodings = clip_adapter.embed_text(text) + print(text_features.shape, image_features.shape) + print(text_encodings.shape) diff --git a/dalle2_pytorch/optimizer.py b/dalle2_pytorch/optimizer.py new file mode 100644 index 0000000..5de2bfa --- /dev/null +++ b/dalle2_pytorch/optimizer.py @@ -0,0 +1,29 @@ +from torch.optim import AdamW, Adam + +def separate_weight_decayable_params(params): + no_wd_params = set([param for param in params if param.ndim < 2]) + wd_params = set(params) - no_wd_params + return wd_params, no_wd_params + +def get_optimizer( + params, + lr = 3e-4, + wd = 1e-2, + betas = (0.9, 0.999), + filter_by_requires_grad = False +): + if filter_by_requires_grad: + params = list(filter(lambda t: t.requires_grad, params)) + + if wd == 0: + return Adam(params, lr = lr, betas = betas) + + params = set(params) + wd_params, no_wd_params = separate_weight_decayable_params(params) + + param_groups = [ + {'params': list(wd_params)}, + {'params': list(no_wd_params), 'weight_decay': 0}, + ] + + return AdamW(param_groups, lr = lr, weight_decay = wd, betas = betas) diff --git a/setup.py b/setup.py index c53b8a8..000ae1e 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.67', + version = '0.0.70', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',