mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
take @crowsonkb 's suggestion at https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
This commit is contained in:
@@ -756,7 +756,8 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
|||||||
predict_x_start = True,
|
predict_x_start = True,
|
||||||
beta_schedule = "cosine",
|
beta_schedule = "cosine",
|
||||||
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
|
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
|
||||||
sampling_clamp_l2norm = False
|
sampling_clamp_l2norm = False,
|
||||||
|
image_embed_scale = None, # this is for scaling the l2-normed image embedding, so it is more suitable for gaussian diffusion, as outlined by Katherine (@crowsonkb) https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
beta_schedule = beta_schedule,
|
beta_schedule = beta_schedule,
|
||||||
@@ -782,8 +783,11 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
|||||||
self.cond_drop_prob = cond_drop_prob
|
self.cond_drop_prob = cond_drop_prob
|
||||||
self.condition_on_text_encodings = condition_on_text_encodings
|
self.condition_on_text_encodings = condition_on_text_encodings
|
||||||
|
|
||||||
self.predict_x_start = predict_x_start
|
|
||||||
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
|
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
|
||||||
|
self.predict_x_start = predict_x_start
|
||||||
|
|
||||||
|
# @crowsonkb 's suggestion - https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
|
||||||
|
self.image_embed_scale = default(image_embed_scale, image_embed_dim ** 0.5)
|
||||||
|
|
||||||
# whether to force an l2norm, similar to clipping denoised, when sampling
|
# whether to force an l2norm, similar to clipping denoised, when sampling
|
||||||
self.sampling_clamp_l2norm = sampling_clamp_l2norm
|
self.sampling_clamp_l2norm = sampling_clamp_l2norm
|
||||||
@@ -802,7 +806,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
|||||||
x_recon.clamp_(-1., 1.)
|
x_recon.clamp_(-1., 1.)
|
||||||
|
|
||||||
if self.predict_x_start and self.sampling_clamp_l2norm:
|
if self.predict_x_start and self.sampling_clamp_l2norm:
|
||||||
x_recon = l2norm(x_recon)
|
x_recon = l2norm(x_recon) * self.image_embed_scale
|
||||||
|
|
||||||
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
||||||
return model_mean, posterior_variance, posterior_log_variance
|
return model_mean, posterior_variance, posterior_log_variance
|
||||||
@@ -862,6 +866,11 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
|||||||
text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask}
|
text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask}
|
||||||
|
|
||||||
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond)
|
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond)
|
||||||
|
|
||||||
|
# retrieve original unscaled image embed
|
||||||
|
|
||||||
|
image_embeds /= self.image_embed_scale
|
||||||
|
|
||||||
text_embeds = text_cond['text_embed']
|
text_embeds = text_cond['text_embed']
|
||||||
|
|
||||||
text_embeds = rearrange(text_embeds, '(b r) d -> b r d', r = num_samples_per_batch)
|
text_embeds = rearrange(text_embeds, '(b r) d -> b r d', r = num_samples_per_batch)
|
||||||
@@ -909,6 +918,10 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
|||||||
batch, device = image_embed.shape[0], image_embed.device
|
batch, device = image_embed.shape[0], image_embed.device
|
||||||
times = torch.randint(0, self.num_timesteps, (batch,), device = device, dtype = torch.long)
|
times = torch.randint(0, self.num_timesteps, (batch,), device = device, dtype = torch.long)
|
||||||
|
|
||||||
|
# scale image embed (Katherine)
|
||||||
|
|
||||||
|
image_embed *= self.image_embed_scale
|
||||||
|
|
||||||
# calculate forward loss
|
# calculate forward loss
|
||||||
|
|
||||||
return self.p_losses(image_embed, times, text_cond = text_cond, *args, **kwargs)
|
return self.p_losses(image_embed, times, text_cond = text_cond, *args, **kwargs)
|
||||||
|
|||||||
Reference in New Issue
Block a user