fix a potentially huge bug thanks to @CiaoHe https://github.com/lucidrains/DALLE2-pytorch/issues/71

This commit is contained in:
Phil Wang
2022-05-07 05:05:46 -07:00
parent fd53fa17db
commit 85ed77d512
2 changed files with 3 additions and 2 deletions

View File

@@ -765,7 +765,7 @@ class DiffusionPriorNetwork(nn.Module):
# but let's just do it right
if exists(mask):
mask = F.pad(mask, (0, 2), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
mask = F.pad(mask, (0, 3), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
time_embed = self.time_embeddings(diffusion_timesteps)
time_embed = rearrange(time_embed, 'b d -> b 1 d')
@@ -776,6 +776,7 @@ class DiffusionPriorNetwork(nn.Module):
text_encodings,
text_embed,
time_embed,
image_embed,
learned_queries
), dim = -2)

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.1.5',
version = '0.1.6',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',