add learned padding tokens, same strategy as dalle1, for diffusion prior, and get rid of masking in causal transformer

This commit is contained in:
Phil Wang
2022-07-12 17:33:14 -07:00
parent cd26c6b17d
commit 3ee3c56d2a
3 changed files with 31 additions and 15 deletions

View File

@@ -782,17 +782,13 @@ class CausalTransformer(nn.Module):
self.norm = LayerNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
self.project_out = nn.Linear(dim, dim, bias = False) if final_proj else nn.Identity()
def forward(
self,
x,
mask = None # we will need a mask here, due to variable length of the text encodings - also offer dalle1 strategy with padding token embeddings
):
def forward(self, x):
n, device = x.shape[1], x.device
attn_bias = self.rel_pos_bias(n, n + 1, device = device)
for attn, ff in self.layers:
x = attn(x, mask = mask, attn_bias = attn_bias) + x
x = attn(x, attn_bias = attn_bias) + x
x = ff(x) + x
out = self.norm(x)
@@ -806,7 +802,7 @@ class DiffusionPriorNetwork(nn.Module):
num_time_embeds = 1,
num_image_embeds = 1,
num_text_embeds = 1,
attend_all_text_encodings = True,
max_text_len = 256,
**kwargs
):
super().__init__()
@@ -832,7 +828,10 @@ class DiffusionPriorNetwork(nn.Module):
self.learned_query = nn.Parameter(torch.randn(dim))
self.causal_transformer = CausalTransformer(dim = dim, **kwargs)
self.attend_all_text_encodings = attend_all_text_encodings
# dalle1 learned padding strategy
self.max_text_len = max_text_len
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, dim))
def forward_with_cond_scale(
self,
@@ -872,11 +871,28 @@ class DiffusionPriorNetwork(nn.Module):
if not exists(text_encodings):
text_encodings = torch.empty((batch, 0, dim), device = device, dtype = dtype)
mask = torch.any(text_encodings != 0., dim = -1)
if self.attend_all_text_encodings:
mask = torch.ones((batch, text_encodings.shape[-2]), device = device, dtype = torch.bool)
else:
mask = torch.any(text_encodings != 0., dim = -1)
# replace any padding in the text encodings with learned padding tokens unique across position
text_encodings = text_encodings[:, :self.max_text_len]
mask = mask[:, :self.max_text_len]
text_len = text_encodings.shape[-2]
remainder = self.max_text_len - text_len
if remainder > 0:
text_encodings = F.pad(text_encodings, (0, 0, 0, remainder), value = 0.)
mask = F.pad(mask, (0, remainder), value = False)
null_text_embeds = self.null_text_embed.to(text_encodings.dtype)
text_encodings = torch.where(
rearrange(mask, 'b n -> b n 1'),
text_encodings,
null_text_embeds
)
# classifier free guidance
@@ -910,7 +926,7 @@ class DiffusionPriorNetwork(nn.Module):
# attend
tokens = self.causal_transformer(tokens, mask = mask)
tokens = self.causal_transformer(tokens)
# get learned query, which should predict the image embedding (per DDPM timestep)