From 5ffc3410610e744e205e2074c23f18b0ed32fb77 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 12 Jul 2022 17:32:09 -0700 Subject: [PATCH] add learned padding tokens, same strategy as dalle1, for diffusion prior, and get rid of masking in causal transformer --- dalle2_pytorch/dalle2_pytorch.py | 42 ++++++++++++++++++++++---------- dalle2_pytorch/train_configs.py | 2 +- dalle2_pytorch/version.py | 2 +- 3 files changed, 31 insertions(+), 15 deletions(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 8c69b9b..d97db32 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -782,17 +782,13 @@ class CausalTransformer(nn.Module): self.norm = LayerNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options self.project_out = nn.Linear(dim, dim, bias = False) if final_proj else nn.Identity() - def forward( - self, - x, - mask = None # we will need a mask here, due to variable length of the text encodings - also offer dalle1 strategy with padding token embeddings - ): + def forward(self, x): n, device = x.shape[1], x.device attn_bias = self.rel_pos_bias(n, n + 1, device = device) for attn, ff in self.layers: - x = attn(x, mask = mask, attn_bias = attn_bias) + x + x = attn(x, attn_bias = attn_bias) + x x = ff(x) + x out = self.norm(x) @@ -806,7 +802,7 @@ class DiffusionPriorNetwork(nn.Module): num_time_embeds = 1, num_image_embeds = 1, num_text_embeds = 1, - attend_all_text_encodings = True, + max_text_len = 256, **kwargs ): super().__init__() @@ -832,7 +828,10 @@ class DiffusionPriorNetwork(nn.Module): self.learned_query = nn.Parameter(torch.randn(dim)) self.causal_transformer = CausalTransformer(dim = dim, **kwargs) - self.attend_all_text_encodings = attend_all_text_encodings + # dalle1 learned padding strategy + + self.max_text_len = max_text_len + self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, dim)) def forward_with_cond_scale( self, @@ -872,11 +871,28 @@ class DiffusionPriorNetwork(nn.Module): if not exists(text_encodings): text_encodings = torch.empty((batch, 0, dim), device = device, dtype = dtype) + + mask = torch.any(text_encodings != 0., dim = -1) - if self.attend_all_text_encodings: - mask = torch.ones((batch, text_encodings.shape[-2]), device = device, dtype = torch.bool) - else: - mask = torch.any(text_encodings != 0., dim = -1) + # replace any padding in the text encodings with learned padding tokens unique across position + + text_encodings = text_encodings[:, :self.max_text_len] + mask = mask[:, :self.max_text_len] + + text_len = text_encodings.shape[-2] + remainder = self.max_text_len - text_len + + if remainder > 0: + text_encodings = F.pad(text_encodings, (0, 0, 0, remainder), value = 0.) + mask = F.pad(mask, (0, remainder), value = 0.) + + null_text_embeds = self.null_text_embed.to(text_encodings.dtype) + + text_encodings = torch.where( + rearrange(mask, 'b n -> b n 1'), + text_encodings, + null_text_embeds + ) # classifier free guidance @@ -910,7 +926,7 @@ class DiffusionPriorNetwork(nn.Module): # attend - tokens = self.causal_transformer(tokens, mask = mask) + tokens = self.causal_transformer(tokens) # get learned query, which should predict the image embedding (per DDPM timestep) diff --git a/dalle2_pytorch/train_configs.py b/dalle2_pytorch/train_configs.py index 3f77fbb..d2f502e 100644 --- a/dalle2_pytorch/train_configs.py +++ b/dalle2_pytorch/train_configs.py @@ -129,11 +129,11 @@ class AdapterConfig(BaseModel): class DiffusionPriorNetworkConfig(BaseModel): dim: int depth: int + max_text_len: int = None num_timesteps: int = None num_time_embeds: int = 1 num_image_embeds: int = 1 num_text_embeds: int = 1 - attend_all_text_encodings: bool = True dim_head: int = 64 heads: int = 8 ff_mult: int = 4 diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index eed48b7..08a9dbf 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.22.3' +__version__ = '0.23.0'