fix classifier free guidance for diffusion prior, thanks to @jaykim9870 for spotting the issue

This commit is contained in:
Phil Wang
2022-08-23 08:28:54 -07:00
parent 916ece164c
commit 59fa101c4d
2 changed files with 45 additions and 21 deletions

View File

@@ -978,7 +978,10 @@ class DiffusionPriorNetwork(nn.Module):
# dalle1 learned padding strategy
self.max_text_len = max_text_len
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, dim))
self.null_text_encodings = nn.Parameter(torch.randn(1, max_text_len, dim))
self.null_text_embeds = nn.Parameter(torch.randn(1, num_text_embeds, dim))
self.null_image_embed = nn.Parameter(torch.randn(1, dim))
# whether to use self conditioning, Hinton's group's new ddpm technique
@@ -995,7 +998,7 @@ class DiffusionPriorNetwork(nn.Module):
if cond_scale == 1:
return logits
null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
null_logits = self.forward(*args, text_cond_drop_prob = 1., image_cond_drop_prob = 1, **kwargs)
return null_logits + (logits - null_logits) * cond_scale
def forward(
@@ -1006,7 +1009,8 @@ class DiffusionPriorNetwork(nn.Module):
text_embed,
text_encodings = None,
self_cond = None,
cond_drop_prob = 0.
text_cond_drop_prob = 0.,
image_cond_drop_prob = 0.
):
batch, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype
@@ -1024,6 +1028,14 @@ class DiffusionPriorNetwork(nn.Module):
text_embed = self.to_text_embeds(text_embed)
image_embed = self.to_image_embeds(image_embed)
# classifier free guidance masks
text_keep_mask = prob_mask_like((batch,), 1 - text_cond_drop_prob, device = device)
text_keep_mask = rearrange(text_keep_mask, 'b -> b 1 1')
image_keep_mask = prob_mask_like((batch,), 1 - image_cond_drop_prob, device = device)
image_keep_mask = rearrange(image_keep_mask, 'b -> b 1 1')
# make text encodings optional
# although the paper seems to suggest it is present <--
@@ -1044,32 +1056,39 @@ class DiffusionPriorNetwork(nn.Module):
text_encodings = F.pad(text_encodings, (0, 0, 0, remainder), value = 0.)
mask = F.pad(mask, (0, remainder), value = False)
null_text_embeds = self.null_text_embed.to(text_encodings.dtype)
# mask out text encodings with null encodings
null_text_encodings = self.null_text_encodings.to(text_encodings.dtype)
text_encodings = torch.where(
rearrange(mask, 'b n -> b n 1').clone(),
rearrange(mask, 'b n -> b n 1').clone() & text_keep_mask,
text_encodings,
null_text_encodings
)
# mask out text embeddings with null text embeddings
null_text_embeds = self.null_text_embeds.to(text_embed.dtype)
text_embeds = torch.where(
text_keep_mask,
text_embed,
null_text_embeds
)
# classifier free guidance
# mask out image embeddings with null image embeddings
keep_mask = prob_mask_like((batch,), 1 - cond_drop_prob, device = device)
keep_mask = rearrange(keep_mask, 'b -> b 1')
null_image_embed = self.null_image_embed.to(image_embed.dtype)
mask &= keep_mask
# whether text embedding is masked or not depends on the classifier free guidance conditional masking
keep_mask = repeat(keep_mask, 'b 1 -> b n', n = num_text_embeds)
mask = torch.cat((mask, keep_mask), dim = 1)
image_embed = torch.where(
image_keep_mask,
image_embed,
null_image_embed
)
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
# but let's just do it right
attend_padding = 1 + num_time_embeds + num_image_embeds + int(self.self_cond) # 1 for learned queries + number of image embeds + time embeds
mask = F.pad(mask, (0, attend_padding), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
time_embed = self.to_time_embeds(diffusion_timesteps)
learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch)
@@ -1107,6 +1126,8 @@ class DiffusionPrior(nn.Module):
timesteps = 1000,
sample_timesteps = None,
cond_drop_prob = 0.,
text_cond_drop_prob = None,
image_cond_drop_prob = None,
loss_type = "l2",
predict_x_start = True,
beta_schedule = "cosine",
@@ -1147,8 +1168,10 @@ class DiffusionPrior(nn.Module):
self.image_embed_dim = default(image_embed_dim, lambda: clip.dim_latent)
self.channels = default(image_channels, lambda: clip.image_channels)
self.cond_drop_prob = cond_drop_prob
self.can_classifier_guidance = cond_drop_prob > 0.
self.text_cond_drop_prob = default(text_cond_drop_prob, cond_drop_prob)
self.image_cond_drop_prob = default(image_cond_drop_prob, cond_drop_prob)
self.can_classifier_guidance = self.text_cond_drop_prob > 0. and self.image_cond_drop_prob > 0.
self.condition_on_text_encodings = condition_on_text_encodings
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
@@ -1308,7 +1331,8 @@ class DiffusionPrior(nn.Module):
image_embed_noisy,
times,
self_cond = self_cond,
cond_drop_prob = self.cond_drop_prob,
text_cond_drop_prob = self.text_cond_drop_prob,
image_cond_drop_prob = self.image_cond_drop_prob,
**text_cond
)