From f22e8c8741c36b847828c9f1ef9b2c3744bcd6d8 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sat, 30 Jul 2022 09:02:31 -0700 Subject: [PATCH] make open clip available for use with dalle2 pytorch --- README.md | 12 ++++++ dalle2_pytorch/dalle2_pytorch.py | 69 ++++++++++++++++++++++++++++++++ dalle2_pytorch/version.py | 2 +- 3 files changed, 82 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index d72779a..c88b700 100644 --- a/README.md +++ b/README.md @@ -627,6 +627,18 @@ images = dalle2( # save your image (in this example, of size 256x256) ``` +Alternatively, you can also use Open Clip + +```bash +$ pip install open-clip-torch +``` + +```python +from dalle2_pytorch import OpenClipAdapter + +clip = OpenClipAdapter() +``` + Now you'll just have to worry about training the Prior and the Decoder! ## Inpainting diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 005a920..427e3e1 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -339,6 +339,75 @@ class OpenAIClipAdapter(BaseClipAdapter): image_embed = self.clip.encode_image(image) return EmbeddedImage(l2norm(image_embed.float()), None) +class OpenClipAdapter(BaseClipAdapter): + def __init__( + self, + name = 'ViT-B/32', + pretrained = 'laion400m_e32' + ): + import open_clip + clip, _, preprocess = open_clip.create_model_and_transforms(name, pretrained = pretrained) + + super().__init__(clip) + self.eos_id = 49407 + + text_attention_final = self.find_layer('ln_final') + self.handle = text_attention_final.register_forward_hook(self._hook) + self.clip_normalize = preprocess.transforms[-1] + self.cleared = False + + def find_layer(self, layer): + modules = dict([*self.clip.named_modules()]) + return modules.get(layer, None) + + def clear(self): + if self.cleared: + return + + self.handle() + + def _hook(self, _, inputs, outputs): + self.text_encodings = outputs + + @property + def dim_latent(self): + return 512 + + @property + def image_size(self): + return self.clip.visual.image_size + + @property + def image_channels(self): + return 3 + + @property + def max_text_len(self): + return self.clip.context_length + + @torch.no_grad() + def embed_text(self, text): + text = text[..., :self.max_text_len] + + is_eos_id = (text == self.eos_id) + text_mask_excluding_eos = is_eos_id.cumsum(dim = -1) == 0 + text_mask = F.pad(text_mask_excluding_eos, (1, -1), value = True) + assert not self.cleared + + text_embed = self.clip.encode_text(text) + text_encodings = self.text_encodings + text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.) + del self.text_encodings + return EmbeddedText(l2norm(text_embed.float()), text_encodings.float()) + + @torch.no_grad() + def embed_image(self, image): + assert not self.cleared + image = self.validate_and_resize_image(image) + image = self.clip_normalize(image) + image_embed = self.clip.encode_image(image) + return EmbeddedImage(l2norm(image_embed.float()), None) + # classifier free guidance functions def prob_mask_like(shape, prob, device): diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index 9e0feee..5e235ea 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.4.4' +__version__ = '1.4.5'