Compare commits

...

4 Commits

Author SHA1 Message Date
Phil Wang
4a4c7ac9e6 cond drop prob for diffusion prior network should default to 0 2022-05-15 18:47:45 -07:00
Phil Wang
fad7481479 todo 2022-05-15 17:00:25 -07:00
Phil Wang
123658d082 cite Ho et al, since cascading ddpm is now trainable 2022-05-15 16:56:53 -07:00
Phil Wang
11d4e11f10 allow for training unconditional ddpm or cascading ddpms 2022-05-15 16:54:56 -07:00
3 changed files with 69 additions and 8 deletions

View File

@@ -706,7 +706,7 @@ mock_image_embed = torch.randn(1, 512).cuda()
images = decoder.sample(mock_image_embed) # (1, 3, 1024, 1024) images = decoder.sample(mock_image_embed) # (1, 3, 1024, 1024)
``` ```
## Training wrapper (wip) ## Training wrapper
### Decoder Training ### Decoder Training
@@ -851,6 +851,57 @@ diffusion_prior_trainer.update() # this will update the optimizer as well as th
image_embeds = diffusion_prior_trainer.sample(text) # (4, 512) - exponential moving averaged image embeddings image_embeds = diffusion_prior_trainer.sample(text) # (4, 512) - exponential moving averaged image embeddings
``` ```
## Bonus
### Unconditional Training
The repository also contains the means to train unconditional DDPM model, or even cascading DDPMs. You simply have to set `unconditional = True` in the `Decoder`
ex.
```python
import torch
from dalle2_pytorch import Unet, Decoder
# unet for the cascading ddpm
unet1 = Unet(
dim = 128,
dim_mults=(1, 2, 4, 8)
).cuda()
unet2 = Unet(
dim = 32,
dim_mults = (1, 2, 4, 8, 16)
).cuda()
# decoder, which contains the unets
decoder = Decoder(
unet = (unet1, unet2),
image_sizes = (256, 512), # first unet up to 256px, then second to 512px
timesteps = 1000,
unconditional = True
).cuda()
# mock images (get a lot of this)
images = torch.randn(1, 3, 512, 512).cuda()
# feed images into decoder
for i in (1, 2):
loss = decoder(images, unet_number = i)
loss.backward()
# do the above for many many many many steps
# then it will learn to generate images
images = decoder.sample(batch_size = 2) # (2, 3, 512, 512)
```
## Dataloaders
### Decoder Dataloaders ### Decoder Dataloaders
In order to make loading data simple and efficient, we include some general dataloaders that can be used to train portions of the network. In order to make loading data simple and efficient, we include some general dataloaders that can be used to train portions of the network.
@@ -1014,6 +1065,7 @@ Once built, images will be saved to the same directory the command is invoked
- [ ] allow for unet to be able to condition non-cross attention style as well - [ ] allow for unet to be able to condition non-cross attention style as well
- [ ] for all model classes with hyperparameters that changes the network architecture, make it requirement that they must expose a config property, and write a simple function that asserts that it restores the object correctly - [ ] for all model classes with hyperparameters that changes the network architecture, make it requirement that they must expose a config property, and write a simple function that asserts that it restores the object correctly
- [ ] for both diffusion prior and decoder, all exponential moving averaged models needs to be saved and restored as well (as well as the step number) - [ ] for both diffusion prior and decoder, all exponential moving averaged models needs to be saved and restored as well (as well as the step number)
- [ ] read the paper, figure it out, and build it https://github.com/lucidrains/DALLE2-pytorch/issues/89
## Citations ## Citations
@@ -1102,4 +1154,13 @@ Once built, images will be saved to the same directory the command is invoked
} }
``` ```
```bibtex
@article{ho2021cascaded,
title = {Cascaded Diffusion Models for High Fidelity Image Generation},
author = {Ho, Jonathan and Saharia, Chitwan and Chan, William and Fleet, David J and Norouzi, Mohammad and Salimans, Tim},
journal = {arXiv preprint arXiv:2106.15282},
year = {2021}
}
```
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a> *Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>

View File

@@ -794,7 +794,7 @@ class DiffusionPriorNetwork(nn.Module):
text_embed, text_embed,
text_encodings = None, text_encodings = None,
mask = None, mask = None,
cond_drop_prob = 0.2 cond_drop_prob = 0.
): ):
batch, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype batch, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype
@@ -1305,7 +1305,7 @@ class Unet(nn.Module):
self, self,
dim, dim,
*, *,
image_embed_dim, image_embed_dim = None,
text_embed_dim = None, text_embed_dim = None,
cond_dim = None, cond_dim = None,
num_image_tokens = 4, num_image_tokens = 4,
@@ -1377,7 +1377,7 @@ class Unet(nn.Module):
self.image_to_cond = nn.Sequential( self.image_to_cond = nn.Sequential(
nn.Linear(image_embed_dim, cond_dim * num_image_tokens), nn.Linear(image_embed_dim, cond_dim * num_image_tokens),
Rearrange('b (n d) -> b n d', n = num_image_tokens) Rearrange('b (n d) -> b n d', n = num_image_tokens)
) if image_embed_dim != cond_dim else nn.Identity() ) if cond_on_image_embeds and image_embed_dim != cond_dim else nn.Identity()
self.norm_cond = nn.LayerNorm(cond_dim) self.norm_cond = nn.LayerNorm(cond_dim)
self.norm_mid_cond = nn.LayerNorm(cond_dim) self.norm_mid_cond = nn.LayerNorm(cond_dim)
@@ -1701,7 +1701,7 @@ class Decoder(BaseGaussianDiffusion):
self.unconditional = unconditional self.unconditional = unconditional
assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present' assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present'
assert exists(clip) ^ exists(image_size), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)' assert self.unconditional or (exists(clip) ^ exists(image_size)), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)'
self.clip = None self.clip = None
if exists(clip): if exists(clip):
@@ -2036,12 +2036,12 @@ class Decoder(BaseGaussianDiffusion):
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long) times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
if not exists(image_embed): if not exists(image_embed) and not self.unconditional:
assert exists(self.clip), 'if you want to derive CLIP image embeddings automatically, you must supply `clip` to the decoder on init' assert exists(self.clip), 'if you want to derive CLIP image embeddings automatically, you must supply `clip` to the decoder on init'
image_embed, _ = self.clip.embed_image(image) image_embed, _ = self.clip.embed_image(image)
text_encodings = text_mask = None text_encodings = text_mask = None
if exists(text) and not exists(text_encodings): if exists(text) and not exists(text_encodings) and not self.unconditional:
assert exists(self.clip), 'if you are passing in raw text, you need to supply `clip` to the decoder' assert exists(self.clip), 'if you are passing in raw text, you need to supply `clip` to the decoder'
_, text_encodings, text_mask = self.clip.embed_text(text) _, text_encodings, text_mask = self.clip.embed_text(text)

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.2.34', version = '0.2.36',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',