Compare commits

...

2 Commits

2 changed files with 9 additions and 4 deletions

View File

@@ -278,7 +278,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
def embed_image(self, image):
assert not self.cleared
image = resize_image_to(image, self.image_size)
image = self.clip_normalize(unnormalize_img(image))
image = self.clip_normalize(image)
image_embed = self.clip.encode_image(image)
return EmbeddedImage(l2norm(image_embed.float()), None)
@@ -1037,7 +1037,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
assert not (self.condition_on_text_encodings and (not exists(text_encodings) and not exists(text))), 'text encodings must be present if you specified you wish to condition on it on initialization'
if exists(image):
image_embed, _ = self.clip.embed_image(image)
image_embed, _ = self.clip.embed_image(unnormalize_img(image))
# calculate text conditionings, based on what is passed in
@@ -1890,6 +1890,11 @@ class Decoder(BaseGaussianDiffusion):
# return simple loss if not using learned variance
return loss
# most of the code below is transcribed from
# https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/diffusion_utils_2.py
# the Improved DDPM paper then further modified it so that the mean is detached (shown a couple lines before), and weighted to be smaller than the l1 or l2 "simple" loss
# it is questionable whether this is really needed, looking at some of the figures in the paper, but may as well stay faithful to their implementation
# if learning the variance, also include the extra weight kl loss
true_mean, _, true_log_variance_clipped = self.q_posterior(x_start = x_start, x_t = x_noisy, t = times)
@@ -2006,7 +2011,7 @@ class Decoder(BaseGaussianDiffusion):
if not exists(image_embed):
assert exists(self.clip), 'if you want to derive CLIP image embeddings automatically, you must supply `clip` to the decoder on init'
image_embed, _ = self.clip.embed_image(image)
image_embed, _ = self.clip.embed_image(unnormalize_img(image))
text_encodings = text_mask = None
if exists(text) and not exists(text_encodings):

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.2.14',
version = '0.2.15',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',