Compare commits

...

3 Commits
0.1.9 ... 0.2.0

Author SHA1 Message Date
Phil Wang
53c189e46a give more surface area for attention in diffusion prior 2022-05-09 08:08:11 -07:00
Phil Wang
dde51fd362 revert restriction for classifier free guidance for diffusion prior, given @crowsonkb advice 2022-05-07 20:55:41 -07:00
Nasir Khalid
2eac7996fa Additional image_embed metric (#75)
Added metric to track image_embed vs predicted_image_embed
2022-05-07 14:32:33 -07:00
4 changed files with 29 additions and 8 deletions

View File

@@ -966,6 +966,7 @@ Once built, images will be saved to the same directory the command is invoked
- [x] add convnext backbone for vqgan-vae (in addition to vit [vit-vqgan] + resnet)
- [x] make sure DDPMs can be run with traditional resnet blocks (but leave convnext as an option for experimentation)
- [x] make sure for the latter unets in the cascade, one can train on crops for learning super resolution (constrain the unet to be only convolutions in that case, or allow conv-like attention with rel pos bias)
- [x] offer setting in diffusion prior to split time and image embeddings into multiple tokens, configurable, for more surface area during attention
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
- [ ] make sure the cascading ddpm in the repository can be trained unconditionally, offer a one-line CLI tool for training on a folder of images
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
@@ -981,7 +982,6 @@ Once built, images will be saved to the same directory the command is invoked
- [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783
- [ ] make sure resnet hyperparameters can be configurable across unet depth (groups and expansion factor)
- [ ] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
- [ ] offer setting in diffusion prior to split time and image embeddings into multiple tokens, configurable, for more surface area during attention
- [ ] bring in skip-layer excitatons (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training
## Citations

View File

@@ -703,10 +703,24 @@ class DiffusionPriorNetwork(nn.Module):
self,
dim,
num_timesteps = None,
num_time_embeds = 1,
num_image_embeds = 1,
**kwargs
):
super().__init__()
self.time_embeddings = nn.Embedding(num_timesteps, dim) if exists(num_timesteps) else nn.Sequential(SinusoidalPosEmb(dim), MLP(dim, dim)) # also offer a continuous version of timestep embeddings, with a 2 layer MLP
self.num_time_embeds = num_time_embeds
self.num_image_embeds = num_image_embeds
self.to_time_embeds = nn.Sequential(
nn.Embedding(num_timesteps, dim * num_time_embeds) if exists(num_timesteps) else nn.Sequential(SinusoidalPosEmb(dim), MLP(dim, dim * num_time_embeds)), # also offer a continuous version of timestep embeddings, with a 2 layer MLP
Rearrange('b (n d) -> b n d', n = num_time_embeds)
)
self.to_image_embeds = nn.Sequential(
nn.Linear(dim, dim * num_image_embeds),
Rearrange('b (n d) -> b n d', n = num_image_embeds)
)
self.learned_query = nn.Parameter(torch.randn(dim))
self.causal_transformer = CausalTransformer(dim = dim, **kwargs)
@@ -736,10 +750,13 @@ class DiffusionPriorNetwork(nn.Module):
):
batch, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype
num_time_embeds, num_image_embeds = self.num_time_embeds, self.num_image_embeds
# in section 2.2, last paragraph
# "... consisting of encoded text, CLIP text embedding, diffusion timestep embedding, noised CLIP image embedding, final embedding for prediction"
text_embed, image_embed = rearrange_many((text_embed, image_embed), 'b d -> b 1 d')
text_embed = rearrange(text_embed, 'b d -> b 1 d')
image_embed = self.to_image_embeds(image_embed)
# make text encodings optional
# although the paper seems to suggest it is present <--
@@ -765,10 +782,10 @@ class DiffusionPriorNetwork(nn.Module):
# but let's just do it right
if exists(mask):
mask = F.pad(mask, (0, 3), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
attend_padding = 1 + num_time_embeds + num_image_embeds # 1 for learned queries + number of image embeds + time embeds
mask = F.pad(mask, (0, attend_padding), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
time_embed = self.time_embeddings(diffusion_timesteps)
time_embed = rearrange(time_embed, 'b d -> b 1 d')
time_embed = self.to_time_embeds(diffusion_timesteps)
learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch)
@@ -834,7 +851,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
self.image_embed_dim = default(image_embed_dim, lambda: clip.dim_latent)
self.channels = default(image_channels, lambda: clip.image_channels)
self.cond_drop_prob = cond_drop_prob if not predict_x_start else 0.
self.cond_drop_prob = cond_drop_prob
self.condition_on_text_encodings = condition_on_text_encodings
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.1.9',
version = '0.2.0',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',

View File

@@ -93,6 +93,8 @@ def report_cosine_sims(diffusion_prior, image_reader, text_reader, train_set_siz
text_embed, predicted_image_embeddings).cpu().numpy()
unrelated_similarity = cos(
text_embed, predicted_unrelated_embeddings).cpu().numpy()
predicted_img_similarity = cos(
test_image_embeddings, predicted_image_embeddings).cpu().numpy()
wandb.log(
{"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity)})
@@ -100,6 +102,8 @@ def report_cosine_sims(diffusion_prior, image_reader, text_reader, train_set_siz
predicted_similarity)})
wandb.log({"CosineSimilarity(text_embed,predicted_unrelated_embed)": np.mean(
unrelated_similarity)})
wandb.log({"CosineSimilarity(image_embed,predicted_image_embed)": np.mean(
predicted_img_similarity)})
return np.mean(predicted_similarity - original_similarity)