From 28e944f3284e27edec2714f63efc1e78e6c6583c Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Fri, 6 May 2022 10:12:03 -0700 Subject: [PATCH] make sure openai clip adapter outputs l2normed embeddings --- dalle2_pytorch/dalle2_pytorch.py | 4 ++-- setup.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 7d78db6..0b781a0 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -264,7 +264,7 @@ class OpenAIClipAdapter(BaseClipAdapter): text_embed = self.clip.encode_text(text) text_encodings = self.text_encodings del self.text_encodings - return EmbeddedText(text_embed.float(), text_encodings.float(), text_mask) + return EmbeddedText(l2norm(text_embed.float()), text_encodings.float(), text_mask) @torch.no_grad() def embed_image(self, image): @@ -272,7 +272,7 @@ class OpenAIClipAdapter(BaseClipAdapter): image = resize_image_to(image, self.image_size) image = self.clip_normalize(unnormalize_img(image)) image_embed = self.clip.encode_image(image) - return EmbeddedImage(image_embed.float(), None) + return EmbeddedImage(l2norm(image_embed.float()), None) # classifier free guidance functions diff --git a/setup.py b/setup.py index 7528174..c981089 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.1.2', + version = '0.1.4', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',