From 775abc4df655c2945987274a0ab5a19a9cc4d45a Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 12 Jul 2022 17:08:12 -0700 Subject: [PATCH] add setting to attend to all text encodings regardless of padding, for diffusion prior --- dalle2_pytorch/dalle2_pytorch.py | 9 +++++++-- dalle2_pytorch/train_configs.py | 1 + dalle2_pytorch/version.py | 2 +- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 7dc015f..8c69b9b 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -806,6 +806,7 @@ class DiffusionPriorNetwork(nn.Module): num_time_embeds = 1, num_image_embeds = 1, num_text_embeds = 1, + attend_all_text_encodings = True, **kwargs ): super().__init__() @@ -831,6 +832,8 @@ class DiffusionPriorNetwork(nn.Module): self.learned_query = nn.Parameter(torch.randn(dim)) self.causal_transformer = CausalTransformer(dim = dim, **kwargs) + self.attend_all_text_encodings = attend_all_text_encodings + def forward_with_cond_scale( self, *args, @@ -852,7 +855,6 @@ class DiffusionPriorNetwork(nn.Module): *, text_embed, text_encodings = None, - mask = None, cond_drop_prob = 0. ): batch, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype @@ -871,7 +873,10 @@ class DiffusionPriorNetwork(nn.Module): if not exists(text_encodings): text_encodings = torch.empty((batch, 0, dim), device = device, dtype = dtype) - mask = torch.any(text_encodings != 0., dim = -1) + if self.attend_all_text_encodings: + mask = torch.ones((batch, text_encodings.shape[-2]), device = device, dtype = torch.bool) + else: + mask = torch.any(text_encodings != 0., dim = -1) # classifier free guidance diff --git a/dalle2_pytorch/train_configs.py b/dalle2_pytorch/train_configs.py index 5f3685e..3f77fbb 100644 --- a/dalle2_pytorch/train_configs.py +++ b/dalle2_pytorch/train_configs.py @@ -133,6 +133,7 @@ class DiffusionPriorNetworkConfig(BaseModel): num_time_embeds: int = 1 num_image_embeds: int = 1 num_text_embeds: int = 1 + attend_all_text_encodings: bool = True dim_head: int = 64 heads: int = 8 ff_mult: int = 4 diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index 66d9d1e..cc37364 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.22.1' +__version__ = '0.22.2'