mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-20 10:14:19 +01:00
readme
This commit is contained in:
10
README.md
10
README.md
@@ -109,7 +109,7 @@ unet = Unet(
|
|||||||
# decoder, which contains the unet and clip
|
# decoder, which contains the unet and clip
|
||||||
|
|
||||||
decoder = Decoder(
|
decoder = Decoder(
|
||||||
net = unet,
|
unet = unet,
|
||||||
clip = clip,
|
clip = clip,
|
||||||
timesteps = 100,
|
timesteps = 100,
|
||||||
cond_drop_prob = 0.2
|
cond_drop_prob = 0.2
|
||||||
@@ -182,9 +182,9 @@ loss.backward()
|
|||||||
# now the diffusion prior can generate image embeddings from the text embeddings
|
# now the diffusion prior can generate image embeddings from the text embeddings
|
||||||
```
|
```
|
||||||
|
|
||||||
In the paper, they actually used a <a href="https://cascaded-diffusion.github.io/">recently discovered technique</a>, from <a href="http://www.jonathanho.me/">Jonathan Ho</a> himself (original author of DDPMs, from which DALL-E2 is based).
|
In the paper, they actually used a <a href="https://cascaded-diffusion.github.io/">recently discovered technique</a>, from <a href="http://www.jonathanho.me/">Jonathan Ho</a> himself (original author of DDPMs, the core technique used in DALL-E v2) for high resolution image synthesis.
|
||||||
|
|
||||||
This can easily be used within the framework offered in this repository as so
|
This can easily be used within this framework as so
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import torch
|
import torch
|
||||||
@@ -218,7 +218,7 @@ unet1 = Unet(
|
|||||||
unet2 = Unet(
|
unet2 = Unet(
|
||||||
dim = 16,
|
dim = 16,
|
||||||
image_embed_dim = 512,
|
image_embed_dim = 512,
|
||||||
lowres_cond = True, # subsequence unets must have this turned on (and first unet must have this turned off)
|
lowres_cond = True, # subsequent unets must have this turned on (and first unet must have this turned off)
|
||||||
cond_dim = 128,
|
cond_dim = 128,
|
||||||
channels = 3,
|
channels = 3,
|
||||||
dim_mults = (1, 2, 4, 8, 16)
|
dim_mults = (1, 2, 4, 8, 16)
|
||||||
@@ -412,6 +412,8 @@ Offer training wrappers
|
|||||||
- [x] build the cascading ddpm by having Decoder class manage multiple unets at different resolutions
|
- [x] build the cascading ddpm by having Decoder class manage multiple unets at different resolutions
|
||||||
- [ ] use an image resolution cutoff and do cross attention conditioning only if resources allow, and MLP + sum conditioning on rest
|
- [ ] use an image resolution cutoff and do cross attention conditioning only if resources allow, and MLP + sum conditioning on rest
|
||||||
- [ ] make unet more configurable
|
- [ ] make unet more configurable
|
||||||
|
- [ ] figure out some factory methods to make cascading unet instantiations less error-prone
|
||||||
|
- [ ] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately)
|
||||||
- [ ] train on a toy task, offer in colab
|
- [ ] train on a toy task, offer in colab
|
||||||
- [ ] add attention to unet - apply some personal tricks with efficient attention - use the sparse attention mechanism from https://github.com/lucidrains/vit-pytorch#maxvit
|
- [ ] add attention to unet - apply some personal tricks with efficient attention - use the sparse attention mechanism from https://github.com/lucidrains/vit-pytorch#maxvit
|
||||||
- [ ] build out latent diffusion architecture in separate file, as it is not faithful to dalle-2 (but offer it as as setting)
|
- [ ] build out latent diffusion architecture in separate file, as it is not faithful to dalle-2 (but offer it as as setting)
|
||||||
|
|||||||
Reference in New Issue
Block a user