mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
fix classifier free guidance for image hiddens summed to time hiddens, thanks to @xvjiarui for finding this bug
This commit is contained in:
@@ -1422,6 +1422,7 @@ class Unet(nn.Module):
|
|||||||
# for classifier free guidance
|
# for classifier free guidance
|
||||||
|
|
||||||
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
|
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
|
||||||
|
self.null_image_hiddens = nn.Parameter(torch.randn(1, time_cond_dim))
|
||||||
|
|
||||||
self.max_text_len = max_text_len
|
self.max_text_len = max_text_len
|
||||||
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim))
|
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim))
|
||||||
@@ -1559,19 +1560,28 @@ class Unet(nn.Module):
|
|||||||
time_tokens = self.to_time_tokens(time_hiddens)
|
time_tokens = self.to_time_tokens(time_hiddens)
|
||||||
t = self.to_time_cond(time_hiddens)
|
t = self.to_time_cond(time_hiddens)
|
||||||
|
|
||||||
# image embedding to be summed to time embedding
|
|
||||||
# discovered by @mhh0318 in the paper
|
|
||||||
|
|
||||||
if exists(image_embed) and exists(self.to_image_hiddens):
|
|
||||||
image_hiddens = self.to_image_hiddens(image_embed)
|
|
||||||
t = t + image_hiddens
|
|
||||||
|
|
||||||
# conditional dropout
|
# conditional dropout
|
||||||
|
|
||||||
image_keep_mask = prob_mask_like((batch_size,), 1 - image_cond_drop_prob, device = device)
|
image_keep_mask = prob_mask_like((batch_size,), 1 - image_cond_drop_prob, device = device)
|
||||||
text_keep_mask = prob_mask_like((batch_size,), 1 - text_cond_drop_prob, device = device)
|
text_keep_mask = prob_mask_like((batch_size,), 1 - text_cond_drop_prob, device = device)
|
||||||
|
|
||||||
image_keep_mask, text_keep_mask = rearrange_many((image_keep_mask, text_keep_mask), 'b -> b 1 1')
|
text_keep_mask = rearrange(text_keep_mask, 'b -> b 1 1')
|
||||||
|
|
||||||
|
# image embedding to be summed to time embedding
|
||||||
|
# discovered by @mhh0318 in the paper
|
||||||
|
|
||||||
|
if exists(image_embed) and exists(self.to_image_hiddens):
|
||||||
|
image_hiddens = self.to_image_hiddens(image_embed)
|
||||||
|
image_keep_mask_hidden = rearrange(image_keep_mask, 'b -> b 1')
|
||||||
|
null_image_hiddens = self.null_image_hiddens.to(image_hiddens.dtype)
|
||||||
|
|
||||||
|
image_hiddens = torch.where(
|
||||||
|
image_keep_mask_hidden,
|
||||||
|
image_hiddens,
|
||||||
|
null_image_hiddens
|
||||||
|
)
|
||||||
|
|
||||||
|
t = t + image_hiddens
|
||||||
|
|
||||||
# mask out image embedding depending on condition dropout
|
# mask out image embedding depending on condition dropout
|
||||||
# for classifier free guidance
|
# for classifier free guidance
|
||||||
@@ -1579,11 +1589,12 @@ class Unet(nn.Module):
|
|||||||
image_tokens = None
|
image_tokens = None
|
||||||
|
|
||||||
if self.cond_on_image_embeds:
|
if self.cond_on_image_embeds:
|
||||||
|
image_keep_mask_embed = rearrange(image_keep_mask, 'b -> b 1 1')
|
||||||
image_tokens = self.image_to_tokens(image_embed)
|
image_tokens = self.image_to_tokens(image_embed)
|
||||||
null_image_embed = self.null_image_embed.to(image_tokens.dtype) # for some reason pytorch AMP not working
|
null_image_embed = self.null_image_embed.to(image_tokens.dtype) # for some reason pytorch AMP not working
|
||||||
|
|
||||||
image_tokens = torch.where(
|
image_tokens = torch.where(
|
||||||
image_keep_mask,
|
image_keep_mask_embed,
|
||||||
image_tokens,
|
image_tokens,
|
||||||
null_image_embed
|
null_image_embed
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = '0.6.16'
|
__version__ = '0.7.0'
|
||||||
|
|||||||
Reference in New Issue
Block a user