mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
give more surface area for attention in diffusion prior
This commit is contained in:
@@ -703,10 +703,24 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
self,
|
||||
dim,
|
||||
num_timesteps = None,
|
||||
num_time_embeds = 1,
|
||||
num_image_embeds = 1,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
self.time_embeddings = nn.Embedding(num_timesteps, dim) if exists(num_timesteps) else nn.Sequential(SinusoidalPosEmb(dim), MLP(dim, dim)) # also offer a continuous version of timestep embeddings, with a 2 layer MLP
|
||||
self.num_time_embeds = num_time_embeds
|
||||
self.num_image_embeds = num_image_embeds
|
||||
|
||||
self.to_time_embeds = nn.Sequential(
|
||||
nn.Embedding(num_timesteps, dim * num_time_embeds) if exists(num_timesteps) else nn.Sequential(SinusoidalPosEmb(dim), MLP(dim, dim * num_time_embeds)), # also offer a continuous version of timestep embeddings, with a 2 layer MLP
|
||||
Rearrange('b (n d) -> b n d', n = num_time_embeds)
|
||||
)
|
||||
|
||||
self.to_image_embeds = nn.Sequential(
|
||||
nn.Linear(dim, dim * num_image_embeds),
|
||||
Rearrange('b (n d) -> b n d', n = num_image_embeds)
|
||||
)
|
||||
|
||||
self.learned_query = nn.Parameter(torch.randn(dim))
|
||||
self.causal_transformer = CausalTransformer(dim = dim, **kwargs)
|
||||
|
||||
@@ -736,10 +750,13 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
):
|
||||
batch, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype
|
||||
|
||||
num_time_embeds, num_image_embeds = self.num_time_embeds, self.num_image_embeds
|
||||
|
||||
# in section 2.2, last paragraph
|
||||
# "... consisting of encoded text, CLIP text embedding, diffusion timestep embedding, noised CLIP image embedding, final embedding for prediction"
|
||||
|
||||
text_embed, image_embed = rearrange_many((text_embed, image_embed), 'b d -> b 1 d')
|
||||
text_embed = rearrange(text_embed, 'b d -> b 1 d')
|
||||
image_embed = self.to_image_embeds(image_embed)
|
||||
|
||||
# make text encodings optional
|
||||
# although the paper seems to suggest it is present <--
|
||||
@@ -765,10 +782,10 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
# but let's just do it right
|
||||
|
||||
if exists(mask):
|
||||
mask = F.pad(mask, (0, 3), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
|
||||
attend_padding = 1 + num_time_embeds + num_image_embeds # 1 for learned queries + number of image embeds + time embeds
|
||||
mask = F.pad(mask, (0, attend_padding), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
|
||||
|
||||
time_embed = self.time_embeddings(diffusion_timesteps)
|
||||
time_embed = rearrange(time_embed, 'b d -> b 1 d')
|
||||
time_embed = self.to_time_embeds(diffusion_timesteps)
|
||||
|
||||
learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user