From 8b5c2385b05e74e28d37655cf6e9be5f7c3bc7ef Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Thu, 14 Apr 2022 09:24:31 -0700 Subject: [PATCH] better naming --- dalle2_pytorch/dalle2_pytorch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index b9ef22b..e704e74 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -320,8 +320,8 @@ class DiffusionPriorNetwork(nn.Module): # but let's just do it right if exists(mask): - all_masked_out = mask.any(dim = -1) - mask = torch.cat((mask, rearrange(all_masked_out, 'b -> b 1')), dim = 1) + not_all_masked_out = mask.any(dim = -1) + mask = torch.cat((mask, rearrange(not_all_masked_out, 'b -> b 1')), dim = 1) if exists(mask): mask = F.pad(mask, (0, 2), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query