diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index b9ef22b..e704e74 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -320,8 +320,8 @@ class DiffusionPriorNetwork(nn.Module): # but let's just do it right if exists(mask): - all_masked_out = mask.any(dim = -1) - mask = torch.cat((mask, rearrange(all_masked_out, 'b -> b 1')), dim = 1) + not_all_masked_out = mask.any(dim = -1) + mask = torch.cat((mask, rearrange(not_all_masked_out, 'b -> b 1')), dim = 1) if exists(mask): mask = F.pad(mask, (0, 2), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query