mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
fix a bug with classifier free guidance, thanks to @xiankgx again!
This commit is contained in:
@@ -688,14 +688,14 @@ class DiffusionPriorNetwork(nn.Module):
|
|||||||
|
|
||||||
# classifier free guidance
|
# classifier free guidance
|
||||||
|
|
||||||
cond_prob_mask = prob_mask_like((batch,), cond_drop_prob, device = device)
|
keep_mask = prob_mask_like((batch,), 1 - cond_drop_prob, device = device)
|
||||||
cond_prob_mask = rearrange(cond_prob_mask, 'b -> b 1')
|
keep_mask = rearrange(keep_mask, 'b -> b 1')
|
||||||
|
|
||||||
mask &= cond_prob_mask
|
mask &= keep_mask
|
||||||
|
|
||||||
# whether text embedding is masked or not depends on the classifier free guidance conditional masking
|
# whether text embedding is masked or not depends on the classifier free guidance conditional masking
|
||||||
|
|
||||||
mask = torch.cat((mask, cond_prob_mask), dim = 1)
|
mask = torch.cat((mask, keep_mask), dim = 1)
|
||||||
|
|
||||||
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
|
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
|
||||||
# but let's just do it right
|
# but let's just do it right
|
||||||
@@ -1208,8 +1208,8 @@ class Unet(nn.Module):
|
|||||||
|
|
||||||
# conditional dropout
|
# conditional dropout
|
||||||
|
|
||||||
cond_prob_mask = prob_mask_like((batch_size,), cond_drop_prob, device = device)
|
keep_mask = prob_mask_like((batch_size,), 1 - cond_drop_prob, device = device)
|
||||||
cond_prob_mask = rearrange(cond_prob_mask, 'b -> b 1 1')
|
keep_mask = rearrange(keep_mask, 'b -> b 1 1')
|
||||||
|
|
||||||
# mask out image embedding depending on condition dropout
|
# mask out image embedding depending on condition dropout
|
||||||
# for classifier free guidance
|
# for classifier free guidance
|
||||||
@@ -1220,7 +1220,7 @@ class Unet(nn.Module):
|
|||||||
image_tokens = self.image_to_cond(image_embed)
|
image_tokens = self.image_to_cond(image_embed)
|
||||||
|
|
||||||
image_tokens = torch.where(
|
image_tokens = torch.where(
|
||||||
cond_prob_mask,
|
keep_mask,
|
||||||
image_tokens,
|
image_tokens,
|
||||||
self.null_image_embed
|
self.null_image_embed
|
||||||
)
|
)
|
||||||
@@ -1232,7 +1232,7 @@ class Unet(nn.Module):
|
|||||||
if exists(text_encodings) and self.cond_on_text_encodings:
|
if exists(text_encodings) and self.cond_on_text_encodings:
|
||||||
text_tokens = self.text_to_cond(text_encodings)
|
text_tokens = self.text_to_cond(text_encodings)
|
||||||
text_tokens = torch.where(
|
text_tokens = torch.where(
|
||||||
cond_prob_mask,
|
keep_mask,
|
||||||
text_tokens,
|
text_tokens,
|
||||||
self.null_text_embed[:, :text_tokens.shape[1]]
|
self.null_text_embed[:, :text_tokens.shape[1]]
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user