mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
make sure openai clip adapter outputs l2normed embeddings
This commit is contained in:
@@ -264,7 +264,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
|
|||||||
text_embed = self.clip.encode_text(text)
|
text_embed = self.clip.encode_text(text)
|
||||||
text_encodings = self.text_encodings
|
text_encodings = self.text_encodings
|
||||||
del self.text_encodings
|
del self.text_encodings
|
||||||
return EmbeddedText(text_embed.float(), text_encodings.float(), text_mask)
|
return EmbeddedText(l2norm(text_embed.float()), text_encodings.float(), text_mask)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def embed_image(self, image):
|
def embed_image(self, image):
|
||||||
@@ -272,7 +272,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
|
|||||||
image = resize_image_to(image, self.image_size)
|
image = resize_image_to(image, self.image_size)
|
||||||
image = self.clip_normalize(unnormalize_img(image))
|
image = self.clip_normalize(unnormalize_img(image))
|
||||||
image_embed = self.clip.encode_image(image)
|
image_embed = self.clip.encode_image(image)
|
||||||
return EmbeddedImage(image_embed.float(), None)
|
return EmbeddedImage(l2norm(image_embed.float()), None)
|
||||||
|
|
||||||
# classifier free guidance functions
|
# classifier free guidance functions
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user