This commit is contained in:
Phil Wang
2022-04-18 11:52:25 -07:00
parent 0332eaa6ff
commit 6cddefad26

View File

@@ -109,7 +109,7 @@ unet = Unet(
# decoder, which contains the unet and clip # decoder, which contains the unet and clip
decoder = Decoder( decoder = Decoder(
net = unet, unet = unet,
clip = clip, clip = clip,
timesteps = 100, timesteps = 100,
cond_drop_prob = 0.2 cond_drop_prob = 0.2
@@ -182,9 +182,9 @@ loss.backward()
# now the diffusion prior can generate image embeddings from the text embeddings # now the diffusion prior can generate image embeddings from the text embeddings
``` ```
In the paper, they actually used a <a href="https://cascaded-diffusion.github.io/">recently discovered technique</a>, from <a href="http://www.jonathanho.me/">Jonathan Ho</a> himself (original author of DDPMs, from which DALL-E2 is based). In the paper, they actually used a <a href="https://cascaded-diffusion.github.io/">recently discovered technique</a>, from <a href="http://www.jonathanho.me/">Jonathan Ho</a> himself (original author of DDPMs, the core technique used in DALL-E v2) for high resolution image synthesis.
This can easily be used within the framework offered in this repository as so This can easily be used within this framework as so
```python ```python
import torch import torch
@@ -218,7 +218,7 @@ unet1 = Unet(
unet2 = Unet( unet2 = Unet(
dim = 16, dim = 16,
image_embed_dim = 512, image_embed_dim = 512,
lowres_cond = True, # subsequence unets must have this turned on (and first unet must have this turned off) lowres_cond = True, # subsequent unets must have this turned on (and first unet must have this turned off)
cond_dim = 128, cond_dim = 128,
channels = 3, channels = 3,
dim_mults = (1, 2, 4, 8, 16) dim_mults = (1, 2, 4, 8, 16)
@@ -412,6 +412,8 @@ Offer training wrappers
- [x] build the cascading ddpm by having Decoder class manage multiple unets at different resolutions - [x] build the cascading ddpm by having Decoder class manage multiple unets at different resolutions
- [ ] use an image resolution cutoff and do cross attention conditioning only if resources allow, and MLP + sum conditioning on rest - [ ] use an image resolution cutoff and do cross attention conditioning only if resources allow, and MLP + sum conditioning on rest
- [ ] make unet more configurable - [ ] make unet more configurable
- [ ] figure out some factory methods to make cascading unet instantiations less error-prone
- [ ] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately)
- [ ] train on a toy task, offer in colab - [ ] train on a toy task, offer in colab
- [ ] add attention to unet - apply some personal tricks with efficient attention - use the sparse attention mechanism from https://github.com/lucidrains/vit-pytorch#maxvit - [ ] add attention to unet - apply some personal tricks with efficient attention - use the sparse attention mechanism from https://github.com/lucidrains/vit-pytorch#maxvit
- [ ] build out latent diffusion architecture in separate file, as it is not faithful to dalle-2 (but offer it as as setting) - [ ] build out latent diffusion architecture in separate file, as it is not faithful to dalle-2 (but offer it as as setting)