fix classifier free guidance for image hiddens summed to time hiddens, thanks to @xvjiarui for finding this bug

This commit is contained in:
Phil Wang
2022-06-13 21:01:50 -07:00
parent 0f31980362
commit 5d958713c0
2 changed files with 21 additions and 10 deletions

View File

@@ -1422,6 +1422,7 @@ class Unet(nn.Module):
# for classifier free guidance # for classifier free guidance
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim)) self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
self.null_image_hiddens = nn.Parameter(torch.randn(1, time_cond_dim))
self.max_text_len = max_text_len self.max_text_len = max_text_len
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim)) self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim))
@@ -1559,19 +1560,28 @@ class Unet(nn.Module):
time_tokens = self.to_time_tokens(time_hiddens) time_tokens = self.to_time_tokens(time_hiddens)
t = self.to_time_cond(time_hiddens) t = self.to_time_cond(time_hiddens)
# image embedding to be summed to time embedding
# discovered by @mhh0318 in the paper
if exists(image_embed) and exists(self.to_image_hiddens):
image_hiddens = self.to_image_hiddens(image_embed)
t = t + image_hiddens
# conditional dropout # conditional dropout
image_keep_mask = prob_mask_like((batch_size,), 1 - image_cond_drop_prob, device = device) image_keep_mask = prob_mask_like((batch_size,), 1 - image_cond_drop_prob, device = device)
text_keep_mask = prob_mask_like((batch_size,), 1 - text_cond_drop_prob, device = device) text_keep_mask = prob_mask_like((batch_size,), 1 - text_cond_drop_prob, device = device)
image_keep_mask, text_keep_mask = rearrange_many((image_keep_mask, text_keep_mask), 'b -> b 1 1') text_keep_mask = rearrange(text_keep_mask, 'b -> b 1 1')
# image embedding to be summed to time embedding
# discovered by @mhh0318 in the paper
if exists(image_embed) and exists(self.to_image_hiddens):
image_hiddens = self.to_image_hiddens(image_embed)
image_keep_mask_hidden = rearrange(image_keep_mask, 'b -> b 1')
null_image_hiddens = self.null_image_hiddens.to(image_hiddens.dtype)
image_hiddens = torch.where(
image_keep_mask_hidden,
image_hiddens,
null_image_hiddens
)
t = t + image_hiddens
# mask out image embedding depending on condition dropout # mask out image embedding depending on condition dropout
# for classifier free guidance # for classifier free guidance
@@ -1579,11 +1589,12 @@ class Unet(nn.Module):
image_tokens = None image_tokens = None
if self.cond_on_image_embeds: if self.cond_on_image_embeds:
image_keep_mask_embed = rearrange(image_keep_mask, 'b -> b 1 1')
image_tokens = self.image_to_tokens(image_embed) image_tokens = self.image_to_tokens(image_embed)
null_image_embed = self.null_image_embed.to(image_tokens.dtype) # for some reason pytorch AMP not working null_image_embed = self.null_image_embed.to(image_tokens.dtype) # for some reason pytorch AMP not working
image_tokens = torch.where( image_tokens = torch.where(
image_keep_mask, image_keep_mask_embed,
image_tokens, image_tokens,
null_image_embed null_image_embed
) )

View File

@@ -1 +1 @@
__version__ = '0.6.16' __version__ = '0.7.0'