mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 11:34:29 +01:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
124d8577c8 | ||
|
|
2db0c9794c |
@@ -278,7 +278,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
|
||||
def embed_image(self, image):
|
||||
assert not self.cleared
|
||||
image = resize_image_to(image, self.image_size)
|
||||
image = self.clip_normalize(unnormalize_img(image))
|
||||
image = self.clip_normalize(image)
|
||||
image_embed = self.clip.encode_image(image)
|
||||
return EmbeddedImage(l2norm(image_embed.float()), None)
|
||||
|
||||
@@ -1037,7 +1037,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
assert not (self.condition_on_text_encodings and (not exists(text_encodings) and not exists(text))), 'text encodings must be present if you specified you wish to condition on it on initialization'
|
||||
|
||||
if exists(image):
|
||||
image_embed, _ = self.clip.embed_image(image)
|
||||
image_embed, _ = self.clip.embed_image(unnormalize_img(image))
|
||||
|
||||
# calculate text conditionings, based on what is passed in
|
||||
|
||||
@@ -1890,6 +1890,11 @@ class Decoder(BaseGaussianDiffusion):
|
||||
# return simple loss if not using learned variance
|
||||
return loss
|
||||
|
||||
# most of the code below is transcribed from
|
||||
# https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/diffusion_utils_2.py
|
||||
# the Improved DDPM paper then further modified it so that the mean is detached (shown a couple lines before), and weighted to be smaller than the l1 or l2 "simple" loss
|
||||
# it is questionable whether this is really needed, looking at some of the figures in the paper, but may as well stay faithful to their implementation
|
||||
|
||||
# if learning the variance, also include the extra weight kl loss
|
||||
|
||||
true_mean, _, true_log_variance_clipped = self.q_posterior(x_start = x_start, x_t = x_noisy, t = times)
|
||||
@@ -2006,7 +2011,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
|
||||
if not exists(image_embed):
|
||||
assert exists(self.clip), 'if you want to derive CLIP image embeddings automatically, you must supply `clip` to the decoder on init'
|
||||
image_embed, _ = self.clip.embed_image(image)
|
||||
image_embed, _ = self.clip.embed_image(unnormalize_img(image))
|
||||
|
||||
text_encodings = text_mask = None
|
||||
if exists(text) and not exists(text_encodings):
|
||||
|
||||
Reference in New Issue
Block a user