This commit is contained in:
Phil Wang
2022-05-05 07:28:53 -07:00
parent d8d8b6caf1
commit 1d5dc08810
2 changed files with 17 additions and 4 deletions

View File

@@ -756,7 +756,8 @@ class DiffusionPrior(BaseGaussianDiffusion):
predict_x_start = True, predict_x_start = True,
beta_schedule = "cosine", beta_schedule = "cosine",
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
sampling_clamp_l2norm = False sampling_clamp_l2norm = False,
image_embed_scale = None, # this is for scaling the l2-normed image embedding, so it is more suitable for gaussian diffusion, as outlined by Katherine (@crowsonkb) https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
): ):
super().__init__( super().__init__(
beta_schedule = beta_schedule, beta_schedule = beta_schedule,
@@ -782,8 +783,11 @@ class DiffusionPrior(BaseGaussianDiffusion):
self.cond_drop_prob = cond_drop_prob self.cond_drop_prob = cond_drop_prob
self.condition_on_text_encodings = condition_on_text_encodings self.condition_on_text_encodings = condition_on_text_encodings
self.predict_x_start = predict_x_start
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both. # in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
self.predict_x_start = predict_x_start
# @crowsonkb 's suggestion - https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
self.image_embed_scale = default(image_embed_scale, image_embed_dim ** 0.5)
# whether to force an l2norm, similar to clipping denoised, when sampling # whether to force an l2norm, similar to clipping denoised, when sampling
self.sampling_clamp_l2norm = sampling_clamp_l2norm self.sampling_clamp_l2norm = sampling_clamp_l2norm
@@ -802,7 +806,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
x_recon.clamp_(-1., 1.) x_recon.clamp_(-1., 1.)
if self.predict_x_start and self.sampling_clamp_l2norm: if self.predict_x_start and self.sampling_clamp_l2norm:
x_recon = l2norm(x_recon) x_recon = l2norm(x_recon) * self.image_embed_scale
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
return model_mean, posterior_variance, posterior_log_variance return model_mean, posterior_variance, posterior_log_variance
@@ -862,6 +866,11 @@ class DiffusionPrior(BaseGaussianDiffusion):
text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask} text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask}
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond) image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond)
# retrieve original unscaled image embed
image_embeds /= self.image_embed_scale
text_embeds = text_cond['text_embed'] text_embeds = text_cond['text_embed']
text_embeds = rearrange(text_embeds, '(b r) d -> b r d', r = num_samples_per_batch) text_embeds = rearrange(text_embeds, '(b r) d -> b r d', r = num_samples_per_batch)
@@ -909,6 +918,10 @@ class DiffusionPrior(BaseGaussianDiffusion):
batch, device = image_embed.shape[0], image_embed.device batch, device = image_embed.shape[0], image_embed.device
times = torch.randint(0, self.num_timesteps, (batch,), device = device, dtype = torch.long) times = torch.randint(0, self.num_timesteps, (batch,), device = device, dtype = torch.long)
# scale image embed (Katherine)
image_embed *= self.image_embed_scale
# calculate forward loss # calculate forward loss
return self.p_losses(image_embed, times, text_cond = text_cond, *args, **kwargs) return self.p_losses(image_embed, times, text_cond = text_cond, *args, **kwargs)

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.0.102', version = '0.0.104',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',