mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
fix classifier free guidance for diffusion prior, thanks to @jaykim9870 for spotting the issue
This commit is contained in:
@@ -978,7 +978,10 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
# dalle1 learned padding strategy
|
||||
|
||||
self.max_text_len = max_text_len
|
||||
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, dim))
|
||||
|
||||
self.null_text_encodings = nn.Parameter(torch.randn(1, max_text_len, dim))
|
||||
self.null_text_embeds = nn.Parameter(torch.randn(1, num_text_embeds, dim))
|
||||
self.null_image_embed = nn.Parameter(torch.randn(1, dim))
|
||||
|
||||
# whether to use self conditioning, Hinton's group's new ddpm technique
|
||||
|
||||
@@ -995,7 +998,7 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
if cond_scale == 1:
|
||||
return logits
|
||||
|
||||
null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
|
||||
null_logits = self.forward(*args, text_cond_drop_prob = 1., image_cond_drop_prob = 1, **kwargs)
|
||||
return null_logits + (logits - null_logits) * cond_scale
|
||||
|
||||
def forward(
|
||||
@@ -1006,7 +1009,8 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
text_embed,
|
||||
text_encodings = None,
|
||||
self_cond = None,
|
||||
cond_drop_prob = 0.
|
||||
text_cond_drop_prob = 0.,
|
||||
image_cond_drop_prob = 0.
|
||||
):
|
||||
batch, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype
|
||||
|
||||
@@ -1024,6 +1028,14 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
text_embed = self.to_text_embeds(text_embed)
|
||||
image_embed = self.to_image_embeds(image_embed)
|
||||
|
||||
# classifier free guidance masks
|
||||
|
||||
text_keep_mask = prob_mask_like((batch,), 1 - text_cond_drop_prob, device = device)
|
||||
text_keep_mask = rearrange(text_keep_mask, 'b -> b 1 1')
|
||||
|
||||
image_keep_mask = prob_mask_like((batch,), 1 - image_cond_drop_prob, device = device)
|
||||
image_keep_mask = rearrange(image_keep_mask, 'b -> b 1 1')
|
||||
|
||||
# make text encodings optional
|
||||
# although the paper seems to suggest it is present <--
|
||||
|
||||
@@ -1044,32 +1056,39 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
text_encodings = F.pad(text_encodings, (0, 0, 0, remainder), value = 0.)
|
||||
mask = F.pad(mask, (0, remainder), value = False)
|
||||
|
||||
null_text_embeds = self.null_text_embed.to(text_encodings.dtype)
|
||||
# mask out text encodings with null encodings
|
||||
|
||||
null_text_encodings = self.null_text_encodings.to(text_encodings.dtype)
|
||||
|
||||
text_encodings = torch.where(
|
||||
rearrange(mask, 'b n -> b n 1').clone(),
|
||||
rearrange(mask, 'b n -> b n 1').clone() & text_keep_mask,
|
||||
text_encodings,
|
||||
null_text_encodings
|
||||
)
|
||||
|
||||
# mask out text embeddings with null text embeddings
|
||||
|
||||
null_text_embeds = self.null_text_embeds.to(text_embed.dtype)
|
||||
|
||||
text_embeds = torch.where(
|
||||
text_keep_mask,
|
||||
text_embed,
|
||||
null_text_embeds
|
||||
)
|
||||
|
||||
# classifier free guidance
|
||||
# mask out image embeddings with null image embeddings
|
||||
|
||||
keep_mask = prob_mask_like((batch,), 1 - cond_drop_prob, device = device)
|
||||
keep_mask = rearrange(keep_mask, 'b -> b 1')
|
||||
null_image_embed = self.null_image_embed.to(image_embed.dtype)
|
||||
|
||||
mask &= keep_mask
|
||||
|
||||
# whether text embedding is masked or not depends on the classifier free guidance conditional masking
|
||||
|
||||
keep_mask = repeat(keep_mask, 'b 1 -> b n', n = num_text_embeds)
|
||||
mask = torch.cat((mask, keep_mask), dim = 1)
|
||||
image_embed = torch.where(
|
||||
image_keep_mask,
|
||||
image_embed,
|
||||
null_image_embed
|
||||
)
|
||||
|
||||
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
|
||||
# but let's just do it right
|
||||
|
||||
attend_padding = 1 + num_time_embeds + num_image_embeds + int(self.self_cond) # 1 for learned queries + number of image embeds + time embeds
|
||||
mask = F.pad(mask, (0, attend_padding), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
|
||||
|
||||
time_embed = self.to_time_embeds(diffusion_timesteps)
|
||||
|
||||
learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch)
|
||||
@@ -1107,6 +1126,8 @@ class DiffusionPrior(nn.Module):
|
||||
timesteps = 1000,
|
||||
sample_timesteps = None,
|
||||
cond_drop_prob = 0.,
|
||||
text_cond_drop_prob = None,
|
||||
image_cond_drop_prob = None,
|
||||
loss_type = "l2",
|
||||
predict_x_start = True,
|
||||
beta_schedule = "cosine",
|
||||
@@ -1147,8 +1168,10 @@ class DiffusionPrior(nn.Module):
|
||||
self.image_embed_dim = default(image_embed_dim, lambda: clip.dim_latent)
|
||||
self.channels = default(image_channels, lambda: clip.image_channels)
|
||||
|
||||
self.cond_drop_prob = cond_drop_prob
|
||||
self.can_classifier_guidance = cond_drop_prob > 0.
|
||||
self.text_cond_drop_prob = default(text_cond_drop_prob, cond_drop_prob)
|
||||
self.image_cond_drop_prob = default(image_cond_drop_prob, cond_drop_prob)
|
||||
|
||||
self.can_classifier_guidance = self.text_cond_drop_prob > 0. and self.image_cond_drop_prob > 0.
|
||||
self.condition_on_text_encodings = condition_on_text_encodings
|
||||
|
||||
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
|
||||
@@ -1308,7 +1331,8 @@ class DiffusionPrior(nn.Module):
|
||||
image_embed_noisy,
|
||||
times,
|
||||
self_cond = self_cond,
|
||||
cond_drop_prob = self.cond_drop_prob,
|
||||
text_cond_drop_prob = self.text_cond_drop_prob,
|
||||
image_cond_drop_prob = self.image_cond_drop_prob,
|
||||
**text_cond
|
||||
)
|
||||
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = '1.8.4'
|
||||
__version__ = '1.9.0'
|
||||
|
||||
Reference in New Issue
Block a user