mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-19 15:14:27 +01:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3115fa17b3 |
@@ -114,10 +114,10 @@ def resize_image_to(image, target_image_size):
|
|||||||
# ddpms expect images to be in the range of -1 to 1
|
# ddpms expect images to be in the range of -1 to 1
|
||||||
# but CLIP may otherwise
|
# but CLIP may otherwise
|
||||||
|
|
||||||
def normalize_img(img):
|
def normalize_neg_one_to_one(img):
|
||||||
return img * 2 - 1
|
return img * 2 - 1
|
||||||
|
|
||||||
def unnormalize_img(normed_img):
|
def unnormalize_zero_to_one(normed_img):
|
||||||
return (normed_img + 1) * 0.5
|
return (normed_img + 1) * 0.5
|
||||||
|
|
||||||
# clip related adapters
|
# clip related adapters
|
||||||
@@ -1037,7 +1037,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
|||||||
assert not (self.condition_on_text_encodings and (not exists(text_encodings) and not exists(text))), 'text encodings must be present if you specified you wish to condition on it on initialization'
|
assert not (self.condition_on_text_encodings and (not exists(text_encodings) and not exists(text))), 'text encodings must be present if you specified you wish to condition on it on initialization'
|
||||||
|
|
||||||
if exists(image):
|
if exists(image):
|
||||||
image_embed, _ = self.clip.embed_image(unnormalize_img(image))
|
image_embed, _ = self.clip.embed_image(image)
|
||||||
|
|
||||||
# calculate text conditionings, based on what is passed in
|
# calculate text conditionings, based on what is passed in
|
||||||
|
|
||||||
@@ -1821,7 +1821,7 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
# eq 15 - https://arxiv.org/abs/2102.09672
|
# eq 15 - https://arxiv.org/abs/2102.09672
|
||||||
min_log = extract(self.posterior_log_variance_clipped, t, x.shape)
|
min_log = extract(self.posterior_log_variance_clipped, t, x.shape)
|
||||||
max_log = extract(torch.log(self.betas), t, x.shape)
|
max_log = extract(torch.log(self.betas), t, x.shape)
|
||||||
var_interp_frac = unnormalize_img(var_interp_frac_unnormalized)
|
var_interp_frac = unnormalize_zero_to_one(var_interp_frac_unnormalized)
|
||||||
|
|
||||||
posterior_log_variance = var_interp_frac * max_log + (1 - var_interp_frac) * min_log
|
posterior_log_variance = var_interp_frac * max_log + (1 - var_interp_frac) * min_log
|
||||||
posterior_variance = posterior_log_variance.exp()
|
posterior_variance = posterior_log_variance.exp()
|
||||||
@@ -1859,11 +1859,21 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
clip_denoised = clip_denoised
|
clip_denoised = clip_denoised
|
||||||
)
|
)
|
||||||
|
|
||||||
return img
|
unnormalize_img = unnormalize_zero_to_one(img)
|
||||||
|
return unnormalize_img
|
||||||
|
|
||||||
def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False):
|
def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False):
|
||||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||||
|
|
||||||
|
# normalize to [-1, 1]
|
||||||
|
|
||||||
|
x_start = normalize_neg_one_to_one(x_start)
|
||||||
|
|
||||||
|
if exists(lowres_cond_img):
|
||||||
|
lowres_cond_img = normalize_neg_one_to_one(lowres_cond_img)
|
||||||
|
|
||||||
|
# get x_t
|
||||||
|
|
||||||
x_noisy = self.q_sample(x_start = x_start, t = times, noise = noise)
|
x_noisy = self.q_sample(x_start = x_start, t = times, noise = noise)
|
||||||
|
|
||||||
model_output = unet(
|
model_output = unet(
|
||||||
@@ -2011,7 +2021,7 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
|
|
||||||
if not exists(image_embed):
|
if not exists(image_embed):
|
||||||
assert exists(self.clip), 'if you want to derive CLIP image embeddings automatically, you must supply `clip` to the decoder on init'
|
assert exists(self.clip), 'if you want to derive CLIP image embeddings automatically, you must supply `clip` to the decoder on init'
|
||||||
image_embed, _ = self.clip.embed_image(unnormalize_img(image))
|
image_embed, _ = self.clip.embed_image(image)
|
||||||
|
|
||||||
text_encodings = text_mask = None
|
text_encodings = text_mask = None
|
||||||
if exists(text) and not exists(text_encodings):
|
if exists(text) and not exists(text_encodings):
|
||||||
|
|||||||
Reference in New Issue
Block a user