mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
fix a potentially huge bug thanks to @CiaoHe https://github.com/lucidrains/DALLE2-pytorch/issues/71
This commit is contained in:
@@ -765,7 +765,7 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
# but let's just do it right
|
||||
|
||||
if exists(mask):
|
||||
mask = F.pad(mask, (0, 2), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
|
||||
mask = F.pad(mask, (0, 3), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
|
||||
|
||||
time_embed = self.time_embeddings(diffusion_timesteps)
|
||||
time_embed = rearrange(time_embed, 'b d -> b 1 d')
|
||||
@@ -776,6 +776,7 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
text_encodings,
|
||||
text_embed,
|
||||
time_embed,
|
||||
image_embed,
|
||||
learned_queries
|
||||
), dim = -2)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user