mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 03:24:22 +01:00
Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dde51fd362 | ||
|
|
2eac7996fa | ||
|
|
4010aec033 | ||
|
|
c87b84a259 | ||
|
|
8b05468653 | ||
|
|
830afd3c15 |
@@ -981,6 +981,8 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
- [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783
|
||||
- [ ] make sure resnet hyperparameters can be configurable across unet depth (groups and expansion factor)
|
||||
- [ ] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
|
||||
- [ ] offer setting in diffusion prior to split time and image embeddings into multiple tokens, configurable, for more surface area during attention
|
||||
- [ ] bring in skip-layer excitatons (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training
|
||||
|
||||
## Citations
|
||||
|
||||
|
||||
@@ -706,7 +706,7 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
self.time_embeddings = nn.Embedding(num_timesteps, dim) if exists(num_timesteps) else nn.Sequential(Rearrange('b -> b 1'), MLP(1, dim)) # also offer a continuous version of timestep embeddings, with a 2 layer MLP
|
||||
self.time_embeddings = nn.Embedding(num_timesteps, dim) if exists(num_timesteps) else nn.Sequential(SinusoidalPosEmb(dim), MLP(dim, dim)) # also offer a continuous version of timestep embeddings, with a 2 layer MLP
|
||||
self.learned_query = nn.Parameter(torch.randn(dim))
|
||||
self.causal_transformer = CausalTransformer(dim = dim, **kwargs)
|
||||
|
||||
@@ -800,7 +800,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
image_size = None,
|
||||
image_channels = 3,
|
||||
timesteps = 1000,
|
||||
cond_drop_prob = 0.2,
|
||||
cond_drop_prob = 0.,
|
||||
loss_type = "l1",
|
||||
predict_x_start = True,
|
||||
beta_schedule = "cosine",
|
||||
|
||||
2
setup.py
2
setup.py
@@ -10,7 +10,7 @@ setup(
|
||||
'dream = dalle2_pytorch.cli:dream'
|
||||
],
|
||||
},
|
||||
version = '0.1.7',
|
||||
version = '0.1.10',
|
||||
license='MIT',
|
||||
description = 'DALL-E 2',
|
||||
author = 'Phil Wang',
|
||||
|
||||
@@ -93,6 +93,8 @@ def report_cosine_sims(diffusion_prior, image_reader, text_reader, train_set_siz
|
||||
text_embed, predicted_image_embeddings).cpu().numpy()
|
||||
unrelated_similarity = cos(
|
||||
text_embed, predicted_unrelated_embeddings).cpu().numpy()
|
||||
predicted_img_similarity = cos(
|
||||
test_image_embeddings, predicted_image_embeddings).cpu().numpy()
|
||||
|
||||
wandb.log(
|
||||
{"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity)})
|
||||
@@ -100,6 +102,8 @@ def report_cosine_sims(diffusion_prior, image_reader, text_reader, train_set_siz
|
||||
predicted_similarity)})
|
||||
wandb.log({"CosineSimilarity(text_embed,predicted_unrelated_embed)": np.mean(
|
||||
unrelated_similarity)})
|
||||
wandb.log({"CosineSimilarity(image_embed,predicted_image_embed)": np.mean(
|
||||
predicted_img_similarity)})
|
||||
|
||||
return np.mean(predicted_similarity - original_similarity)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user