mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
allow for training unconditional ddpm or cascading ddpms
This commit is contained in:
53
README.md
53
README.md
@@ -706,7 +706,7 @@ mock_image_embed = torch.randn(1, 512).cuda()
|
||||
images = decoder.sample(mock_image_embed) # (1, 3, 1024, 1024)
|
||||
```
|
||||
|
||||
## Training wrapper (wip)
|
||||
## Training wrapper
|
||||
|
||||
### Decoder Training
|
||||
|
||||
@@ -851,6 +851,57 @@ diffusion_prior_trainer.update() # this will update the optimizer as well as th
|
||||
image_embeds = diffusion_prior_trainer.sample(text) # (4, 512) - exponential moving averaged image embeddings
|
||||
```
|
||||
|
||||
## Bonus
|
||||
|
||||
### Unconditional Training
|
||||
|
||||
The repository also contains the means to train unconditional DDPM model, or even cascading DDPMs. You simply have to set `unconditional = True` in the `Decoder`
|
||||
|
||||
ex.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from dalle2_pytorch import Unet, Decoder
|
||||
|
||||
# unet for the cascading ddpm
|
||||
|
||||
unet1 = Unet(
|
||||
dim = 128,
|
||||
dim_mults=(1, 2, 4, 8)
|
||||
).cuda()
|
||||
|
||||
unet2 = Unet(
|
||||
dim = 32,
|
||||
dim_mults = (1, 2, 4, 8, 16)
|
||||
).cuda()
|
||||
|
||||
# decoder, which contains the unets
|
||||
|
||||
decoder = Decoder(
|
||||
unet = (unet1, unet2),
|
||||
image_sizes = (256, 512), # first unet up to 256px, then second to 512px
|
||||
timesteps = 1000,
|
||||
unconditional = True
|
||||
).cuda()
|
||||
|
||||
# mock images (get a lot of this)
|
||||
|
||||
images = torch.randn(1, 3, 512, 512).cuda()
|
||||
|
||||
# feed images into decoder
|
||||
|
||||
for i in (1, 2):
|
||||
loss = decoder(images, unet_number = i)
|
||||
loss.backward()
|
||||
|
||||
# do the above for many many many many steps
|
||||
# then it will learn to generate images
|
||||
|
||||
images = decoder.sample(batch_size = 2) # (2, 3, 512, 512)
|
||||
```
|
||||
|
||||
## Dataloaders
|
||||
|
||||
### Decoder Dataloaders
|
||||
|
||||
In order to make loading data simple and efficient, we include some general dataloaders that can be used to train portions of the network.
|
||||
|
||||
@@ -1305,7 +1305,7 @@ class Unet(nn.Module):
|
||||
self,
|
||||
dim,
|
||||
*,
|
||||
image_embed_dim,
|
||||
image_embed_dim = None,
|
||||
text_embed_dim = None,
|
||||
cond_dim = None,
|
||||
num_image_tokens = 4,
|
||||
@@ -1377,7 +1377,7 @@ class Unet(nn.Module):
|
||||
self.image_to_cond = nn.Sequential(
|
||||
nn.Linear(image_embed_dim, cond_dim * num_image_tokens),
|
||||
Rearrange('b (n d) -> b n d', n = num_image_tokens)
|
||||
) if image_embed_dim != cond_dim else nn.Identity()
|
||||
) if cond_on_image_embeds and image_embed_dim != cond_dim else nn.Identity()
|
||||
|
||||
self.norm_cond = nn.LayerNorm(cond_dim)
|
||||
self.norm_mid_cond = nn.LayerNorm(cond_dim)
|
||||
@@ -1701,7 +1701,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
self.unconditional = unconditional
|
||||
assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present'
|
||||
|
||||
assert exists(clip) ^ exists(image_size), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)'
|
||||
assert self.unconditional or (exists(clip) ^ exists(image_size)), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)'
|
||||
|
||||
self.clip = None
|
||||
if exists(clip):
|
||||
@@ -2036,12 +2036,12 @@ class Decoder(BaseGaussianDiffusion):
|
||||
|
||||
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
|
||||
|
||||
if not exists(image_embed):
|
||||
if not exists(image_embed) and not self.unconditional:
|
||||
assert exists(self.clip), 'if you want to derive CLIP image embeddings automatically, you must supply `clip` to the decoder on init'
|
||||
image_embed, _ = self.clip.embed_image(image)
|
||||
|
||||
text_encodings = text_mask = None
|
||||
if exists(text) and not exists(text_encodings):
|
||||
if exists(text) and not exists(text_encodings) and not self.unconditional:
|
||||
assert exists(self.clip), 'if you are passing in raw text, you need to supply `clip` to the decoder'
|
||||
_, text_encodings, text_mask = self.clip.embed_text(text)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user