Compare commits

..

4 Commits

Author SHA1 Message Date
Phil Wang
36c5079bd7 LazyLinear is not mature, make users pass in text_embed_dim if text conditioning is turned on 2022-05-15 18:56:52 -07:00
Phil Wang
4a4c7ac9e6 cond drop prob for diffusion prior network should default to 0 2022-05-15 18:47:45 -07:00
Phil Wang
fad7481479 todo 2022-05-15 17:00:25 -07:00
Phil Wang
123658d082 cite Ho et al, since cascading ddpm is now trainable 2022-05-15 16:56:53 -07:00
4 changed files with 14 additions and 8 deletions

View File

@@ -1065,6 +1065,7 @@ Once built, images will be saved to the same directory the command is invoked
- [ ] allow for unet to be able to condition non-cross attention style as well
- [ ] for all model classes with hyperparameters that changes the network architecture, make it requirement that they must expose a config property, and write a simple function that asserts that it restores the object correctly
- [ ] for both diffusion prior and decoder, all exponential moving averaged models needs to be saved and restored as well (as well as the step number)
- [ ] read the paper, figure it out, and build it https://github.com/lucidrains/DALLE2-pytorch/issues/89
## Citations
@@ -1153,4 +1154,13 @@ Once built, images will be saved to the same directory the command is invoked
}
```
```bibtex
@article{ho2021cascaded,
title = {Cascaded Diffusion Models for High Fidelity Image Generation},
author = {Ho, Jonathan and Saharia, Chitwan and Chan, William and Fleet, David J and Norouzi, Mohammad and Salimans, Tim},
journal = {arXiv preprint arXiv:2106.15282},
year = {2021}
}
```
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>

View File

@@ -794,7 +794,7 @@ class DiffusionPriorNetwork(nn.Module):
text_embed,
text_encodings = None,
mask = None,
cond_drop_prob = 0.2
cond_drop_prob = 0.
):
batch, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype
@@ -1387,7 +1387,8 @@ class Unet(nn.Module):
self.text_to_cond = None
if cond_on_text_encodings:
self.text_to_cond = nn.LazyLinear(cond_dim) if not exists(text_embed_dim) else nn.Linear(text_embed_dim, cond_dim)
assert exists(text_embed_dim), 'text_embed_dim must be given to the unet if cond_on_text_encodings is True'
self.text_to_cond = nn.Linear(text_embed_dim, cond_dim)
# finer control over whether to condition on image embeddings and text encodings
# so one can have the latter unets in the cascading DDPMs only focus on super-resoluting

View File

@@ -335,11 +335,6 @@ class DecoderTrainer(nn.Module):
self.num_unets = len(self.decoder.unets)
self.use_ema = use_ema
if use_ema:
has_lazy_linear = any([type(module) == nn.LazyLinear for module in decoder.modules()])
assert not has_lazy_linear, 'you must set the text_embed_dim on your u-nets if you plan on doing automatic exponential moving average'
self.ema_unets = nn.ModuleList([])
self.amp = amp

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.2.35',
version = '0.2.37',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',