better naming

This commit is contained in:
Phil Wang
2022-04-14 09:24:31 -07:00
parent f2c52d8239
commit 8b5c2385b0

View File

@@ -320,8 +320,8 @@ class DiffusionPriorNetwork(nn.Module):
# but let's just do it right
if exists(mask):
all_masked_out = mask.any(dim = -1)
mask = torch.cat((mask, rearrange(all_masked_out, 'b -> b 1')), dim = 1)
not_all_masked_out = mask.any(dim = -1)
mask = torch.cat((mask, rearrange(not_all_masked_out, 'b -> b 1')), dim = 1)
if exists(mask):
mask = F.pad(mask, (0, 2), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query