mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
take @crowsonkb 's suggestion at https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
This commit is contained in:
@@ -756,7 +756,8 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
predict_x_start = True,
|
||||
beta_schedule = "cosine",
|
||||
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
|
||||
sampling_clamp_l2norm = False
|
||||
sampling_clamp_l2norm = False,
|
||||
image_embed_scale = None, # this is for scaling the l2-normed image embedding, so it is more suitable for gaussian diffusion, as outlined by Katherine (@crowsonkb) https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
|
||||
):
|
||||
super().__init__(
|
||||
beta_schedule = beta_schedule,
|
||||
@@ -782,8 +783,11 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
self.cond_drop_prob = cond_drop_prob
|
||||
self.condition_on_text_encodings = condition_on_text_encodings
|
||||
|
||||
self.predict_x_start = predict_x_start
|
||||
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
|
||||
self.predict_x_start = predict_x_start
|
||||
|
||||
# @crowsonkb 's suggestion - https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
|
||||
self.image_embed_scale = default(image_embed_scale, image_embed_dim ** 0.5)
|
||||
|
||||
# whether to force an l2norm, similar to clipping denoised, when sampling
|
||||
self.sampling_clamp_l2norm = sampling_clamp_l2norm
|
||||
@@ -802,7 +806,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
x_recon.clamp_(-1., 1.)
|
||||
|
||||
if self.predict_x_start and self.sampling_clamp_l2norm:
|
||||
x_recon = l2norm(x_recon)
|
||||
x_recon = l2norm(x_recon) * self.image_embed_scale
|
||||
|
||||
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
||||
return model_mean, posterior_variance, posterior_log_variance
|
||||
@@ -862,6 +866,11 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask}
|
||||
|
||||
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond)
|
||||
|
||||
# retrieve original unscaled image embed
|
||||
|
||||
image_embeds /= self.image_embed_scale
|
||||
|
||||
text_embeds = text_cond['text_embed']
|
||||
|
||||
text_embeds = rearrange(text_embeds, '(b r) d -> b r d', r = num_samples_per_batch)
|
||||
@@ -909,6 +918,10 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
batch, device = image_embed.shape[0], image_embed.device
|
||||
times = torch.randint(0, self.num_timesteps, (batch,), device = device, dtype = torch.long)
|
||||
|
||||
# scale image embed (Katherine)
|
||||
|
||||
image_embed *= self.image_embed_scale
|
||||
|
||||
# calculate forward loss
|
||||
|
||||
return self.p_losses(image_embed, times, text_cond = text_cond, *args, **kwargs)
|
||||
|
||||
Reference in New Issue
Block a user