mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-14 14:14:21 +01:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
11d4e11f10 |
53
README.md
53
README.md
@@ -706,7 +706,7 @@ mock_image_embed = torch.randn(1, 512).cuda()
|
|||||||
images = decoder.sample(mock_image_embed) # (1, 3, 1024, 1024)
|
images = decoder.sample(mock_image_embed) # (1, 3, 1024, 1024)
|
||||||
```
|
```
|
||||||
|
|
||||||
## Training wrapper (wip)
|
## Training wrapper
|
||||||
|
|
||||||
### Decoder Training
|
### Decoder Training
|
||||||
|
|
||||||
@@ -851,6 +851,57 @@ diffusion_prior_trainer.update() # this will update the optimizer as well as th
|
|||||||
image_embeds = diffusion_prior_trainer.sample(text) # (4, 512) - exponential moving averaged image embeddings
|
image_embeds = diffusion_prior_trainer.sample(text) # (4, 512) - exponential moving averaged image embeddings
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Bonus
|
||||||
|
|
||||||
|
### Unconditional Training
|
||||||
|
|
||||||
|
The repository also contains the means to train unconditional DDPM model, or even cascading DDPMs. You simply have to set `unconditional = True` in the `Decoder`
|
||||||
|
|
||||||
|
ex.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from dalle2_pytorch import Unet, Decoder
|
||||||
|
|
||||||
|
# unet for the cascading ddpm
|
||||||
|
|
||||||
|
unet1 = Unet(
|
||||||
|
dim = 128,
|
||||||
|
dim_mults=(1, 2, 4, 8)
|
||||||
|
).cuda()
|
||||||
|
|
||||||
|
unet2 = Unet(
|
||||||
|
dim = 32,
|
||||||
|
dim_mults = (1, 2, 4, 8, 16)
|
||||||
|
).cuda()
|
||||||
|
|
||||||
|
# decoder, which contains the unets
|
||||||
|
|
||||||
|
decoder = Decoder(
|
||||||
|
unet = (unet1, unet2),
|
||||||
|
image_sizes = (256, 512), # first unet up to 256px, then second to 512px
|
||||||
|
timesteps = 1000,
|
||||||
|
unconditional = True
|
||||||
|
).cuda()
|
||||||
|
|
||||||
|
# mock images (get a lot of this)
|
||||||
|
|
||||||
|
images = torch.randn(1, 3, 512, 512).cuda()
|
||||||
|
|
||||||
|
# feed images into decoder
|
||||||
|
|
||||||
|
for i in (1, 2):
|
||||||
|
loss = decoder(images, unet_number = i)
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
# do the above for many many many many steps
|
||||||
|
# then it will learn to generate images
|
||||||
|
|
||||||
|
images = decoder.sample(batch_size = 2) # (2, 3, 512, 512)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Dataloaders
|
||||||
|
|
||||||
### Decoder Dataloaders
|
### Decoder Dataloaders
|
||||||
|
|
||||||
In order to make loading data simple and efficient, we include some general dataloaders that can be used to train portions of the network.
|
In order to make loading data simple and efficient, we include some general dataloaders that can be used to train portions of the network.
|
||||||
|
|||||||
@@ -1305,7 +1305,7 @@ class Unet(nn.Module):
|
|||||||
self,
|
self,
|
||||||
dim,
|
dim,
|
||||||
*,
|
*,
|
||||||
image_embed_dim,
|
image_embed_dim = None,
|
||||||
text_embed_dim = None,
|
text_embed_dim = None,
|
||||||
cond_dim = None,
|
cond_dim = None,
|
||||||
num_image_tokens = 4,
|
num_image_tokens = 4,
|
||||||
@@ -1377,7 +1377,7 @@ class Unet(nn.Module):
|
|||||||
self.image_to_cond = nn.Sequential(
|
self.image_to_cond = nn.Sequential(
|
||||||
nn.Linear(image_embed_dim, cond_dim * num_image_tokens),
|
nn.Linear(image_embed_dim, cond_dim * num_image_tokens),
|
||||||
Rearrange('b (n d) -> b n d', n = num_image_tokens)
|
Rearrange('b (n d) -> b n d', n = num_image_tokens)
|
||||||
) if image_embed_dim != cond_dim else nn.Identity()
|
) if cond_on_image_embeds and image_embed_dim != cond_dim else nn.Identity()
|
||||||
|
|
||||||
self.norm_cond = nn.LayerNorm(cond_dim)
|
self.norm_cond = nn.LayerNorm(cond_dim)
|
||||||
self.norm_mid_cond = nn.LayerNorm(cond_dim)
|
self.norm_mid_cond = nn.LayerNorm(cond_dim)
|
||||||
@@ -1701,7 +1701,7 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
self.unconditional = unconditional
|
self.unconditional = unconditional
|
||||||
assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present'
|
assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present'
|
||||||
|
|
||||||
assert exists(clip) ^ exists(image_size), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)'
|
assert self.unconditional or (exists(clip) ^ exists(image_size)), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)'
|
||||||
|
|
||||||
self.clip = None
|
self.clip = None
|
||||||
if exists(clip):
|
if exists(clip):
|
||||||
@@ -2036,12 +2036,12 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
|
|
||||||
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
|
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
|
||||||
|
|
||||||
if not exists(image_embed):
|
if not exists(image_embed) and not self.unconditional:
|
||||||
assert exists(self.clip), 'if you want to derive CLIP image embeddings automatically, you must supply `clip` to the decoder on init'
|
assert exists(self.clip), 'if you want to derive CLIP image embeddings automatically, you must supply `clip` to the decoder on init'
|
||||||
image_embed, _ = self.clip.embed_image(image)
|
image_embed, _ = self.clip.embed_image(image)
|
||||||
|
|
||||||
text_encodings = text_mask = None
|
text_encodings = text_mask = None
|
||||||
if exists(text) and not exists(text_encodings):
|
if exists(text) and not exists(text_encodings) and not self.unconditional:
|
||||||
assert exists(self.clip), 'if you are passing in raw text, you need to supply `clip` to the decoder'
|
assert exists(self.clip), 'if you are passing in raw text, you need to supply `clip` to the decoder'
|
||||||
_, text_encodings, text_mask = self.clip.embed_text(text)
|
_, text_encodings, text_mask = self.clip.embed_text(text)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user