allow for training unconditional ddpm or cascading ddpms

This commit is contained in:
Phil Wang
2022-05-15 16:54:56 -07:00
parent 99778e12de
commit 11d4e11f10
3 changed files with 58 additions and 7 deletions

View File

@@ -706,7 +706,7 @@ mock_image_embed = torch.randn(1, 512).cuda()
images = decoder.sample(mock_image_embed) # (1, 3, 1024, 1024)
```
## Training wrapper (wip)
## Training wrapper
### Decoder Training
@@ -851,6 +851,57 @@ diffusion_prior_trainer.update() # this will update the optimizer as well as th
image_embeds = diffusion_prior_trainer.sample(text) # (4, 512) - exponential moving averaged image embeddings
```
## Bonus
### Unconditional Training
The repository also contains the means to train unconditional DDPM model, or even cascading DDPMs. You simply have to set `unconditional = True` in the `Decoder`
ex.
```python
import torch
from dalle2_pytorch import Unet, Decoder
# unet for the cascading ddpm
unet1 = Unet(
dim = 128,
dim_mults=(1, 2, 4, 8)
).cuda()
unet2 = Unet(
dim = 32,
dim_mults = (1, 2, 4, 8, 16)
).cuda()
# decoder, which contains the unets
decoder = Decoder(
unet = (unet1, unet2),
image_sizes = (256, 512), # first unet up to 256px, then second to 512px
timesteps = 1000,
unconditional = True
).cuda()
# mock images (get a lot of this)
images = torch.randn(1, 3, 512, 512).cuda()
# feed images into decoder
for i in (1, 2):
loss = decoder(images, unet_number = i)
loss.backward()
# do the above for many many many many steps
# then it will learn to generate images
images = decoder.sample(batch_size = 2) # (2, 3, 512, 512)
```
## Dataloaders
### Decoder Dataloaders
In order to make loading data simple and efficient, we include some general dataloaders that can be used to train portions of the network.

View File

@@ -1305,7 +1305,7 @@ class Unet(nn.Module):
self,
dim,
*,
image_embed_dim,
image_embed_dim = None,
text_embed_dim = None,
cond_dim = None,
num_image_tokens = 4,
@@ -1377,7 +1377,7 @@ class Unet(nn.Module):
self.image_to_cond = nn.Sequential(
nn.Linear(image_embed_dim, cond_dim * num_image_tokens),
Rearrange('b (n d) -> b n d', n = num_image_tokens)
) if image_embed_dim != cond_dim else nn.Identity()
) if cond_on_image_embeds and image_embed_dim != cond_dim else nn.Identity()
self.norm_cond = nn.LayerNorm(cond_dim)
self.norm_mid_cond = nn.LayerNorm(cond_dim)
@@ -1701,7 +1701,7 @@ class Decoder(BaseGaussianDiffusion):
self.unconditional = unconditional
assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present'
assert exists(clip) ^ exists(image_size), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)'
assert self.unconditional or (exists(clip) ^ exists(image_size)), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)'
self.clip = None
if exists(clip):
@@ -2036,12 +2036,12 @@ class Decoder(BaseGaussianDiffusion):
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
if not exists(image_embed):
if not exists(image_embed) and not self.unconditional:
assert exists(self.clip), 'if you want to derive CLIP image embeddings automatically, you must supply `clip` to the decoder on init'
image_embed, _ = self.clip.embed_image(image)
text_encodings = text_mask = None
if exists(text) and not exists(text_encodings):
if exists(text) and not exists(text_encodings) and not self.unconditional:
assert exists(self.clip), 'if you are passing in raw text, you need to supply `clip` to the decoder'
_, text_encodings, text_mask = self.clip.embed_text(text)

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.2.34',
version = '0.2.35',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',