mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
move the inverse normalization function called before image embeddings are derived from clip to within the diffusion prior and decoder classes
This commit is contained in:
@@ -278,7 +278,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
|
|||||||
def embed_image(self, image):
|
def embed_image(self, image):
|
||||||
assert not self.cleared
|
assert not self.cleared
|
||||||
image = resize_image_to(image, self.image_size)
|
image = resize_image_to(image, self.image_size)
|
||||||
image = self.clip_normalize(unnormalize_img(image))
|
image = self.clip_normalize(image)
|
||||||
image_embed = self.clip.encode_image(image)
|
image_embed = self.clip.encode_image(image)
|
||||||
return EmbeddedImage(l2norm(image_embed.float()), None)
|
return EmbeddedImage(l2norm(image_embed.float()), None)
|
||||||
|
|
||||||
@@ -1037,7 +1037,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
|||||||
assert not (self.condition_on_text_encodings and (not exists(text_encodings) and not exists(text))), 'text encodings must be present if you specified you wish to condition on it on initialization'
|
assert not (self.condition_on_text_encodings and (not exists(text_encodings) and not exists(text))), 'text encodings must be present if you specified you wish to condition on it on initialization'
|
||||||
|
|
||||||
if exists(image):
|
if exists(image):
|
||||||
image_embed, _ = self.clip.embed_image(image)
|
image_embed, _ = self.clip.embed_image(unnormalize_img(image))
|
||||||
|
|
||||||
# calculate text conditionings, based on what is passed in
|
# calculate text conditionings, based on what is passed in
|
||||||
|
|
||||||
@@ -2011,7 +2011,7 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
|
|
||||||
if not exists(image_embed):
|
if not exists(image_embed):
|
||||||
assert exists(self.clip), 'if you want to derive CLIP image embeddings automatically, you must supply `clip` to the decoder on init'
|
assert exists(self.clip), 'if you want to derive CLIP image embeddings automatically, you must supply `clip` to the decoder on init'
|
||||||
image_embed, _ = self.clip.embed_image(image)
|
image_embed, _ = self.clip.embed_image(unnormalize_img(image))
|
||||||
|
|
||||||
text_encodings = text_mask = None
|
text_encodings = text_mask = None
|
||||||
if exists(text) and not exists(text_encodings):
|
if exists(text) and not exists(text_encodings):
|
||||||
|
|||||||
Reference in New Issue
Block a user