mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-20 18:24:19 +01:00
better naming
This commit is contained in:
@@ -320,8 +320,8 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
# but let's just do it right
|
||||
|
||||
if exists(mask):
|
||||
all_masked_out = mask.any(dim = -1)
|
||||
mask = torch.cat((mask, rearrange(all_masked_out, 'b -> b 1')), dim = 1)
|
||||
not_all_masked_out = mask.any(dim = -1)
|
||||
mask = torch.cat((mask, rearrange(not_all_masked_out, 'b -> b 1')), dim = 1)
|
||||
|
||||
if exists(mask):
|
||||
mask = F.pad(mask, (0, 2), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
|
||||
|
||||
Reference in New Issue
Block a user