diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 0b781a0..c9cf125 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -765,7 +765,7 @@ class DiffusionPriorNetwork(nn.Module): # but let's just do it right if exists(mask): - mask = F.pad(mask, (0, 2), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query + mask = F.pad(mask, (0, 3), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query time_embed = self.time_embeddings(diffusion_timesteps) time_embed = rearrange(time_embed, 'b d -> b 1 d') @@ -776,6 +776,7 @@ class DiffusionPriorNetwork(nn.Module): text_encodings, text_embed, time_embed, + image_embed, learned_queries ), dim = -2) diff --git a/setup.py b/setup.py index 3613174..6d86262 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.1.5', + version = '0.1.6', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',