fix a bug with classifier free guidance, thanks to @xiankgx again!

This commit is contained in:
Phil Wang
2022-04-30 06:34:18 -07:00
parent a389f81138
commit 0d1c07c803
2 changed files with 9 additions and 9 deletions

View File

@@ -688,14 +688,14 @@ class DiffusionPriorNetwork(nn.Module):
# classifier free guidance
cond_prob_mask = prob_mask_like((batch,), cond_drop_prob, device = device)
cond_prob_mask = rearrange(cond_prob_mask, 'b -> b 1')
keep_mask = prob_mask_like((batch,), 1 - cond_drop_prob, device = device)
keep_mask = rearrange(keep_mask, 'b -> b 1')
mask &= cond_prob_mask
mask &= keep_mask
# whether text embedding is masked or not depends on the classifier free guidance conditional masking
mask = torch.cat((mask, cond_prob_mask), dim = 1)
mask = torch.cat((mask, keep_mask), dim = 1)
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
# but let's just do it right
@@ -1208,8 +1208,8 @@ class Unet(nn.Module):
# conditional dropout
cond_prob_mask = prob_mask_like((batch_size,), cond_drop_prob, device = device)
cond_prob_mask = rearrange(cond_prob_mask, 'b -> b 1 1')
keep_mask = prob_mask_like((batch_size,), 1 - cond_drop_prob, device = device)
keep_mask = rearrange(keep_mask, 'b -> b 1 1')
# mask out image embedding depending on condition dropout
# for classifier free guidance
@@ -1220,7 +1220,7 @@ class Unet(nn.Module):
image_tokens = self.image_to_cond(image_embed)
image_tokens = torch.where(
cond_prob_mask,
keep_mask,
image_tokens,
self.null_image_embed
)
@@ -1232,7 +1232,7 @@ class Unet(nn.Module):
if exists(text_encodings) and self.cond_on_text_encodings:
text_tokens = self.text_to_cond(text_encodings)
text_tokens = torch.where(
cond_prob_mask,
keep_mask,
text_tokens,
self.null_text_embed[:, :text_tokens.shape[1]]
)

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.71',
version = '0.0.72',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',