mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 11:34:29 +01:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b22ccd9dd0 |
63
README.md
63
README.md
@@ -706,7 +706,7 @@ mock_image_embed = torch.randn(1, 512).cuda()
|
||||
images = decoder.sample(mock_image_embed) # (1, 3, 1024, 1024)
|
||||
```
|
||||
|
||||
## Training wrapper
|
||||
## Training wrapper (wip)
|
||||
|
||||
### Decoder Training
|
||||
|
||||
@@ -851,57 +851,6 @@ diffusion_prior_trainer.update() # this will update the optimizer as well as th
|
||||
image_embeds = diffusion_prior_trainer.sample(text) # (4, 512) - exponential moving averaged image embeddings
|
||||
```
|
||||
|
||||
## Bonus
|
||||
|
||||
### Unconditional Training
|
||||
|
||||
The repository also contains the means to train unconditional DDPM model, or even cascading DDPMs. You simply have to set `unconditional = True` in the `Decoder`
|
||||
|
||||
ex.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from dalle2_pytorch import Unet, Decoder
|
||||
|
||||
# unet for the cascading ddpm
|
||||
|
||||
unet1 = Unet(
|
||||
dim = 128,
|
||||
dim_mults=(1, 2, 4, 8)
|
||||
).cuda()
|
||||
|
||||
unet2 = Unet(
|
||||
dim = 32,
|
||||
dim_mults = (1, 2, 4, 8, 16)
|
||||
).cuda()
|
||||
|
||||
# decoder, which contains the unets
|
||||
|
||||
decoder = Decoder(
|
||||
unet = (unet1, unet2),
|
||||
image_sizes = (256, 512), # first unet up to 256px, then second to 512px
|
||||
timesteps = 1000,
|
||||
unconditional = True
|
||||
).cuda()
|
||||
|
||||
# mock images (get a lot of this)
|
||||
|
||||
images = torch.randn(1, 3, 512, 512).cuda()
|
||||
|
||||
# feed images into decoder
|
||||
|
||||
for i in (1, 2):
|
||||
loss = decoder(images, unet_number = i)
|
||||
loss.backward()
|
||||
|
||||
# do the above for many many many many steps
|
||||
# then it will learn to generate images
|
||||
|
||||
images = decoder.sample(batch_size = 2) # (2, 3, 512, 512)
|
||||
```
|
||||
|
||||
## Dataloaders
|
||||
|
||||
### Decoder Dataloaders
|
||||
|
||||
In order to make loading data simple and efficient, we include some general dataloaders that can be used to train portions of the network.
|
||||
@@ -1065,7 +1014,6 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
- [ ] allow for unet to be able to condition non-cross attention style as well
|
||||
- [ ] for all model classes with hyperparameters that changes the network architecture, make it requirement that they must expose a config property, and write a simple function that asserts that it restores the object correctly
|
||||
- [ ] for both diffusion prior and decoder, all exponential moving averaged models needs to be saved and restored as well (as well as the step number)
|
||||
- [ ] read the paper, figure it out, and build it https://github.com/lucidrains/DALLE2-pytorch/issues/89
|
||||
|
||||
## Citations
|
||||
|
||||
@@ -1154,13 +1102,4 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@article{ho2021cascaded,
|
||||
title = {Cascaded Diffusion Models for High Fidelity Image Generation},
|
||||
author = {Ho, Jonathan and Saharia, Chitwan and Chan, William and Fleet, David J and Norouzi, Mohammad and Salimans, Tim},
|
||||
journal = {arXiv preprint arXiv:2106.15282},
|
||||
year = {2021}
|
||||
}
|
||||
```
|
||||
|
||||
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>
|
||||
|
||||
@@ -794,7 +794,7 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
text_embed,
|
||||
text_encodings = None,
|
||||
mask = None,
|
||||
cond_drop_prob = 0.
|
||||
cond_drop_prob = 0.2
|
||||
):
|
||||
batch, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype
|
||||
|
||||
@@ -1305,7 +1305,7 @@ class Unet(nn.Module):
|
||||
self,
|
||||
dim,
|
||||
*,
|
||||
image_embed_dim = None,
|
||||
image_embed_dim,
|
||||
text_embed_dim = None,
|
||||
cond_dim = None,
|
||||
num_image_tokens = 4,
|
||||
@@ -1377,7 +1377,7 @@ class Unet(nn.Module):
|
||||
self.image_to_cond = nn.Sequential(
|
||||
nn.Linear(image_embed_dim, cond_dim * num_image_tokens),
|
||||
Rearrange('b (n d) -> b n d', n = num_image_tokens)
|
||||
) if cond_on_image_embeds and image_embed_dim != cond_dim else nn.Identity()
|
||||
) if image_embed_dim != cond_dim else nn.Identity()
|
||||
|
||||
self.norm_cond = nn.LayerNorm(cond_dim)
|
||||
self.norm_mid_cond = nn.LayerNorm(cond_dim)
|
||||
@@ -1701,7 +1701,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
self.unconditional = unconditional
|
||||
assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present'
|
||||
|
||||
assert self.unconditional or (exists(clip) ^ exists(image_size)), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)'
|
||||
assert exists(clip) ^ exists(image_size), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)'
|
||||
|
||||
self.clip = None
|
||||
if exists(clip):
|
||||
@@ -2036,12 +2036,12 @@ class Decoder(BaseGaussianDiffusion):
|
||||
|
||||
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
|
||||
|
||||
if not exists(image_embed) and not self.unconditional:
|
||||
if not exists(image_embed):
|
||||
assert exists(self.clip), 'if you want to derive CLIP image embeddings automatically, you must supply `clip` to the decoder on init'
|
||||
image_embed, _ = self.clip.embed_image(image)
|
||||
|
||||
text_encodings = text_mask = None
|
||||
if exists(text) and not exists(text_encodings) and not self.unconditional:
|
||||
if exists(text) and not exists(text_encodings):
|
||||
assert exists(self.clip), 'if you are passing in raw text, you need to supply `clip` to the decoder'
|
||||
_, text_encodings, text_mask = self.clip.embed_text(text)
|
||||
|
||||
|
||||
@@ -52,17 +52,12 @@ def groupby_prefix_and_trim(prefix, d):
|
||||
def cast_torch_tensor(fn):
|
||||
@wraps(fn)
|
||||
def inner(model, *args, **kwargs):
|
||||
device = kwargs.pop('_device', next(model.parameters()).device)
|
||||
cast_device = kwargs.pop('_cast_device', True)
|
||||
|
||||
device = next(model.parameters()).device
|
||||
kwargs_keys = kwargs.keys()
|
||||
all_args = (*args, *kwargs.values())
|
||||
split_kwargs_index = len(all_args) - len(kwargs_keys)
|
||||
all_args = tuple(map(lambda t: torch.from_numpy(t) if exists(t) and isinstance(t, np.ndarray) else t, all_args))
|
||||
|
||||
if cast_device:
|
||||
all_args = tuple(map(lambda t: t.to(device) if exists(t) and isinstance(t, torch.Tensor) else t, all_args))
|
||||
|
||||
all_args = tuple(map(lambda t: t.to(device) if exists(t) and isinstance(t, torch.Tensor) else t, all_args))
|
||||
args, kwargs_values = all_args[:split_kwargs_index], all_args[split_kwargs_index:]
|
||||
kwargs = dict(tuple(zip(kwargs_keys, kwargs_values)))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user