deal the diffusion prior problem yet another blow

This commit is contained in:
Phil Wang
2022-05-09 11:08:46 -07:00
parent 8647cb5e76
commit e46eaec817
2 changed files with 12 additions and 4 deletions

View File

@@ -743,11 +743,18 @@ class DiffusionPriorNetwork(nn.Module):
num_timesteps = None, num_timesteps = None,
num_time_embeds = 1, num_time_embeds = 1,
num_image_embeds = 1, num_image_embeds = 1,
num_text_embeds = 1,
**kwargs **kwargs
): ):
super().__init__() super().__init__()
self.num_time_embeds = num_time_embeds self.num_time_embeds = num_time_embeds
self.num_image_embeds = num_image_embeds self.num_image_embeds = num_image_embeds
self.num_text_embeds = num_text_embeds
self.to_text_embeds = nn.Sequential(
nn.Linear(dim, dim * num_text_embeds) if num_text_embeds > 1 else nn.Identity(),
Rearrange('b (n d) -> b n d', n = num_text_embeds)
)
self.to_time_embeds = nn.Sequential( self.to_time_embeds = nn.Sequential(
nn.Embedding(num_timesteps, dim * num_time_embeds) if exists(num_timesteps) else nn.Sequential(SinusoidalPosEmb(dim), MLP(dim, dim * num_time_embeds)), # also offer a continuous version of timestep embeddings, with a 2 layer MLP nn.Embedding(num_timesteps, dim * num_time_embeds) if exists(num_timesteps) else nn.Sequential(SinusoidalPosEmb(dim), MLP(dim, dim * num_time_embeds)), # also offer a continuous version of timestep embeddings, with a 2 layer MLP
@@ -755,7 +762,7 @@ class DiffusionPriorNetwork(nn.Module):
) )
self.to_image_embeds = nn.Sequential( self.to_image_embeds = nn.Sequential(
nn.Linear(dim, dim * num_image_embeds), nn.Linear(dim, dim * num_image_embeds) if num_image_embeds > 1 else nn.Identity(),
Rearrange('b (n d) -> b n d', n = num_image_embeds) Rearrange('b (n d) -> b n d', n = num_image_embeds)
) )
@@ -788,12 +795,12 @@ class DiffusionPriorNetwork(nn.Module):
): ):
batch, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype batch, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype
num_time_embeds, num_image_embeds = self.num_time_embeds, self.num_image_embeds num_time_embeds, num_image_embeds, num_text_embeds = self.num_time_embeds, self.num_image_embeds, self.num_text_embeds
# in section 2.2, last paragraph # in section 2.2, last paragraph
# "... consisting of encoded text, CLIP text embedding, diffusion timestep embedding, noised CLIP image embedding, final embedding for prediction" # "... consisting of encoded text, CLIP text embedding, diffusion timestep embedding, noised CLIP image embedding, final embedding for prediction"
text_embed = rearrange(text_embed, 'b d -> b 1 d') text_embed = self.to_text_embeds(text_embed)
image_embed = self.to_image_embeds(image_embed) image_embed = self.to_image_embeds(image_embed)
# make text encodings optional # make text encodings optional
@@ -814,6 +821,7 @@ class DiffusionPriorNetwork(nn.Module):
# whether text embedding is masked or not depends on the classifier free guidance conditional masking # whether text embedding is masked or not depends on the classifier free guidance conditional masking
keep_mask = repeat(keep_mask, 'b 1 -> b n', n = num_text_embeds)
mask = torch.cat((mask, keep_mask), dim = 1) mask = torch.cat((mask, keep_mask), dim = 1)
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different) # whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.2.0', version = '0.2.1',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',