mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-13 03:54:35 +01:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8518684ae9 | ||
|
|
1d5dc08810 |
@@ -652,14 +652,12 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
self,
|
||||
dim,
|
||||
num_timesteps = None,
|
||||
l2norm_output = False, # whether to restrict image embedding output with l2norm at the end (may make it easier to learn?)
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
self.time_embeddings = nn.Embedding(num_timesteps, dim) if exists(num_timesteps) else nn.Sequential(Rearrange('b -> b 1'), MLP(1, dim)) # also offer a continuous version of timestep embeddings, with a 2 layer MLP
|
||||
self.learned_query = nn.Parameter(torch.randn(dim))
|
||||
self.causal_transformer = CausalTransformer(dim = dim, **kwargs)
|
||||
self.l2norm_output = l2norm_output
|
||||
|
||||
def forward_with_cond_scale(
|
||||
self,
|
||||
@@ -738,8 +736,7 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
|
||||
pred_image_embed = tokens[..., -1, :]
|
||||
|
||||
output_fn = l2norm if self.l2norm_output else identity
|
||||
return output_fn(pred_image_embed)
|
||||
return pred_image_embed
|
||||
|
||||
class DiffusionPrior(BaseGaussianDiffusion):
|
||||
def __init__(
|
||||
@@ -756,7 +753,8 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
predict_x_start = True,
|
||||
beta_schedule = "cosine",
|
||||
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
|
||||
sampling_clamp_l2norm = False
|
||||
sampling_clamp_l2norm = False,
|
||||
image_embed_scale = None, # this is for scaling the l2-normed image embedding, so it is more suitable for gaussian diffusion, as outlined by Katherine (@crowsonkb) https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
|
||||
):
|
||||
super().__init__(
|
||||
beta_schedule = beta_schedule,
|
||||
@@ -782,8 +780,11 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
self.cond_drop_prob = cond_drop_prob
|
||||
self.condition_on_text_encodings = condition_on_text_encodings
|
||||
|
||||
self.predict_x_start = predict_x_start
|
||||
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
|
||||
self.predict_x_start = predict_x_start
|
||||
|
||||
# @crowsonkb 's suggestion - https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
|
||||
self.image_embed_scale = default(image_embed_scale, image_embed_dim ** 0.5)
|
||||
|
||||
# whether to force an l2norm, similar to clipping denoised, when sampling
|
||||
self.sampling_clamp_l2norm = sampling_clamp_l2norm
|
||||
@@ -802,7 +803,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
x_recon.clamp_(-1., 1.)
|
||||
|
||||
if self.predict_x_start and self.sampling_clamp_l2norm:
|
||||
x_recon = l2norm(x_recon)
|
||||
x_recon = l2norm(x_recon) * self.image_embed_scale
|
||||
|
||||
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
||||
return model_mean, posterior_variance, posterior_log_variance
|
||||
@@ -862,6 +863,11 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask}
|
||||
|
||||
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond)
|
||||
|
||||
# retrieve original unscaled image embed
|
||||
|
||||
image_embeds /= self.image_embed_scale
|
||||
|
||||
text_embeds = text_cond['text_embed']
|
||||
|
||||
text_embeds = rearrange(text_embeds, '(b r) d -> b r d', r = num_samples_per_batch)
|
||||
@@ -909,6 +915,10 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
batch, device = image_embed.shape[0], image_embed.device
|
||||
times = torch.randint(0, self.num_timesteps, (batch,), device = device, dtype = torch.long)
|
||||
|
||||
# scale image embed (Katherine)
|
||||
|
||||
image_embed *= self.image_embed_scale
|
||||
|
||||
# calculate forward loss
|
||||
|
||||
return self.p_losses(image_embed, times, text_cond = text_cond, *args, **kwargs)
|
||||
|
||||
Reference in New Issue
Block a user