From 5d958713c0753922b64086b2310324c1f3be64e5 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Mon, 13 Jun 2022 21:01:50 -0700 Subject: [PATCH] fix classifier free guidance for image hiddens summed to time hiddens, thanks to @xvjiarui for finding this bug --- dalle2_pytorch/dalle2_pytorch.py | 29 ++++++++++++++++++++--------- dalle2_pytorch/version.py | 2 +- 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index bdab5d3..bea84d7 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1422,6 +1422,7 @@ class Unet(nn.Module): # for classifier free guidance self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim)) + self.null_image_hiddens = nn.Parameter(torch.randn(1, time_cond_dim)) self.max_text_len = max_text_len self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim)) @@ -1559,19 +1560,28 @@ class Unet(nn.Module): time_tokens = self.to_time_tokens(time_hiddens) t = self.to_time_cond(time_hiddens) - # image embedding to be summed to time embedding - # discovered by @mhh0318 in the paper - - if exists(image_embed) and exists(self.to_image_hiddens): - image_hiddens = self.to_image_hiddens(image_embed) - t = t + image_hiddens - # conditional dropout image_keep_mask = prob_mask_like((batch_size,), 1 - image_cond_drop_prob, device = device) text_keep_mask = prob_mask_like((batch_size,), 1 - text_cond_drop_prob, device = device) - image_keep_mask, text_keep_mask = rearrange_many((image_keep_mask, text_keep_mask), 'b -> b 1 1') + text_keep_mask = rearrange(text_keep_mask, 'b -> b 1 1') + + # image embedding to be summed to time embedding + # discovered by @mhh0318 in the paper + + if exists(image_embed) and exists(self.to_image_hiddens): + image_hiddens = self.to_image_hiddens(image_embed) + image_keep_mask_hidden = rearrange(image_keep_mask, 'b -> b 1') + null_image_hiddens = self.null_image_hiddens.to(image_hiddens.dtype) + + image_hiddens = torch.where( + image_keep_mask_hidden, + image_hiddens, + null_image_hiddens + ) + + t = t + image_hiddens # mask out image embedding depending on condition dropout # for classifier free guidance @@ -1579,11 +1589,12 @@ class Unet(nn.Module): image_tokens = None if self.cond_on_image_embeds: + image_keep_mask_embed = rearrange(image_keep_mask, 'b -> b 1 1') image_tokens = self.image_to_tokens(image_embed) null_image_embed = self.null_image_embed.to(image_tokens.dtype) # for some reason pytorch AMP not working image_tokens = torch.where( - image_keep_mask, + image_keep_mask_embed, image_tokens, null_image_embed ) diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index e6ce986..a71c5c7 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.6.16' +__version__ = '0.7.0'