From f2c52d82397e19ddb47e25d713b44dccf0af5a10 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Thu, 14 Apr 2022 09:21:51 -0700 Subject: [PATCH] fix bug with classifier free guidance for prior network, even though it seems it may not be used --- dalle2_pytorch/dalle2_pytorch.py | 9 ++++++++- setup.py | 2 +- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 4840548..b9ef22b 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -316,8 +316,15 @@ class DiffusionPriorNetwork(nn.Module): text_embed, image_embed = rearrange_many((text_embed, image_embed), 'b d -> b 1 d') + # whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different) + # but let's just do it right + if exists(mask): - mask = F.pad(mask, (0, 3), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query + all_masked_out = mask.any(dim = -1) + mask = torch.cat((mask, rearrange(all_masked_out, 'b -> b 1')), dim = 1) + + if exists(mask): + mask = F.pad(mask, (0, 2), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query time_embed = self.time_embeddings(diffusion_timesteps) time_embed = rearrange(time_embed, 'b d -> b 1 d') diff --git a/setup.py b/setup.py index 5949f48..a13aeb1 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.8', + version = '0.0.9', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',