give more surface area for attention in diffusion prior

This commit is contained in:
Phil Wang
2022-05-09 08:08:11 -07:00
parent dde51fd362
commit 53c189e46a
3 changed files with 24 additions and 7 deletions

View File

@@ -703,10 +703,24 @@ class DiffusionPriorNetwork(nn.Module):
self,
dim,
num_timesteps = None,
num_time_embeds = 1,
num_image_embeds = 1,
**kwargs
):
super().__init__()
self.time_embeddings = nn.Embedding(num_timesteps, dim) if exists(num_timesteps) else nn.Sequential(SinusoidalPosEmb(dim), MLP(dim, dim)) # also offer a continuous version of timestep embeddings, with a 2 layer MLP
self.num_time_embeds = num_time_embeds
self.num_image_embeds = num_image_embeds
self.to_time_embeds = nn.Sequential(
nn.Embedding(num_timesteps, dim * num_time_embeds) if exists(num_timesteps) else nn.Sequential(SinusoidalPosEmb(dim), MLP(dim, dim * num_time_embeds)), # also offer a continuous version of timestep embeddings, with a 2 layer MLP
Rearrange('b (n d) -> b n d', n = num_time_embeds)
)
self.to_image_embeds = nn.Sequential(
nn.Linear(dim, dim * num_image_embeds),
Rearrange('b (n d) -> b n d', n = num_image_embeds)
)
self.learned_query = nn.Parameter(torch.randn(dim))
self.causal_transformer = CausalTransformer(dim = dim, **kwargs)
@@ -736,10 +750,13 @@ class DiffusionPriorNetwork(nn.Module):
):
batch, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype
num_time_embeds, num_image_embeds = self.num_time_embeds, self.num_image_embeds
# in section 2.2, last paragraph
# "... consisting of encoded text, CLIP text embedding, diffusion timestep embedding, noised CLIP image embedding, final embedding for prediction"
text_embed, image_embed = rearrange_many((text_embed, image_embed), 'b d -> b 1 d')
text_embed = rearrange(text_embed, 'b d -> b 1 d')
image_embed = self.to_image_embeds(image_embed)
# make text encodings optional
# although the paper seems to suggest it is present <--
@@ -765,10 +782,10 @@ class DiffusionPriorNetwork(nn.Module):
# but let's just do it right
if exists(mask):
mask = F.pad(mask, (0, 3), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
attend_padding = 1 + num_time_embeds + num_image_embeds # 1 for learned queries + number of image embeds + time embeds
mask = F.pad(mask, (0, attend_padding), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
time_embed = self.time_embeddings(diffusion_timesteps)
time_embed = rearrange(time_embed, 'b d -> b 1 d')
time_embed = self.to_time_embeds(diffusion_timesteps)
learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch)