mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 11:34:29 +01:00
Compare commits
118 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d0c11b30b0 | ||
|
|
86e2d5ba84 | ||
|
|
0d82dff9c5 | ||
|
|
8bbc956ff1 | ||
|
|
22019fddeb | ||
|
|
6fb7e91343 | ||
|
|
ba58ae0bf2 | ||
|
|
1cc5d0afa7 | ||
|
|
59fa101c4d | ||
|
|
916ece164c | ||
|
|
cbaadb6931 | ||
|
|
083508ff8e | ||
|
|
7762edd0ff | ||
|
|
de5e628773 | ||
|
|
1b4046b039 | ||
|
|
27f19ba7fa | ||
|
|
8f38339c2b | ||
|
|
6b9b4b9e5e | ||
|
|
44e09d5a4d | ||
|
|
34806663e3 | ||
|
|
dc816b1b6e | ||
|
|
05192ffac4 | ||
|
|
9440411954 | ||
|
|
981d407792 | ||
|
|
7c5477b26d | ||
|
|
be3bb868bf | ||
|
|
451de34871 | ||
|
|
f22e8c8741 | ||
|
|
87432e93ad | ||
|
|
d167378401 | ||
|
|
2d67d5821e | ||
|
|
748c7fe7af | ||
|
|
80046334ad | ||
|
|
36fb46a95e | ||
|
|
07abfcf45b | ||
|
|
2e35a9967d | ||
|
|
406e75043f | ||
|
|
9646dfc0e6 | ||
|
|
62043acb2f | ||
|
|
417ff808e6 | ||
|
|
f3d7e226ba | ||
|
|
48a1302428 | ||
|
|
ccaa46b81b | ||
|
|
76d08498cc | ||
|
|
f9423d308b | ||
|
|
06c65b60d2 | ||
|
|
4145474bab | ||
|
|
4b912a38c6 | ||
|
|
f97e55ec6b | ||
|
|
291377bb9c | ||
|
|
7f120a8b56 | ||
|
|
8c003ab1e1 | ||
|
|
723bf0abba | ||
|
|
d88c7ba56c | ||
|
|
3676a8ce78 | ||
|
|
da8e99ada0 | ||
|
|
6afb886cf4 | ||
|
|
c7fe4f2f44 | ||
|
|
a2ee3fa3cc | ||
|
|
a58a370d75 | ||
|
|
1662bbf226 | ||
|
|
5be1f57448 | ||
|
|
c52ce58e10 | ||
|
|
a34f60962a | ||
|
|
0b40cbaa54 | ||
|
|
f141144a6d | ||
|
|
f988207718 | ||
|
|
b2073219f0 | ||
|
|
cc0f7a935c | ||
|
|
95a512cb65 | ||
|
|
972ee973bc | ||
|
|
79e2a3bc77 | ||
|
|
544cdd0b29 | ||
|
|
349aaca56f | ||
|
|
3ee3c56d2a | ||
|
|
cd26c6b17d | ||
|
|
775abc4df6 | ||
|
|
11b1d533a0 | ||
|
|
e76e89f9eb | ||
|
|
bb3ff0ac67 | ||
|
|
1ec4dbe64f | ||
|
|
e0835acca9 | ||
|
|
e055793e5d | ||
|
|
1d9ef99288 | ||
|
|
bdd62c24b3 | ||
|
|
1f1557c614 | ||
|
|
1a217e99e3 | ||
|
|
7ea314e2f0 | ||
|
|
4173e88121 | ||
|
|
3dae43fa0e | ||
|
|
a598820012 | ||
|
|
4878762627 | ||
|
|
47ae17b36e | ||
|
|
b7e22f7da0 | ||
|
|
68de937aac | ||
|
|
097afda606 | ||
|
|
5c520db825 | ||
|
|
3070610231 | ||
|
|
870aeeca62 | ||
|
|
f28dc6dc01 | ||
|
|
081d8d3484 | ||
|
|
a71f693a26 | ||
|
|
d7bc5fbedd | ||
|
|
8c823affff | ||
|
|
ec7cab01d9 | ||
|
|
46be8c32d3 | ||
|
|
900f086a6d | ||
|
|
b3e646fd3b | ||
|
|
6a59c7093d | ||
|
|
a6cdbe0b9c | ||
|
|
e928ae5c34 | ||
|
|
1bd8a7835a | ||
|
|
f33453df9f | ||
|
|
1e4bb2bafb | ||
|
|
ee75515c7d | ||
|
|
ec68243479 | ||
|
|
3afdcdfe86 | ||
|
|
b9a908ff75 |
2
.github/FUNDING.yml
vendored
2
.github/FUNDING.yml
vendored
@@ -1 +1 @@
|
||||
github: [lucidrains]
|
||||
github: [nousr, Veldrovive, lucidrains]
|
||||
|
||||
33
.github/workflows/ci.yml
vendored
Normal file
33
.github/workflows/ci.yml
vendored
Normal file
@@ -0,0 +1,33 @@
|
||||
name: Continuous integration
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
|
||||
jobs:
|
||||
tests:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: [3.8]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install
|
||||
run: |
|
||||
python3 -m venv .env
|
||||
source .env/bin/activate
|
||||
make install
|
||||
- name: Tests
|
||||
run: |
|
||||
source .env/bin/activate
|
||||
make test
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -136,3 +136,5 @@ dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
.tracker_data
|
||||
*.pth
|
||||
|
||||
6
Makefile
Normal file
6
Makefile
Normal file
@@ -0,0 +1,6 @@
|
||||
install:
|
||||
pip install -U pip
|
||||
pip install -e .
|
||||
|
||||
test:
|
||||
CUDA_VISIBLE_DEVICES= python train_decoder.py --config_file configs/train_decoder_config.test.json
|
||||
228
README.md
228
README.md
@@ -44,9 +44,12 @@ This library would not have gotten to this working state without the help of
|
||||
- <a href="https://github.com/krish240574">Kumar</a> for working on the initial diffusion training script
|
||||
- <a href="https://github.com/rom1504">Romain</a> for the pull request reviews and project management
|
||||
- <a href="https://github.com/Ciaohe">He Cao</a> and <a href="https://github.com/xiankgx">xiankgx</a> for the Q&A and for identifying of critical bugs
|
||||
- <a href="https://github.com/marunine">Marunine</a> for identifying issues with resizing of the low resolution conditioner, when training the upsampler, in addition to various other bug fixes
|
||||
- <a href="https://github.com/malumadev">MalumaDev</a> for proposing the use of pixel shuffle upsampler for fixing checkboard artifacts
|
||||
- <a href="https://github.com/crowsonkb">Katherine</a> for her advice
|
||||
- <a href="https://stability.ai/">Stability AI</a> for the generous sponsorship
|
||||
- <a href="https://huggingface.co">🤗 Huggingface</a> and in particular <a href="https://github.com/sgugger">Sylvain</a> for the <a href="https://github.com/huggingface/accelerate">Accelerate</a> library
|
||||
- <a href="https://github.com/arogozhnikov">Alex</a> for <a href="https://github.com/arogozhnikov/einops">einops</a>, indispensable tool for tensor manipulation
|
||||
|
||||
... and many others. Thank you! 🙏
|
||||
|
||||
@@ -354,7 +357,8 @@ prior_network = DiffusionPriorNetwork(
|
||||
diffusion_prior = DiffusionPrior(
|
||||
net = prior_network,
|
||||
clip = clip,
|
||||
timesteps = 100,
|
||||
timesteps = 1000,
|
||||
sample_timesteps = 64,
|
||||
cond_drop_prob = 0.2
|
||||
).cuda()
|
||||
|
||||
@@ -368,6 +372,7 @@ loss.backward()
|
||||
unet1 = Unet(
|
||||
dim = 128,
|
||||
image_embed_dim = 512,
|
||||
text_embed_dim = 512,
|
||||
cond_dim = 128,
|
||||
channels = 3,
|
||||
dim_mults=(1, 2, 4, 8),
|
||||
@@ -392,7 +397,7 @@ decoder = Decoder(
|
||||
).cuda()
|
||||
|
||||
for unet_number in (1, 2):
|
||||
loss = decoder(images, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
|
||||
loss = decoder(images, text = text, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
|
||||
loss.backward()
|
||||
|
||||
# do above for many steps
|
||||
@@ -418,7 +423,7 @@ For the layperson, no worries, training will all be automated into a CLI tool, a
|
||||
|
||||
## Training on Preprocessed CLIP Embeddings
|
||||
|
||||
It is likely, when scaling up, that you would first preprocess your images and text into corresponding embeddings before training the prior network. You can do so easily by simply passing in `image_embed`, `text_embed`, and optionally `text_encodings` and `text_mask`
|
||||
It is likely, when scaling up, that you would first preprocess your images and text into corresponding embeddings before training the prior network. You can do so easily by simply passing in `image_embed`, `text_embed`, and optionally `text_encodings`
|
||||
|
||||
Working example below
|
||||
|
||||
@@ -581,7 +586,9 @@ unet1 = Unet(
|
||||
image_embed_dim = 512,
|
||||
cond_dim = 128,
|
||||
channels = 3,
|
||||
dim_mults=(1, 2, 4, 8)
|
||||
dim_mults=(1, 2, 4, 8),
|
||||
text_embed_dim = 512,
|
||||
cond_on_text_encodings = True # set to True for any unets that need to be conditioned on text encodings (ex. first unet in cascade)
|
||||
).cuda()
|
||||
|
||||
unet2 = Unet(
|
||||
@@ -596,14 +603,14 @@ decoder = Decoder(
|
||||
unet = (unet1, unet2),
|
||||
image_sizes = (128, 256),
|
||||
clip = clip,
|
||||
timesteps = 100,
|
||||
timesteps = 1000,
|
||||
sample_timesteps = (250, 27),
|
||||
image_cond_drop_prob = 0.1,
|
||||
text_cond_drop_prob = 0.5,
|
||||
condition_on_text_encodings = False # set this to True if you wish to condition on text during training and sampling
|
||||
text_cond_drop_prob = 0.5
|
||||
).cuda()
|
||||
|
||||
for unet_number in (1, 2):
|
||||
loss = decoder(images, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
|
||||
loss = decoder(images, text = text, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
|
||||
loss.backward()
|
||||
|
||||
# do above for many steps
|
||||
@@ -621,8 +628,96 @@ images = dalle2(
|
||||
# save your image (in this example, of size 256x256)
|
||||
```
|
||||
|
||||
Alternatively, you can also use <a href="https://github.com/mlfoundations/open_clip">Open Clip</a>
|
||||
|
||||
```bash
|
||||
$ pip install open-clip-torch
|
||||
```
|
||||
|
||||
```python
|
||||
from dalle2_pytorch import OpenClipAdapter
|
||||
|
||||
clip = OpenClipAdapter()
|
||||
```
|
||||
|
||||
Now you'll just have to worry about training the Prior and the Decoder!
|
||||
|
||||
## Inpainting
|
||||
|
||||
Inpainting is also built into the `Decoder`. You simply have to pass in the `inpaint_image` and `inpaint_mask` (boolean tensor where `True` indicates which regions of the inpaint image to keep)
|
||||
|
||||
This repository uses the formulation put forth by <a href="https://arxiv.org/abs/2201.09865">Lugmayr et al. in Repaint</a>
|
||||
|
||||
```python
|
||||
import torch
|
||||
from dalle2_pytorch import Unet, Decoder, CLIP
|
||||
|
||||
# trained clip from step 1
|
||||
|
||||
clip = CLIP(
|
||||
dim_text = 512,
|
||||
dim_image = 512,
|
||||
dim_latent = 512,
|
||||
num_text_tokens = 49408,
|
||||
text_enc_depth = 6,
|
||||
text_seq_len = 256,
|
||||
text_heads = 8,
|
||||
visual_enc_depth = 6,
|
||||
visual_image_size = 256,
|
||||
visual_patch_size = 32,
|
||||
visual_heads = 8
|
||||
).cuda()
|
||||
|
||||
# 2 unets for the decoder (a la cascading DDPM)
|
||||
|
||||
unet = Unet(
|
||||
dim = 16,
|
||||
image_embed_dim = 512,
|
||||
cond_dim = 128,
|
||||
channels = 3,
|
||||
dim_mults = (1, 1, 1, 1)
|
||||
).cuda()
|
||||
|
||||
|
||||
# decoder, which contains the unet(s) and clip
|
||||
|
||||
decoder = Decoder(
|
||||
clip = clip,
|
||||
unet = (unet,), # insert both unets in order of low resolution to highest resolution (you can have as many stages as you want here)
|
||||
image_sizes = (256,), # resolutions, 256 for first unet, 512 for second. these must be unique and in ascending order (matches with the unets passed in)
|
||||
timesteps = 1000,
|
||||
image_cond_drop_prob = 0.1,
|
||||
text_cond_drop_prob = 0.5
|
||||
).cuda()
|
||||
|
||||
# mock images (get a lot of this)
|
||||
|
||||
images = torch.randn(4, 3, 256, 256).cuda()
|
||||
|
||||
# feed images into decoder, specifying which unet you want to train
|
||||
# each unet can be trained separately, which is one of the benefits of the cascading DDPM scheme
|
||||
|
||||
loss = decoder(images, unet_number = 1)
|
||||
loss.backward()
|
||||
|
||||
# do the above for many steps for both unets
|
||||
|
||||
mock_image_embed = torch.randn(1, 512).cuda()
|
||||
|
||||
# then to do inpainting
|
||||
|
||||
inpaint_image = torch.randn(1, 3, 256, 256).cuda() # (batch, channels, height, width)
|
||||
inpaint_mask = torch.ones(1, 256, 256).bool().cuda() # (batch, height, width)
|
||||
|
||||
inpainted_images = decoder.sample(
|
||||
image_embed = mock_image_embed,
|
||||
inpaint_image = inpaint_image, # just pass in the inpaint image
|
||||
inpaint_mask = inpaint_mask # and the mask
|
||||
)
|
||||
|
||||
inpainted_images.shape # (1, 3, 256, 256)
|
||||
```
|
||||
|
||||
## Experimental
|
||||
|
||||
### DALL-E2 with Latent Diffusion
|
||||
@@ -779,25 +874,23 @@ unet1 = Unet(
|
||||
text_embed_dim = 512,
|
||||
cond_dim = 128,
|
||||
channels = 3,
|
||||
dim_mults=(1, 2, 4, 8)
|
||||
dim_mults=(1, 2, 4, 8),
|
||||
cond_on_text_encodings = True,
|
||||
).cuda()
|
||||
|
||||
unet2 = Unet(
|
||||
dim = 16,
|
||||
image_embed_dim = 512,
|
||||
text_embed_dim = 512,
|
||||
cond_dim = 128,
|
||||
channels = 3,
|
||||
dim_mults = (1, 2, 4, 8, 16),
|
||||
cond_on_text_encodings = True
|
||||
).cuda()
|
||||
|
||||
decoder = Decoder(
|
||||
unet = (unet1, unet2),
|
||||
image_sizes = (128, 256),
|
||||
clip = clip,
|
||||
timesteps = 1000,
|
||||
condition_on_text_encodings = True
|
||||
timesteps = 1000
|
||||
).cuda()
|
||||
|
||||
decoder_trainer = DecoderTrainer(
|
||||
@@ -822,8 +915,8 @@ for unet_number in (1, 2):
|
||||
# after much training
|
||||
# you can sample from the exponentially moving averaged unets as so
|
||||
|
||||
mock_image_embed = torch.randn(4, 512).cuda()
|
||||
images = decoder_trainer.sample(mock_image_embed, text = text) # (4, 3, 256, 256)
|
||||
mock_image_embed = torch.randn(32, 512).cuda()
|
||||
images = decoder_trainer.sample(image_embed = mock_image_embed, text = text) # (4, 3, 256, 256)
|
||||
```
|
||||
|
||||
### Diffusion Prior Training
|
||||
@@ -986,52 +1079,11 @@ dataset = ImageEmbeddingDataset(
|
||||
)
|
||||
```
|
||||
|
||||
### Scripts (wip)
|
||||
### Scripts
|
||||
|
||||
#### `train_diffusion_prior.py`
|
||||
|
||||
This script allows training the DiffusionPrior on pre-computed text and image embeddings. The working example below elucidates this process.
|
||||
Please note that the script internally passes text_embed and image_embed to the DiffusionPrior, unlike the example below.
|
||||
|
||||
#### Usage
|
||||
|
||||
```bash
|
||||
$ python train_diffusion_prior.py
|
||||
```
|
||||
|
||||
The most significant parameters for the script are as follows:
|
||||
|
||||
- `image-embed-url`, default = `"https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/"`
|
||||
|
||||
- `text-embed-url`, default = `"https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/"`
|
||||
|
||||
- `image-embed-dim`, default = `768` - 768 corresponds to the ViT iL/14 embedding size,change it to what your chosen ViT generates
|
||||
|
||||
- `learning-rate`, default = `1.1e-4`
|
||||
|
||||
- `weight-decay`, default = `6.02e-2`
|
||||
|
||||
- `max-grad-norm`, default = `0.5`
|
||||
|
||||
- `batch-size`, default = `10 ** 4`
|
||||
|
||||
- `num-epochs`, default = `5`
|
||||
|
||||
- `clip`, default = `None` # Signals the prior to use pre-computed embeddings
|
||||
|
||||
## CLI (wip)
|
||||
|
||||
```bash
|
||||
$ dream 'sharing a sunset at the summit of mount everest with my dog'
|
||||
```
|
||||
|
||||
Once built, images will be saved to the same directory the command is invoked
|
||||
|
||||
<a href="https://github.com/lucidrains/big-sleep">template</a>
|
||||
|
||||
## Training CLI (wip)
|
||||
|
||||
<a href="https://github.com/lucidrains/stylegan2-pytorch">template</a>
|
||||
For detailed information on training the diffusion prior, please refer to the [dedicated readme](prior.md)
|
||||
|
||||
## Todo
|
||||
|
||||
@@ -1070,11 +1122,12 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
- [x] bring in skip-layer excitations (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training (doesnt work well)
|
||||
- [x] test out grid attention in cascading ddpm locally, decide whether to keep or remove https://arxiv.org/abs/2204.01697 (keeping, seems to be fine)
|
||||
- [x] allow for unet to be able to condition non-cross attention style as well
|
||||
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
|
||||
- [ ] speed up inference, read up on papers (ddim or diffusion-gan, etc)
|
||||
- [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824
|
||||
- [x] speed up inference, read up on papers (ddim)
|
||||
- [x] add inpainting ability using resampler from repaint paper https://arxiv.org/abs/2201.09865
|
||||
- [x] add the final combination of upsample feature maps, used in unet squared, seems to have an effect in local experiments
|
||||
- [ ] consider elucidated dalle2 https://arxiv.org/abs/2206.00364
|
||||
- [ ] add simple outpainting, text-guided 2x size the image for starters
|
||||
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
|
||||
- [ ] add inpainting ability using resampler from repaint paper https://arxiv.org/abs/2201.09865
|
||||
|
||||
## Citations
|
||||
|
||||
@@ -1192,4 +1245,55 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@article{Lugmayr2022RePaintIU,
|
||||
title = {RePaint: Inpainting using Denoising Diffusion Probabilistic Models},
|
||||
author = {Andreas Lugmayr and Martin Danelljan and Andr{\'e}s Romero and Fisher Yu and Radu Timofte and Luc Van Gool},
|
||||
journal = {ArXiv},
|
||||
year = {2022},
|
||||
volume = {abs/2201.09865}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{chen2022analog,
|
||||
title = {Analog Bits: Generating Discrete Data using Diffusion Models with Self-Conditioning},
|
||||
author = {Ting Chen and Ruixiang Zhang and Geoffrey Hinton},
|
||||
year = {2022},
|
||||
eprint = {2208.04202},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@article{Qiao2019WeightS,
|
||||
title = {Weight Standardization},
|
||||
author = {Siyuan Qiao and Huiyu Wang and Chenxi Liu and Wei Shen and Alan Loddon Yuille},
|
||||
journal = {ArXiv},
|
||||
year = {2019},
|
||||
volume = {abs/1903.10520}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@inproceedings{rogozhnikov2022einops,
|
||||
title = {Einops: Clear and Reliable Tensor Manipulations with Einstein-like Notation},
|
||||
author = {Alex Rogozhnikov},
|
||||
booktitle = {International Conference on Learning Representations},
|
||||
year = {2022},
|
||||
url = {https://openreview.net/forum?id=oapKSVM2bcj}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@article{Sunkara2022NoMS,
|
||||
title = {No More Strided Convolutions or Pooling: A New CNN Building Block for Low-Resolution Images and Small Objects},
|
||||
author = {Raja Sunkara and Tie Luo},
|
||||
journal = {ArXiv},
|
||||
year = {2022},
|
||||
volume = {abs/2208.03641}
|
||||
}
|
||||
```
|
||||
|
||||
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>
|
||||
|
||||
@@ -30,6 +30,7 @@ Defines the configuration options for the decoder model. The unets defined above
|
||||
| `loss_type` | No | `l2` | The loss function. Options are `l1`, `huber`, or `l2`. |
|
||||
| `beta_schedule` | No | `cosine` | The noising schedule. Options are `cosine`, `linear`, `quadratic`, `jsd`, or `sigmoid`. |
|
||||
| `learned_variance` | No | `True` | Whether to learn the variance. |
|
||||
| `clip` | No | `None` | The clip model to use if embeddings are being generated on the fly. Takes keys `make` and `model` with defaults `openai` and `ViT-L/14`. |
|
||||
|
||||
Any parameter from the `Decoder` constructor can also be given here.
|
||||
|
||||
@@ -39,7 +40,8 @@ Settings for creation of the dataloaders.
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `webdataset_base_url` | Yes | N/A | The url of a shard in the webdataset with the shard replaced with `{}`[^1]. |
|
||||
| `embeddings_url` | No | N/A | The url of the folder containing embeddings shards. Not required if embeddings are in webdataset. |
|
||||
| `img_embeddings_url` | No | `None` | The url of the folder containing image embeddings shards. Not required if embeddings are in webdataset or clip is being used. |
|
||||
| `text_embeddings_url` | No | `None` | The url of the folder containing text embeddings shards. Not required if embeddings are in webdataset or clip is being used. |
|
||||
| `num_workers` | No | `4` | The number of workers used in the dataloader. |
|
||||
| `batch_size` | No | `64` | The batch size. |
|
||||
| `start_shard` | No | `0` | Defines the start of the shard range the dataset will recall. |
|
||||
@@ -67,14 +69,12 @@ Settings for controlling the training hyperparameters.
|
||||
| `wd` | No | `0.01` | The weight decay. |
|
||||
| `max_grad_norm`| No | `0.5` | The grad norm clipping. |
|
||||
| `save_every_n_samples` | No | `100000` | Samples will be generated and a checkpoint will be saved every `save_every_n_samples` samples. |
|
||||
| `cond_scale` | No | `1.0` | Conditioning scale to use for sampling. Can also be an array of values, one for each unet. |
|
||||
| `device` | No | `cuda:0` | The device to train on. |
|
||||
| `epoch_samples` | No | `None` | Limits the number of samples iterated through in each epoch. This must be set if resampling. None means no limit. |
|
||||
| `validation_samples` | No | `None` | The number of samples to use for validation. None mean the entire validation set. |
|
||||
| `use_ema` | No | `True` | Whether to use exponential moving average models for sampling. |
|
||||
| `ema_beta` | No | `0.99` | The ema coefficient. |
|
||||
| `save_all` | No | `False` | If True, preserves a checkpoint for every epoch. |
|
||||
| `save_latest` | No | `True` | If True, overwrites the `latest.pth` every time the model is saved. |
|
||||
| `save_best` | No | `True` | If True, overwrites the `best.pth` every time the model has a lower validation loss than all previous models. |
|
||||
| `unet_training_mask` | No | `None` | A boolean array of the same length as the number of unets. If false, the unet is frozen. A value of `None` trains all unets. |
|
||||
|
||||
**<ins>Evaluate</ins>:**
|
||||
@@ -106,6 +106,13 @@ Tracking is split up into three sections:
|
||||
|
||||
**Logging:**
|
||||
|
||||
All loggers have the following keys:
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `log_type` | Yes | N/A | The type of logger class to use. |
|
||||
| `resume` | No | `False` | For loggers that have the option to resume an old run, resume it using maually input parameters. |
|
||||
| `auto_resume` | No | `False` | If true, the logger will attempt to resume an old run using parameters from that previous run. |
|
||||
|
||||
If using `console` there is no further configuration than setting `log_type` to `console`.
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
@@ -119,10 +126,15 @@ If using `wandb`
|
||||
| `wandb_project` | Yes | N/A | The wandb project save the run to. |
|
||||
| `wandb_run_name` | No | `None` | The wandb run name. |
|
||||
| `wandb_run_id` | No | `None` | The wandb run id. Used if resuming an old run. |
|
||||
| `wandb_resume` | No | `False` | Whether to resume an old run. |
|
||||
|
||||
**Loading:**
|
||||
|
||||
All loaders have the following keys:
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `load_from` | Yes | N/A | The type of loader class to use. |
|
||||
| `only_auto_resume` | No | `False` | If true, the loader will only load the model if the run is being auto resumed. |
|
||||
|
||||
If using `local`
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
@@ -149,9 +161,10 @@ All save locations have these configuration options
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `save_to` | Yes | N/A | Must be `local`, `huggingface`, or `wandb`. |
|
||||
| `save_latest_to` | No | `latest.pth` | Sets the relative path to save the latest model to. |
|
||||
| `save_best_to` | No | `best.pth` | Sets the relative path to save the best model to every time the model has a lower validation loss than all previous models. |
|
||||
| `save_type` | No | `'checkpoint'` | The type of save. `'checkpoint'` saves a checkpoint, `'model'` saves a model without any fluff (Saves with ema if ema is enabled). |
|
||||
| `save_latest_to` | No | `None` | Sets the relative path to save the latest model to. |
|
||||
| `save_best_to` | No | `None` | Sets the relative path to save the best model to every time the model has a lower validation loss than all previous models. |
|
||||
| `save_meta_to` | No | `None` | The path to save metadata files in. This includes the config files used to start the training. |
|
||||
| `save_type` | No | `checkpoint` | The type of save. `checkpoint` saves a checkpoint, `model` saves a model without any fluff (Saves with ema if ema is enabled). |
|
||||
|
||||
If using `local`
|
||||
| Option | Required | Default | Description |
|
||||
@@ -163,7 +176,6 @@ If using `huggingface`
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `save_to` | Yes | N/A | Must be `huggingface`. |
|
||||
| `huggingface_repo` | Yes | N/A | The huggingface repository to save to. |
|
||||
| `huggingface_base_path` | Yes | N/A | The base path that checkpoints will be saved under. |
|
||||
| `token_path` | No | `None` | If logging in with the huggingface cli is not possible, point to a token file instead. |
|
||||
|
||||
If using `wandb`
|
||||
|
||||
@@ -20,7 +20,7 @@
|
||||
},
|
||||
"data": {
|
||||
"webdataset_base_url": "pipe:s3cmd get s3://bucket/path/{}.tar -",
|
||||
"embeddings_url": "s3://bucket/embeddings/path/",
|
||||
"img_embeddings_url": "s3://bucket/img_embeddings/path/",
|
||||
"num_workers": 4,
|
||||
"batch_size": 64,
|
||||
"start_shard": 0,
|
||||
@@ -56,9 +56,6 @@
|
||||
"use_ema": true,
|
||||
"ema_beta": 0.99,
|
||||
"amp": false,
|
||||
"save_all": false,
|
||||
"save_latest": true,
|
||||
"save_best": true,
|
||||
"unet_training_mask": [true]
|
||||
},
|
||||
"evaluate": {
|
||||
@@ -96,14 +93,15 @@
|
||||
},
|
||||
|
||||
"save": [{
|
||||
"save_to": "wandb"
|
||||
"save_to": "wandb",
|
||||
"save_latest_to": "latest.pth"
|
||||
}, {
|
||||
"save_to": "huggingface",
|
||||
"huggingface_repo": "Veldrovive/test_model",
|
||||
|
||||
"save_all": true,
|
||||
"save_latest": true,
|
||||
"save_best": true,
|
||||
"save_latest_to": "path/to/model_dir/latest.pth",
|
||||
"save_best_to": "path/to/model_dir/best.pth",
|
||||
"save_meta_to": "path/to/directory/for/assorted/files",
|
||||
|
||||
"save_type": "model"
|
||||
}]
|
||||
|
||||
100
configs/train_decoder_config.test.json
Normal file
100
configs/train_decoder_config.test.json
Normal file
@@ -0,0 +1,100 @@
|
||||
{
|
||||
"decoder": {
|
||||
"unets": [
|
||||
{
|
||||
"dim": 16,
|
||||
"image_embed_dim": 768,
|
||||
"cond_dim": 16,
|
||||
"channels": 3,
|
||||
"dim_mults": [1, 2, 4, 8],
|
||||
"attn_dim_head": 16,
|
||||
"attn_heads": 4,
|
||||
"self_attn": [false, true, true, true]
|
||||
}
|
||||
],
|
||||
"clip": {
|
||||
"make": "openai",
|
||||
"model": "ViT-L/14"
|
||||
},
|
||||
|
||||
"timesteps": 10,
|
||||
"image_sizes": [64],
|
||||
"channels": 3,
|
||||
"loss_type": "l2",
|
||||
"beta_schedule": ["cosine"],
|
||||
"learned_variance": true
|
||||
},
|
||||
"data": {
|
||||
"webdataset_base_url": "test_data/{}.tar",
|
||||
"num_workers": 4,
|
||||
"batch_size": 4,
|
||||
"start_shard": 0,
|
||||
"end_shard": 9,
|
||||
"shard_width": 1,
|
||||
"index_width": 1,
|
||||
"splits": {
|
||||
"train": 0.75,
|
||||
"val": 0.15,
|
||||
"test": 0.1
|
||||
},
|
||||
"shuffle_train": false,
|
||||
"resample_train": true,
|
||||
"preprocessing": {
|
||||
"RandomResizedCrop": {
|
||||
"size": [224, 224],
|
||||
"scale": [0.75, 1.0],
|
||||
"ratio": [1.0, 1.0]
|
||||
},
|
||||
"ToTensor": true
|
||||
}
|
||||
},
|
||||
"train": {
|
||||
"epochs": 1,
|
||||
"lr": 1e-16,
|
||||
"wd": 0.01,
|
||||
"max_grad_norm": 0.5,
|
||||
"save_every_n_samples": 100,
|
||||
"n_sample_images": 1,
|
||||
"device": "cpu",
|
||||
"epoch_samples": 50,
|
||||
"validation_samples": 5,
|
||||
"use_ema": true,
|
||||
"ema_beta": 0.99,
|
||||
"amp": false,
|
||||
"unet_training_mask": [true]
|
||||
},
|
||||
"evaluate": {
|
||||
"n_evaluation_samples": 2,
|
||||
"FID": {
|
||||
"feature": 64
|
||||
},
|
||||
"IS": {
|
||||
"feature": 64,
|
||||
"splits": 10
|
||||
},
|
||||
"KID": {
|
||||
"feature": 64,
|
||||
"subset_size": 2
|
||||
},
|
||||
"LPIPS": {
|
||||
"net_type": "vgg",
|
||||
"reduction": "mean"
|
||||
}
|
||||
},
|
||||
"tracker": {
|
||||
"overwrite_data_path": true,
|
||||
|
||||
"log": {
|
||||
"log_type": "console"
|
||||
},
|
||||
|
||||
"load": {
|
||||
"load_from": null
|
||||
},
|
||||
|
||||
"save": [{
|
||||
"save_to": "local",
|
||||
"save_latest_to": "latest.pth"
|
||||
}]
|
||||
}
|
||||
}
|
||||
@@ -1,18 +1,14 @@
|
||||
{
|
||||
"prior": {
|
||||
"clip": {
|
||||
"make": "x-clip",
|
||||
"model": "ViT-L/14",
|
||||
"base_model_kwargs": {
|
||||
"dim_text": 768,
|
||||
"dim_image": 768,
|
||||
"dim_latent": 768
|
||||
}
|
||||
"make": "openai",
|
||||
"model": "ViT-L/14"
|
||||
},
|
||||
"net": {
|
||||
"dim": 768,
|
||||
"depth": 12,
|
||||
"num_timesteps": 1000,
|
||||
"max_text_len": 77,
|
||||
"num_time_embeds": 1,
|
||||
"num_image_embeds": 1,
|
||||
"num_text_embeds": 1,
|
||||
@@ -20,8 +16,8 @@
|
||||
"heads": 12,
|
||||
"ff_mult": 4,
|
||||
"norm_out": true,
|
||||
"attn_dropout": 0.0,
|
||||
"ff_dropout": 0.0,
|
||||
"attn_dropout": 0.05,
|
||||
"ff_dropout": 0.05,
|
||||
"final_proj": true,
|
||||
"normformer": true,
|
||||
"rotary_emb": true
|
||||
@@ -30,6 +26,7 @@
|
||||
"image_size": 224,
|
||||
"image_channels": 3,
|
||||
"timesteps": 1000,
|
||||
"sample_timesteps": 64,
|
||||
"cond_drop_prob": 0.1,
|
||||
"loss_type": "l2",
|
||||
"predict_x_start": true,
|
||||
@@ -37,34 +34,48 @@
|
||||
"condition_on_text_encodings": true
|
||||
},
|
||||
"data": {
|
||||
"image_url": "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/",
|
||||
"text_url": "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/",
|
||||
"meta_url": "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/laion2B-en-metadata/",
|
||||
"batch_size": 256,
|
||||
"batch_size": 128,
|
||||
"num_data_points": 100000,
|
||||
"eval_every_seconds": 1600,
|
||||
"image_url": "<path to your images>",
|
||||
"meta_url": "<path to your metadata>",
|
||||
"splits": {
|
||||
"train": 0.9,
|
||||
"val": 1e-7,
|
||||
"test": 0.0999999
|
||||
"train": 0.8,
|
||||
"val": 0.1,
|
||||
"test": 0.1
|
||||
}
|
||||
},
|
||||
"train": {
|
||||
"epochs": 1,
|
||||
"epochs": 5,
|
||||
"lr": 1.1e-4,
|
||||
"wd": 6.02e-2,
|
||||
"max_grad_norm": 0.5,
|
||||
"use_ema": true,
|
||||
"ema_beta": 0.9999,
|
||||
"ema_update_after_step": 50,
|
||||
"warmup_steps": 50,
|
||||
"amp": false,
|
||||
"save_every": 10000
|
||||
},
|
||||
"load": {
|
||||
"source": null,
|
||||
"resume": false
|
||||
"save_every_seconds": 3600,
|
||||
"eval_timesteps": [64, 1000],
|
||||
"random_seed": 84513
|
||||
},
|
||||
"tracker": {
|
||||
"tracker_type": "wandb",
|
||||
"data_path": "./prior_checkpoints",
|
||||
"wandb_entity": "laion",
|
||||
"wandb_project": "diffusion-prior",
|
||||
"verbose": true
|
||||
"data_path": ".prior",
|
||||
"overwrite_data_path": true,
|
||||
"log": {
|
||||
"log_type": "wandb",
|
||||
"wandb_entity": "<your wandb username>",
|
||||
"wandb_project": "prior_debugging",
|
||||
"wandb_resume": false,
|
||||
"verbose": true
|
||||
},
|
||||
"save": [
|
||||
{
|
||||
"save_to": "local",
|
||||
"save_type": "checkpoint",
|
||||
"save_latest_to": ".prior/latest_checkpoint.pth",
|
||||
"save_best_to": ".prior/best_checkpoint.pth"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,7 @@
|
||||
import os
|
||||
import webdataset as wds
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
import numpy as np
|
||||
import fsspec
|
||||
import shutil
|
||||
@@ -255,7 +256,7 @@ def create_image_embedding_dataloader(
|
||||
)
|
||||
if shuffle_num is not None and shuffle_num > 0:
|
||||
ds.shuffle(1000)
|
||||
return wds.WebLoader(
|
||||
return DataLoader(
|
||||
ds,
|
||||
num_workers=num_workers,
|
||||
batch_size=batch_size,
|
||||
|
||||
@@ -67,6 +67,15 @@ class PriorEmbeddingDataset(IterableDataset):
|
||||
def __str__(self):
|
||||
return f"<PriorEmbeddingDataset: start: {self.start}, stop: {self.stop}, len: {self.__len__()}>"
|
||||
|
||||
def set_start(self, start):
|
||||
"""
|
||||
Adjust the starting point within the reader, useful for resuming an epoch
|
||||
"""
|
||||
self.start = start
|
||||
|
||||
def get_start(self):
|
||||
return self.start
|
||||
|
||||
def get_sample(self):
|
||||
"""
|
||||
pre-proocess data from either reader into a common format
|
||||
|
||||
@@ -1,15 +1,18 @@
|
||||
import urllib.request
|
||||
import os
|
||||
import json
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
from itertools import zip_longest
|
||||
from typing import Optional, List, Union
|
||||
from typing import Any, Optional, List, Union
|
||||
from pydantic import BaseModel
|
||||
|
||||
import torch
|
||||
|
||||
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
|
||||
from dalle2_pytorch.utils import import_or_print_error
|
||||
from dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer
|
||||
from dalle2_pytorch.version import __version__
|
||||
from packaging import version
|
||||
|
||||
# constants
|
||||
|
||||
@@ -20,16 +23,6 @@ DEFAULT_DATA_PATH = './.tracker-data'
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
# load file functions
|
||||
|
||||
def load_wandb_file(run_path, file_path, **kwargs):
|
||||
wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb recall function')
|
||||
file_reference = wandb.restore(file_path, run_path=run_path)
|
||||
return file_reference.name
|
||||
|
||||
def load_local_file(file_path, **kwargs):
|
||||
return file_path
|
||||
|
||||
class BaseLogger:
|
||||
"""
|
||||
An abstract class representing an object that can log data.
|
||||
@@ -37,14 +30,17 @@ class BaseLogger:
|
||||
data_path (str): A file path for storing temporary data.
|
||||
verbose (bool): Whether of not to always print logs to the console.
|
||||
"""
|
||||
def __init__(self, data_path: str, verbose: bool = False, **kwargs):
|
||||
def __init__(self, data_path: str, resume: bool = False, auto_resume: bool = False, verbose: bool = False, **kwargs):
|
||||
self.data_path = Path(data_path)
|
||||
self.resume = resume
|
||||
self.auto_resume = auto_resume
|
||||
self.verbose = verbose
|
||||
|
||||
def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:
|
||||
"""
|
||||
Initializes the logger.
|
||||
Errors if the logger is invalid.
|
||||
full_config is the config file dict while extra_config is anything else from the script that is not defined the config file.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -60,6 +56,14 @@ class BaseLogger:
|
||||
def log_error(self, error_string, **kwargs) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_resume_data(self, **kwargs) -> dict:
|
||||
"""
|
||||
Sets tracker attributes that along with { "resume": True } will be used to resume training.
|
||||
It is assumed that after init is called this data will be complete.
|
||||
If the logger does not have any resume functionality, it should return an empty dict.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
class ConsoleLogger(BaseLogger):
|
||||
def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:
|
||||
print("Logging to console")
|
||||
@@ -76,6 +80,9 @@ class ConsoleLogger(BaseLogger):
|
||||
def log_error(self, error_string, **kwargs) -> None:
|
||||
print(error_string)
|
||||
|
||||
def get_resume_data(self, **kwargs) -> dict:
|
||||
return {}
|
||||
|
||||
class WandbLogger(BaseLogger):
|
||||
"""
|
||||
Logs to a wandb run.
|
||||
@@ -85,7 +92,6 @@ class WandbLogger(BaseLogger):
|
||||
wandb_project (str): The wandb project to log to.
|
||||
wandb_run_id (str): The wandb run id to resume.
|
||||
wandb_run_name (str): The wandb run name to use.
|
||||
wandb_resume (bool): Whether to resume a wandb run.
|
||||
"""
|
||||
def __init__(self,
|
||||
data_path: str,
|
||||
@@ -93,7 +99,6 @@ class WandbLogger(BaseLogger):
|
||||
wandb_project: str,
|
||||
wandb_run_id: Optional[str] = None,
|
||||
wandb_run_name: Optional[str] = None,
|
||||
wandb_resume: bool = False,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(data_path, **kwargs)
|
||||
@@ -101,7 +106,6 @@ class WandbLogger(BaseLogger):
|
||||
self.project = wandb_project
|
||||
self.run_id = wandb_run_id
|
||||
self.run_name = wandb_run_name
|
||||
self.resume = wandb_resume
|
||||
|
||||
def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:
|
||||
assert self.entity is not None, "wandb_entity must be specified for wandb logger"
|
||||
@@ -149,6 +153,14 @@ class WandbLogger(BaseLogger):
|
||||
print(error_string)
|
||||
self.wandb.log({"error": error_string, **kwargs}, step=step)
|
||||
|
||||
def get_resume_data(self, **kwargs) -> dict:
|
||||
# In order to resume, we need wandb_entity, wandb_project, and wandb_run_id
|
||||
return {
|
||||
"entity": self.entity,
|
||||
"project": self.project,
|
||||
"run_id": self.wandb.run.id
|
||||
}
|
||||
|
||||
logger_type_map = {
|
||||
'console': ConsoleLogger,
|
||||
'wandb': WandbLogger,
|
||||
@@ -168,8 +180,9 @@ class BaseLoader:
|
||||
Parameters:
|
||||
data_path (str): A file path for storing temporary data.
|
||||
"""
|
||||
def __init__(self, data_path: str, **kwargs):
|
||||
def __init__(self, data_path: str, only_auto_resume: bool = False, **kwargs):
|
||||
self.data_path = Path(data_path)
|
||||
self.only_auto_resume = only_auto_resume
|
||||
|
||||
def init(self, logger: BaseLogger, **kwargs) -> None:
|
||||
raise NotImplementedError
|
||||
@@ -213,7 +226,7 @@ class LocalLoader(BaseLoader):
|
||||
|
||||
def init(self, logger: BaseLogger, **kwargs) -> None:
|
||||
# Makes sure the file exists to be loaded
|
||||
if not self.file_path.exists():
|
||||
if not self.file_path.exists() and not self.only_auto_resume:
|
||||
raise FileNotFoundError(f'Model not found at {self.file_path}')
|
||||
|
||||
def recall(self) -> dict:
|
||||
@@ -262,9 +275,9 @@ def create_loader(loader_type: str, data_path: str, **kwargs) -> BaseLoader:
|
||||
class BaseSaver:
|
||||
def __init__(self,
|
||||
data_path: str,
|
||||
save_latest_to: Optional[Union[str, bool]] = 'latest.pth',
|
||||
save_best_to: Optional[Union[str, bool]] = 'best.pth',
|
||||
save_meta_to: str = './',
|
||||
save_latest_to: Optional[Union[str, bool]] = None,
|
||||
save_best_to: Optional[Union[str, bool]] = None,
|
||||
save_meta_to: Optional[str] = None,
|
||||
save_type: str = 'checkpoint',
|
||||
**kwargs
|
||||
):
|
||||
@@ -274,10 +287,10 @@ class BaseSaver:
|
||||
self.save_best_to = save_best_to
|
||||
self.saving_best = save_best_to is not None and save_best_to is not False
|
||||
self.save_meta_to = save_meta_to
|
||||
self.saving_meta = save_meta_to is not None
|
||||
self.save_type = save_type
|
||||
assert save_type in ['checkpoint', 'model'], '`save_type` must be one of `checkpoint` or `model`'
|
||||
assert self.save_meta_to is not None, '`save_meta_to` must be provided'
|
||||
assert self.saving_latest or self.saving_best, '`save_latest_to` or `save_best_to` must be provided'
|
||||
assert self.saving_latest or self.saving_best or self.saving_meta, 'At least one saving option must be specified'
|
||||
|
||||
def init(self, logger: BaseLogger, **kwargs) -> None:
|
||||
raise NotImplementedError
|
||||
@@ -304,6 +317,10 @@ class LocalSaver(BaseSaver):
|
||||
def save_file(self, local_path: str, save_path: str, **kwargs) -> None:
|
||||
# Copy the file to save_path
|
||||
save_path_file_name = Path(save_path).name
|
||||
# Make sure parent directory exists
|
||||
save_path_parent = Path(save_path).parent
|
||||
if not save_path_parent.exists():
|
||||
save_path_parent.mkdir(parents=True)
|
||||
print(f"Saving {save_path_file_name} {self.save_type} to local path {save_path}")
|
||||
shutil.copy(local_path, save_path)
|
||||
|
||||
@@ -385,11 +402,7 @@ class Tracker:
|
||||
def __init__(self, data_path: Optional[str] = DEFAULT_DATA_PATH, overwrite_data_path: bool = False, dummy_mode: bool = False):
|
||||
self.data_path = Path(data_path)
|
||||
if not dummy_mode:
|
||||
if overwrite_data_path:
|
||||
if self.data_path.exists():
|
||||
shutil.rmtree(self.data_path)
|
||||
self.data_path.mkdir(parents=True)
|
||||
else:
|
||||
if not overwrite_data_path:
|
||||
assert not self.data_path.exists(), f'Data path {self.data_path} already exists. Set overwrite_data_path to True to overwrite.'
|
||||
if not self.data_path.exists():
|
||||
self.data_path.mkdir(parents=True)
|
||||
@@ -398,7 +411,51 @@ class Tracker:
|
||||
self.savers: List[BaseSaver]= []
|
||||
self.dummy_mode = dummy_mode
|
||||
|
||||
def _load_auto_resume(self) -> bool:
|
||||
# If the file does not exist, we return False. If autoresume is enabled we print a warning so that the user can know that this is the first run.
|
||||
if not self.auto_resume_path.exists():
|
||||
if self.logger.auto_resume:
|
||||
print("Auto_resume is enabled but no auto_resume.json file exists. Assuming this is the first run.")
|
||||
return False
|
||||
|
||||
# Now we know that the autoresume file exists, but if we are not auto resuming we should remove it so that we don't accidentally load it next time
|
||||
if not self.logger.auto_resume:
|
||||
print(f'Removing auto_resume.json because auto_resume is not enabled in the config')
|
||||
self.auto_resume_path.unlink()
|
||||
return False
|
||||
|
||||
# Otherwise we read the json into a dictionary will will override parts of logger.__dict__
|
||||
with open(self.auto_resume_path, 'r') as f:
|
||||
auto_resume_dict = json.load(f)
|
||||
# Check if the logger is of the same type as the autoresume save
|
||||
if auto_resume_dict["logger_type"] != self.logger.__class__.__name__:
|
||||
raise Exception(f'The logger type in the auto_resume file is {auto_resume_dict["logger_type"]} but the current logger is {self.logger.__class__.__name__}. Either use the original logger type, set `auto_resume` to `False`, or delete your existing tracker-data folder.')
|
||||
# Then we are ready to override the logger with the autoresume save
|
||||
self.logger.__dict__["resume"] = True
|
||||
print(f"Updating {self.logger.__dict__} with {auto_resume_dict}")
|
||||
self.logger.__dict__.update(auto_resume_dict)
|
||||
return True
|
||||
|
||||
def _save_auto_resume(self):
|
||||
# Gets the autoresume dict from the logger and adds "logger_type" to it then saves it to the auto_resume file
|
||||
auto_resume_dict = self.logger.get_resume_data()
|
||||
auto_resume_dict['logger_type'] = self.logger.__class__.__name__
|
||||
with open(self.auto_resume_path, 'w') as f:
|
||||
json.dump(auto_resume_dict, f)
|
||||
|
||||
def init(self, full_config: BaseModel, extra_config: dict):
|
||||
self.auto_resume_path = self.data_path / 'auto_resume.json'
|
||||
# Check for resuming the run
|
||||
self.did_auto_resume = self._load_auto_resume()
|
||||
if self.did_auto_resume:
|
||||
print(f'\n\nWARNING: RUN HAS BEEN AUTO-RESUMED WITH THE LOGGER TYPE {self.logger.__class__.__name__}.\nIf this was not your intention, stop this run and set `auto_resume` to `False` in the config.\n\n')
|
||||
print(f"New logger config: {self.logger.__dict__}")
|
||||
|
||||
self.save_metadata = dict(
|
||||
version = version.parse(__version__)
|
||||
) # Data that will be saved alongside the checkpoint or model
|
||||
self.blacklisted_checkpoint_metadata_keys = ['scaler', 'optimizer', 'model', 'version', 'step', 'steps'] # These keys would cause us to error if we try to save them as metadata
|
||||
|
||||
assert self.logger is not None, '`logger` must be set before `init` is called'
|
||||
if self.dummy_mode:
|
||||
# The only thing we need is a loader
|
||||
@@ -406,12 +463,17 @@ class Tracker:
|
||||
self.loader.init(self.logger)
|
||||
return
|
||||
assert len(self.savers) > 0, '`savers` must be set before `init` is called'
|
||||
|
||||
self.logger.init(full_config, extra_config)
|
||||
if self.loader is not None:
|
||||
self.loader.init(self.logger)
|
||||
for saver in self.savers:
|
||||
saver.init(self.logger)
|
||||
|
||||
if self.logger.auto_resume:
|
||||
# Then we need to save the autoresume file. It is assumed after logger.init is called that the logger is ready to be saved.
|
||||
self._save_auto_resume()
|
||||
|
||||
def add_logger(self, logger: BaseLogger):
|
||||
self.logger = logger
|
||||
|
||||
@@ -442,8 +504,15 @@ class Tracker:
|
||||
# Save the config under config_name in the root folder of data_path
|
||||
shutil.copy(current_config_path, self.data_path / config_name)
|
||||
for saver in self.savers:
|
||||
remote_path = Path(saver.save_meta_to) / config_name
|
||||
saver.save_file(current_config_path, str(remote_path))
|
||||
if saver.saving_meta:
|
||||
remote_path = Path(saver.save_meta_to) / config_name
|
||||
saver.save_file(current_config_path, str(remote_path))
|
||||
|
||||
def add_save_metadata(self, state_dict_key: str, metadata: Any):
|
||||
"""
|
||||
Adds a new piece of metadata that will be saved along with the model or decoder.
|
||||
"""
|
||||
self.save_metadata[state_dict_key] = metadata
|
||||
|
||||
def _save_state_dict(self, trainer: Union[DiffusionPriorTrainer, DecoderTrainer], save_type: str, file_path: str, **kwargs) -> Path:
|
||||
"""
|
||||
@@ -453,24 +522,38 @@ class Tracker:
|
||||
"""
|
||||
assert save_type in ['checkpoint', 'model']
|
||||
if save_type == 'checkpoint':
|
||||
trainer.save(file_path, overwrite=True, **kwargs)
|
||||
# Create a metadata dict without the blacklisted keys so we do not error when we create the state dict
|
||||
metadata = {k: v for k, v in self.save_metadata.items() if k not in self.blacklisted_checkpoint_metadata_keys}
|
||||
trainer.save(file_path, overwrite=True, **kwargs, **metadata)
|
||||
elif save_type == 'model':
|
||||
if isinstance(trainer, DiffusionPriorTrainer):
|
||||
prior = trainer.ema_diffusion_prior.ema_model if trainer.use_ema else trainer.diffusion_prior
|
||||
state_dict = trainer.unwrap_model(prior).state_dict()
|
||||
torch.save(state_dict, file_path)
|
||||
prior: DiffusionPrior = trainer.accelerator.unwrap_model(prior)
|
||||
# Remove CLIP if it is part of the model
|
||||
original_clip = prior.clip
|
||||
prior.clip = None
|
||||
model_state_dict = prior.state_dict()
|
||||
prior.clip = original_clip
|
||||
elif isinstance(trainer, DecoderTrainer):
|
||||
decoder = trainer.accelerator.unwrap_model(trainer.decoder)
|
||||
decoder: Decoder = trainer.accelerator.unwrap_model(trainer.decoder)
|
||||
# Remove CLIP if it is part of the model
|
||||
original_clip = decoder.clip
|
||||
decoder.clip = None
|
||||
if trainer.use_ema:
|
||||
trainable_unets = decoder.unets
|
||||
decoder.unets = trainer.unets # Swap EMA unets in
|
||||
state_dict = decoder.state_dict()
|
||||
model_state_dict = decoder.state_dict()
|
||||
decoder.unets = trainable_unets # Swap back
|
||||
else:
|
||||
state_dict = decoder.state_dict()
|
||||
torch.save(state_dict, file_path)
|
||||
model_state_dict = decoder.state_dict()
|
||||
decoder.clip = original_clip
|
||||
else:
|
||||
raise NotImplementedError('Saving this type of model with EMA mode enabled is not yet implemented. Actually, how did you get here?')
|
||||
state_dict = {
|
||||
**self.save_metadata,
|
||||
'model': model_state_dict
|
||||
}
|
||||
torch.save(state_dict, file_path)
|
||||
return Path(file_path)
|
||||
|
||||
def save(self, trainer, is_best: bool, is_latest: bool, **kwargs):
|
||||
@@ -503,11 +586,16 @@ class Tracker:
|
||||
self.logger.log_error(f'Error saving checkpoint: {e}', **kwargs)
|
||||
print(f'Error saving checkpoint: {e}')
|
||||
|
||||
@property
|
||||
def can_recall(self):
|
||||
# Defines whether a recall can be performed.
|
||||
return self.loader is not None and (not self.loader.only_auto_resume or self.did_auto_resume)
|
||||
|
||||
def recall(self):
|
||||
if self.loader is not None:
|
||||
if self.can_recall:
|
||||
return self.loader.recall()
|
||||
else:
|
||||
raise ValueError('No loader specified')
|
||||
raise ValueError('Tried to recall, but no loader was set or auto-resume was not performed.')
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
from torchvision import transforms as T
|
||||
from pydantic import BaseModel, validator, root_validator
|
||||
from typing import List, Iterable, Optional, Union, Tuple, Dict, Any
|
||||
from typing import List, Optional, Union, Tuple, Dict, Any, TypeVar
|
||||
|
||||
from x_clip import CLIP as XCLIP
|
||||
from coca_pytorch import CoCa
|
||||
@@ -25,11 +25,9 @@ def exists(val):
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
def ListOrTuple(inner_type):
|
||||
return Union[List[inner_type], Tuple[inner_type]]
|
||||
|
||||
def SingularOrIterable(inner_type):
|
||||
return Union[inner_type, ListOrTuple(inner_type)]
|
||||
InnerType = TypeVar('InnerType')
|
||||
ListOrTuple = Union[List[InnerType], Tuple[InnerType]]
|
||||
SingularOrIterable = Union[InnerType, ListOrTuple[InnerType]]
|
||||
|
||||
# general pydantic classes
|
||||
|
||||
@@ -47,6 +45,8 @@ class TrainSplitConfig(BaseModel):
|
||||
|
||||
class TrackerLogConfig(BaseModel):
|
||||
log_type: str = 'console'
|
||||
resume: bool = False # For logs that are saved to unique locations, resume a previous run
|
||||
auto_resume: bool = False # If the process crashes and restarts, resume from the run that crashed
|
||||
verbose: bool = False
|
||||
|
||||
class Config:
|
||||
@@ -59,6 +59,7 @@ class TrackerLogConfig(BaseModel):
|
||||
|
||||
class TrackerLoadConfig(BaseModel):
|
||||
load_from: Optional[str] = None
|
||||
only_auto_resume: bool = False # Only attempt to load if the logger is auto-resuming
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
@@ -126,6 +127,7 @@ class AdapterConfig(BaseModel):
|
||||
class DiffusionPriorNetworkConfig(BaseModel):
|
||||
dim: int
|
||||
depth: int
|
||||
max_text_len: int = None
|
||||
num_timesteps: int = None
|
||||
num_time_embeds: int = 1
|
||||
num_image_embeds: int = 1
|
||||
@@ -133,6 +135,7 @@ class DiffusionPriorNetworkConfig(BaseModel):
|
||||
dim_head: int = 64
|
||||
heads: int = 8
|
||||
ff_mult: int = 4
|
||||
norm_in: bool = False
|
||||
norm_out: bool = True
|
||||
attn_dropout: float = 0.
|
||||
ff_dropout: float = 0.
|
||||
@@ -140,6 +143,9 @@ class DiffusionPriorNetworkConfig(BaseModel):
|
||||
normformer: bool = False
|
||||
rotary_emb: bool = True
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
def create(self):
|
||||
kwargs = self.dict()
|
||||
return DiffusionPriorNetwork(**kwargs)
|
||||
@@ -151,6 +157,7 @@ class DiffusionPriorConfig(BaseModel):
|
||||
image_size: int
|
||||
image_channels: int = 3
|
||||
timesteps: int = 1000
|
||||
sample_timesteps: Optional[int] = None
|
||||
cond_drop_prob: float = 0.
|
||||
loss_type: str = 'l2'
|
||||
predict_x_start: bool = True
|
||||
@@ -181,23 +188,26 @@ class DiffusionPriorTrainConfig(BaseModel):
|
||||
use_ema: bool = True
|
||||
ema_beta: float = 0.99
|
||||
amp: bool = False
|
||||
save_every: int = 10000 # what steps to save on
|
||||
warmup_steps: int = None # number of warmup steps
|
||||
save_every_seconds: int = 3600 # how often to save
|
||||
eval_timesteps: List[int] = [64] # which sampling timesteps to evaluate with
|
||||
best_validation_loss: float = 1e9 # the current best valudation loss observed
|
||||
current_epoch: int = 0 # the current epoch
|
||||
num_samples_seen: int = 0 # the current number of samples seen
|
||||
random_seed: int = 0 # manual seed for torch
|
||||
|
||||
class DiffusionPriorDataConfig(BaseModel):
|
||||
image_url: str # path to embeddings folder
|
||||
meta_url: str # path to metadata (captions) for images
|
||||
splits: TrainSplitConfig
|
||||
batch_size: int = 64
|
||||
|
||||
class DiffusionPriorLoadConfig(BaseModel):
|
||||
source: str = None
|
||||
resume: bool = False
|
||||
image_url: str # path to embeddings folder
|
||||
meta_url: str # path to metadata (captions) for images
|
||||
splits: TrainSplitConfig # define train, validation, test splits for your dataset
|
||||
batch_size: int # per-gpu batch size used to train the model
|
||||
num_data_points: int = 25e7 # total number of datapoints to train on
|
||||
eval_every_seconds: int = 3600 # validation statistics will be performed this often
|
||||
|
||||
class TrainDiffusionPriorConfig(BaseModel):
|
||||
prior: DiffusionPriorConfig
|
||||
data: DiffusionPriorDataConfig
|
||||
train: DiffusionPriorTrainConfig
|
||||
load: DiffusionPriorLoadConfig
|
||||
tracker: TrackerConfig
|
||||
|
||||
@classmethod
|
||||
@@ -210,29 +220,31 @@ class TrainDiffusionPriorConfig(BaseModel):
|
||||
|
||||
class UnetConfig(BaseModel):
|
||||
dim: int
|
||||
dim_mults: ListOrTuple(int)
|
||||
dim_mults: ListOrTuple[int]
|
||||
image_embed_dim: int = None
|
||||
text_embed_dim: int = None
|
||||
cond_on_text_encodings: bool = None
|
||||
cond_dim: int = None
|
||||
channels: int = 3
|
||||
self_attn: ListOrTuple(int)
|
||||
self_attn: ListOrTuple[int]
|
||||
attn_dim_head: int = 32
|
||||
attn_heads: int = 16
|
||||
init_cross_embed: bool = True
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
class DecoderConfig(BaseModel):
|
||||
unets: ListOrTuple(UnetConfig)
|
||||
unets: ListOrTuple[UnetConfig]
|
||||
image_size: int = None
|
||||
image_sizes: ListOrTuple(int) = None
|
||||
image_sizes: ListOrTuple[int] = None
|
||||
clip: Optional[AdapterConfig] # The clip model to use if embeddings are not provided
|
||||
channels: int = 3
|
||||
timesteps: int = 1000
|
||||
sample_timesteps: Optional[SingularOrIterable[Optional[int]]] = None
|
||||
loss_type: str = 'l2'
|
||||
beta_schedule: ListOrTuple(str) = 'cosine'
|
||||
learned_variance: bool = True
|
||||
beta_schedule: ListOrTuple[str] = None # None means all cosine
|
||||
learned_variance: SingularOrIterable[bool] = True
|
||||
image_cond_drop_prob: float = 0.1
|
||||
text_cond_drop_prob: float = 0.5
|
||||
|
||||
@@ -291,19 +303,22 @@ class DecoderDataConfig(BaseModel):
|
||||
|
||||
class DecoderTrainConfig(BaseModel):
|
||||
epochs: int = 20
|
||||
lr: SingularOrIterable(float) = 1e-4
|
||||
wd: SingularOrIterable(float) = 0.01
|
||||
lr: SingularOrIterable[float] = 1e-4
|
||||
wd: SingularOrIterable[float] = 0.01
|
||||
warmup_steps: Optional[SingularOrIterable[int]] = None
|
||||
find_unused_parameters: bool = True
|
||||
max_grad_norm: SingularOrIterable(float) = 0.5
|
||||
max_grad_norm: SingularOrIterable[float] = 0.5
|
||||
save_every_n_samples: int = 100000
|
||||
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
|
||||
cond_scale: Union[float, List[float]] = 1.0
|
||||
device: str = 'cuda:0'
|
||||
epoch_samples: int = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
|
||||
validation_samples: int = None # Same as above but for validation.
|
||||
save_immediately: bool = False
|
||||
use_ema: bool = True
|
||||
ema_beta: float = 0.999
|
||||
amp: bool = False
|
||||
unet_training_mask: ListOrTuple(bool) = None # If None, use all unets
|
||||
unet_training_mask: ListOrTuple[bool] = None # If None, use all unets
|
||||
|
||||
class DecoderEvaluateConfig(BaseModel):
|
||||
n_evaluation_samples: int = 1000
|
||||
@@ -312,12 +327,6 @@ class DecoderEvaluateConfig(BaseModel):
|
||||
KID: Dict[str, Any] = None
|
||||
LPIPS: Dict[str, Any] = None
|
||||
|
||||
class DecoderLoadConfig(BaseModel):
|
||||
source: str = None # Supports file and wandb
|
||||
run_path: str = '' # Used only if source is wandb
|
||||
file_path: str = '' # The local filepath if source is file. If source is wandb, the relative path to the model file in wandb.
|
||||
resume: bool = False # If using wandb, whether to resume the run
|
||||
|
||||
class TrainDecoderConfig(BaseModel):
|
||||
decoder: DecoderConfig
|
||||
data: DecoderDataConfig
|
||||
|
||||
@@ -3,10 +3,13 @@ import copy
|
||||
from pathlib import Path
|
||||
from math import ceil
|
||||
from functools import partial, wraps
|
||||
from contextlib import nullcontext
|
||||
from collections.abc import Iterable
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR
|
||||
from torch.cuda.amp import autocast, GradScaler
|
||||
|
||||
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
|
||||
@@ -14,9 +17,11 @@ from dalle2_pytorch.optimizer import get_optimizer
|
||||
from dalle2_pytorch.version import __version__
|
||||
from packaging import version
|
||||
|
||||
import pytorch_warmup as warmup
|
||||
|
||||
from ema_pytorch import EMA
|
||||
|
||||
from accelerate import Accelerator
|
||||
from accelerate import Accelerator, DistributedType
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -71,6 +76,7 @@ def cast_torch_tensor(fn):
|
||||
def inner(model, *args, **kwargs):
|
||||
device = kwargs.pop('_device', next(model.parameters()).device)
|
||||
cast_device = kwargs.pop('_cast_device', True)
|
||||
cast_deepspeed_precision = kwargs.pop('_cast_deepspeed_precision', True)
|
||||
|
||||
kwargs_keys = kwargs.keys()
|
||||
all_args = (*args, *kwargs.values())
|
||||
@@ -80,6 +86,21 @@ def cast_torch_tensor(fn):
|
||||
if cast_device:
|
||||
all_args = tuple(map(lambda t: t.to(device) if exists(t) and isinstance(t, torch.Tensor) else t, all_args))
|
||||
|
||||
if cast_deepspeed_precision:
|
||||
try:
|
||||
accelerator = model.accelerator
|
||||
if accelerator is not None and accelerator.distributed_type == DistributedType.DEEPSPEED:
|
||||
cast_type_map = {
|
||||
"fp16": torch.half,
|
||||
"bf16": torch.bfloat16,
|
||||
"no": torch.float
|
||||
}
|
||||
precision_type = cast_type_map[accelerator.mixed_precision]
|
||||
all_args = tuple(map(lambda t: t.to(precision_type) if exists(t) and isinstance(t, torch.Tensor) else t, all_args))
|
||||
except AttributeError:
|
||||
# Then this model doesn't have an accelerator
|
||||
pass
|
||||
|
||||
args, kwargs_values = all_args[:split_kwargs_index], all_args[split_kwargs_index:]
|
||||
kwargs = dict(tuple(zip(kwargs_keys, kwargs_values)))
|
||||
|
||||
@@ -153,37 +174,58 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
diffusion_prior,
|
||||
accelerator = None,
|
||||
use_ema = True,
|
||||
lr = 3e-4,
|
||||
wd = 1e-2,
|
||||
eps = 1e-6,
|
||||
max_grad_norm = None,
|
||||
amp = False,
|
||||
group_wd_params = True,
|
||||
device = None,
|
||||
accelerator = None,
|
||||
warmup_steps = None,
|
||||
cosine_decay_max_steps = None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(diffusion_prior, DiffusionPrior)
|
||||
assert not exists(accelerator) or isinstance(accelerator, Accelerator)
|
||||
assert exists(accelerator) or exists(device), "You must supply some method of obtaining a device."
|
||||
|
||||
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
|
||||
accelerator_kwargs, kwargs = groupby_prefix_and_trim('accelerator_', kwargs)
|
||||
|
||||
if not exists(accelerator):
|
||||
accelerator = Accelerator(**accelerator_kwargs)
|
||||
|
||||
# assign some helpful member vars
|
||||
|
||||
self.accelerator = accelerator
|
||||
self.device = accelerator.device if exists(accelerator) else device
|
||||
self.text_conditioned = diffusion_prior.condition_on_text_encodings
|
||||
|
||||
# setting the device
|
||||
|
||||
self.device = accelerator.device
|
||||
diffusion_prior.to(self.device)
|
||||
|
||||
# save model
|
||||
|
||||
self.diffusion_prior = diffusion_prior
|
||||
|
||||
# optimizer and mixed precision stuff
|
||||
# mixed precision checks
|
||||
|
||||
self.amp = amp
|
||||
if (
|
||||
exists(self.accelerator)
|
||||
and self.accelerator.distributed_type == DistributedType.DEEPSPEED
|
||||
and self.diffusion_prior.clip is not None
|
||||
):
|
||||
# Then we need to make sure clip is using the correct precision or else deepspeed will error
|
||||
cast_type_map = {
|
||||
"fp16": torch.half,
|
||||
"bf16": torch.bfloat16,
|
||||
"no": torch.float
|
||||
}
|
||||
precision_type = cast_type_map[accelerator.mixed_precision]
|
||||
assert precision_type == torch.float, "DeepSpeed currently only supports float32 precision when using on the fly embedding generation from clip"
|
||||
self.diffusion_prior.clip.to(precision_type)
|
||||
|
||||
self.scaler = GradScaler(enabled = amp)
|
||||
# optimizer stuff
|
||||
|
||||
self.optim_kwargs = dict(lr=lr, wd=wd, eps=eps, group_wd_params=group_wd_params)
|
||||
|
||||
@@ -193,16 +235,23 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if exists(cosine_decay_max_steps):
|
||||
self.scheduler = CosineAnnealingLR(optimizer, T_max = cosine_decay_max_steps)
|
||||
else:
|
||||
self.scheduler = LambdaLR(self.optimizer, lr_lambda = lambda _: 1.0)
|
||||
|
||||
self.warmup_scheduler = warmup.LinearWarmup(self.optimizer, warmup_period = warmup_steps) if exists(warmup_steps) else None
|
||||
|
||||
# distribute the model if using HFA
|
||||
if exists(self.accelerator):
|
||||
self.diffusion_prior, self.optimizer = self.accelerator.prepare(self.diffusion_prior, self.optimizer)
|
||||
|
||||
self.diffusion_prior, self.optimizer, self.scheduler = self.accelerator.prepare(self.diffusion_prior, self.optimizer, self.scheduler)
|
||||
|
||||
# exponential moving average stuff
|
||||
|
||||
self.use_ema = use_ema
|
||||
|
||||
if self.use_ema:
|
||||
self.ema_diffusion_prior = EMA(self.unwrap_model(self.diffusion_prior), **ema_kwargs)
|
||||
self.ema_diffusion_prior = EMA(self.accelerator.unwrap_model(self.diffusion_prior), **ema_kwargs)
|
||||
|
||||
# gradient clipping if needed
|
||||
|
||||
@@ -210,66 +259,27 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
|
||||
# track steps internally
|
||||
|
||||
self.register_buffer('step', torch.tensor([0]))
|
||||
|
||||
# accelerator wrappers
|
||||
|
||||
def print(self, msg):
|
||||
if exists(self.accelerator):
|
||||
self.accelerator.print(msg)
|
||||
else:
|
||||
print(msg)
|
||||
|
||||
def unwrap_model(self, model):
|
||||
if exists(self.accelerator):
|
||||
return self.accelerator.unwrap_model(model)
|
||||
else:
|
||||
return model
|
||||
|
||||
def wait_for_everyone(self):
|
||||
if exists(self.accelerator):
|
||||
self.accelerator.wait_for_everyone()
|
||||
|
||||
def is_main_process(self):
|
||||
if exists(self.accelerator):
|
||||
return self.accelerator.is_main_process
|
||||
else:
|
||||
return True
|
||||
|
||||
def clip_grad_norm_(self, *args):
|
||||
if exists(self.accelerator):
|
||||
return self.accelerator.clip_grad_norm_(*args)
|
||||
else:
|
||||
return torch.nn.utils.clip_grad_norm_(*args)
|
||||
|
||||
def backprop(self, x):
|
||||
if exists(self.accelerator):
|
||||
self.accelerator.backward(x)
|
||||
else:
|
||||
try:
|
||||
x.backward()
|
||||
except Exception as e:
|
||||
self.print(f"Caught error in backprop call: {e}")
|
||||
self.register_buffer('step', torch.tensor([0], device = self.device))
|
||||
|
||||
# utility
|
||||
|
||||
def save(self, path, overwrite = True, **kwargs):
|
||||
# ensure we sync gradients before continuing
|
||||
self.wait_for_everyone()
|
||||
|
||||
# only save on the main process
|
||||
if self.is_main_process():
|
||||
self.print(f"Saving checkpoint at step: {self.step.item()}")
|
||||
if self.accelerator.is_main_process:
|
||||
print(f"Saving checkpoint at step: {self.step.item()}")
|
||||
path = Path(path)
|
||||
assert not (path.exists() and not overwrite)
|
||||
path.parent.mkdir(parents = True, exist_ok = True)
|
||||
|
||||
# FIXME: LambdaLR can't be saved due to pickling issues
|
||||
save_obj = dict(
|
||||
scaler = self.scaler.state_dict(),
|
||||
optimizer = self.optimizer.state_dict(),
|
||||
model = self.unwrap_model(self.diffusion_prior).state_dict(), # unwrap the model from distribution if applicable
|
||||
scheduler = self.scheduler.state_dict(),
|
||||
warmup_scheduler = self.warmup_scheduler,
|
||||
model = self.accelerator.unwrap_model(self.diffusion_prior).state_dict(),
|
||||
version = version.parse(__version__),
|
||||
step = self.step.item(),
|
||||
step = self.step,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@@ -282,14 +292,14 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
|
||||
torch.save(save_obj, str(path))
|
||||
|
||||
def load(self, path, overwrite_lr = True, strict = True):
|
||||
def load(self, path_or_state, overwrite_lr = True, strict = True):
|
||||
"""
|
||||
Load a checkpoint of a diffusion prior trainer.
|
||||
|
||||
Will load the entire trainer, including the optimizer and EMA.
|
||||
|
||||
Params:
|
||||
- path (str): a path to the DiffusionPriorTrainer checkpoint file
|
||||
- path_or_state (str | torch): a path to the DiffusionPriorTrainer checkpoint file
|
||||
- overwrite_lr (bool): wether or not to overwrite the stored LR with the LR specified in the new trainer
|
||||
- strict (bool): kwarg for `torch.nn.Module.load_state_dict`, will force an exact checkpoint match
|
||||
|
||||
@@ -298,56 +308,59 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
"""
|
||||
|
||||
# all processes need to load checkpoint. no restriction here
|
||||
path = Path(path)
|
||||
assert path.exists()
|
||||
if isinstance(path_or_state, str):
|
||||
path = Path(path_or_state)
|
||||
assert path.exists()
|
||||
loaded_obj = torch.load(str(path), map_location=self.device)
|
||||
|
||||
loaded_obj = torch.load(str(path), map_location=self.device)
|
||||
elif isinstance(path_or_state, dict):
|
||||
loaded_obj = path_or_state
|
||||
|
||||
if version.parse(__version__) != loaded_obj['version']:
|
||||
print(f'loading saved diffusion prior at version {loaded_obj["version"]} but current package version is at {__version__}')
|
||||
|
||||
# unwrap the model when loading from checkpoint
|
||||
self.unwrap_model(self.diffusion_prior).load_state_dict(loaded_obj['model'], strict = strict)
|
||||
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
|
||||
self.accelerator.unwrap_model(self.diffusion_prior).load_state_dict(loaded_obj['model'], strict = strict)
|
||||
self.step.copy_(torch.ones_like(self.step, device=self.device) * loaded_obj['step'].to(self.device))
|
||||
|
||||
self.scaler.load_state_dict(loaded_obj['scaler'])
|
||||
self.optimizer.load_state_dict(loaded_obj['optimizer'])
|
||||
self.scheduler.load_state_dict(loaded_obj['scheduler'])
|
||||
|
||||
# set warmupstep
|
||||
if exists(self.warmup_scheduler):
|
||||
self.warmup_scheduler.last_step = self.step.item()
|
||||
|
||||
# ensure new lr is used if different from old one
|
||||
if overwrite_lr:
|
||||
new_lr = self.optim_kwargs["lr"]
|
||||
|
||||
self.print(f"Overriding LR to be {new_lr}")
|
||||
|
||||
for group in self.optimizer.param_groups:
|
||||
group["lr"] = new_lr
|
||||
group["lr"] = new_lr if group["lr"] > 0.0 else 0.0
|
||||
|
||||
if self.use_ema:
|
||||
assert 'ema' in loaded_obj
|
||||
self.ema_diffusion_prior.load_state_dict(loaded_obj['ema'], strict = strict)
|
||||
# below not be necessary, but I had a suspicion that this wasn't being loaded correctly
|
||||
# below might not be necessary, but I had a suspicion that this wasn't being loaded correctly
|
||||
self.ema_diffusion_prior.ema_model.load_state_dict(loaded_obj["ema_model"])
|
||||
|
||||
# sync and inform
|
||||
self.wait_for_everyone()
|
||||
self.print(f"Loaded model")
|
||||
|
||||
return loaded_obj
|
||||
|
||||
# model functionality
|
||||
|
||||
def update(self):
|
||||
# only continue with updates until all ranks finish
|
||||
self.wait_for_everyone()
|
||||
|
||||
if exists(self.max_grad_norm):
|
||||
self.scaler.unscale_(self.optimizer)
|
||||
# utilize HFA clipping where applicable
|
||||
self.clip_grad_norm_(self.diffusion_prior.parameters(), self.max_grad_norm)
|
||||
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
self.accelerator.clip_grad_norm_(self.diffusion_prior.parameters(), self.max_grad_norm)
|
||||
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
# accelerator will ocassionally skip optimizer steps in a "dynamic loss scaling strategy"
|
||||
if not self.accelerator.optimizer_step_was_skipped:
|
||||
sched_context = self.warmup_scheduler.dampening if exists(self.warmup_scheduler) else nullcontext
|
||||
with sched_context():
|
||||
self.scheduler.step()
|
||||
|
||||
if self.use_ema:
|
||||
self.ema_diffusion_prior.update()
|
||||
|
||||
@@ -376,7 +389,7 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
@cast_torch_tensor
|
||||
@prior_sample_in_chunks
|
||||
def embed_text(self, *args, **kwargs):
|
||||
return self.unwrap_model(self.diffusion_prior).clip.embed_text(*args, **kwargs)
|
||||
return self.accelerator.unwrap_model(self.diffusion_prior).clip.embed_text(*args, **kwargs)
|
||||
|
||||
@cast_torch_tensor
|
||||
def forward(
|
||||
@@ -388,16 +401,14 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
total_loss = 0.
|
||||
|
||||
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
|
||||
with autocast(enabled = self.amp):
|
||||
with self.accelerator.autocast():
|
||||
loss = self.diffusion_prior(*chunked_args, **chunked_kwargs)
|
||||
loss = loss * chunk_size_frac
|
||||
|
||||
total_loss += loss.item()
|
||||
|
||||
# backprop with accelerate if applicable
|
||||
|
||||
if self.training:
|
||||
self.backprop(self.scaler.scale(loss))
|
||||
self.accelerator.backward(loss)
|
||||
|
||||
return total_loss
|
||||
|
||||
@@ -424,10 +435,13 @@ class DecoderTrainer(nn.Module):
|
||||
self,
|
||||
decoder,
|
||||
accelerator = None,
|
||||
dataloaders = None,
|
||||
use_ema = True,
|
||||
lr = 1e-4,
|
||||
wd = 1e-2,
|
||||
eps = 1e-8,
|
||||
warmup_steps = None,
|
||||
cosine_decay_max_steps = None,
|
||||
max_grad_norm = 0.5,
|
||||
amp = False,
|
||||
group_wd_params = True,
|
||||
@@ -449,23 +463,40 @@ class DecoderTrainer(nn.Module):
|
||||
# be able to finely customize learning rate, weight decay
|
||||
# per unet
|
||||
|
||||
lr, wd, eps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps))
|
||||
lr, wd, eps, warmup_steps, cosine_decay_max_steps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps, warmup_steps, cosine_decay_max_steps))
|
||||
|
||||
assert all([unet_lr < 1e-3 for unet_lr in lr]), 'your learning rate is too high, recommend sticking with 1e-4, at most 5e-4'
|
||||
assert all([unet_lr <= 1e-2 for unet_lr in lr]), 'your learning rate is too high, recommend sticking with 1e-4, at most 5e-4'
|
||||
|
||||
optimizers = []
|
||||
schedulers = []
|
||||
warmup_schedulers = []
|
||||
|
||||
for unet, unet_lr, unet_wd, unet_eps in zip(decoder.unets, lr, wd, eps):
|
||||
optimizer = get_optimizer(
|
||||
unet.parameters(),
|
||||
lr = unet_lr,
|
||||
wd = unet_wd,
|
||||
eps = unet_eps,
|
||||
group_wd_params = group_wd_params,
|
||||
**kwargs
|
||||
)
|
||||
for unet, unet_lr, unet_wd, unet_eps, unet_warmup_steps, unet_cosine_decay_max_steps in zip(decoder.unets, lr, wd, eps, warmup_steps, cosine_decay_max_steps):
|
||||
if isinstance(unet, nn.Identity):
|
||||
optimizers.append(None)
|
||||
schedulers.append(None)
|
||||
warmup_schedulers.append(None)
|
||||
else:
|
||||
optimizer = get_optimizer(
|
||||
unet.parameters(),
|
||||
lr = unet_lr,
|
||||
wd = unet_wd,
|
||||
eps = unet_eps,
|
||||
group_wd_params = group_wd_params,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
optimizers.append(optimizer)
|
||||
optimizers.append(optimizer)
|
||||
|
||||
if exists(unet_cosine_decay_max_steps):
|
||||
scheduler = CosineAnnealingLR(optimizer, T_max = unet_cosine_decay_max_steps)
|
||||
else:
|
||||
scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)
|
||||
|
||||
warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) if exists(unet_warmup_steps) else None
|
||||
warmup_schedulers.append(warmup_scheduler)
|
||||
|
||||
schedulers.append(scheduler)
|
||||
|
||||
if self.use_ema:
|
||||
self.ema_unets.append(EMA(unet, **ema_kwargs))
|
||||
@@ -474,15 +505,58 @@ class DecoderTrainer(nn.Module):
|
||||
|
||||
self.max_grad_norm = max_grad_norm
|
||||
|
||||
self.register_buffer('step', torch.tensor([0.]))
|
||||
self.register_buffer('steps', torch.tensor([0] * self.num_unets))
|
||||
|
||||
if self.accelerator.distributed_type == DistributedType.DEEPSPEED and decoder.clip is not None:
|
||||
# Then we need to make sure clip is using the correct precision or else deepspeed will error
|
||||
cast_type_map = {
|
||||
"fp16": torch.half,
|
||||
"bf16": torch.bfloat16,
|
||||
"no": torch.float
|
||||
}
|
||||
precision_type = cast_type_map[accelerator.mixed_precision]
|
||||
assert precision_type == torch.float, "DeepSpeed currently only supports float32 precision when using on the fly embedding generation from clip"
|
||||
clip = decoder.clip
|
||||
clip.to(precision_type)
|
||||
|
||||
decoder, *optimizers = list(self.accelerator.prepare(decoder, *optimizers))
|
||||
|
||||
self.decoder = decoder
|
||||
|
||||
# prepare dataloaders
|
||||
|
||||
train_loader = val_loader = None
|
||||
if exists(dataloaders):
|
||||
train_loader, val_loader = self.accelerator.prepare(dataloaders["train"], dataloaders["val"])
|
||||
|
||||
self.train_loader = train_loader
|
||||
self.val_loader = val_loader
|
||||
|
||||
# store optimizers
|
||||
|
||||
for opt_ind, optimizer in zip(range(len(optimizers)), optimizers):
|
||||
setattr(self, f'optim{opt_ind}', optimizer)
|
||||
|
||||
# store schedulers
|
||||
|
||||
for sched_ind, scheduler in zip(range(len(schedulers)), schedulers):
|
||||
setattr(self, f'sched{sched_ind}', scheduler)
|
||||
|
||||
# store warmup schedulers
|
||||
|
||||
self.warmup_schedulers = warmup_schedulers
|
||||
|
||||
def validate_and_return_unet_number(self, unet_number = None):
|
||||
if self.num_unets == 1:
|
||||
unet_number = default(unet_number, 1)
|
||||
|
||||
assert exists(unet_number) and 1 <= unet_number <= self.num_unets
|
||||
return unet_number
|
||||
|
||||
def num_steps_taken(self, unet_number = None):
|
||||
unet_number = self.validate_and_return_unet_number(unet_number)
|
||||
return self.steps[unet_number - 1].item()
|
||||
|
||||
def save(self, path, overwrite = True, **kwargs):
|
||||
path = Path(path)
|
||||
assert not (path.exists() and not overwrite)
|
||||
@@ -491,14 +565,21 @@ class DecoderTrainer(nn.Module):
|
||||
save_obj = dict(
|
||||
model = self.accelerator.unwrap_model(self.decoder).state_dict(),
|
||||
version = __version__,
|
||||
step = self.step.item(),
|
||||
steps = self.steps.cpu(),
|
||||
**kwargs
|
||||
)
|
||||
|
||||
for ind in range(0, self.num_unets):
|
||||
optimizer_key = f'optim{ind}'
|
||||
scheduler_key = f'sched{ind}'
|
||||
|
||||
optimizer = getattr(self, optimizer_key)
|
||||
save_obj = {**save_obj, optimizer_key: self.accelerator.unwrap_model(optimizer).state_dict()}
|
||||
scheduler = getattr(self, scheduler_key)
|
||||
|
||||
optimizer_state_dict = optimizer.state_dict() if exists(optimizer) else None
|
||||
scheduler_state_dict = scheduler.state_dict() if exists(scheduler) else None
|
||||
|
||||
save_obj = {**save_obj, optimizer_key: optimizer_state_dict, scheduler_key: scheduler_state_dict}
|
||||
|
||||
if self.use_ema:
|
||||
save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}
|
||||
@@ -510,16 +591,29 @@ class DecoderTrainer(nn.Module):
|
||||
self.accelerator.print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {__version__}')
|
||||
|
||||
self.accelerator.unwrap_model(self.decoder).load_state_dict(loaded_obj['model'], strict = strict)
|
||||
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
|
||||
self.steps.copy_(loaded_obj['steps'])
|
||||
|
||||
if only_model:
|
||||
return loaded_obj
|
||||
|
||||
for ind in range(0, self.num_unets):
|
||||
for ind, last_step in zip(range(0, self.num_unets), self.steps.tolist()):
|
||||
|
||||
optimizer_key = f'optim{ind}'
|
||||
optimizer = getattr(self, optimizer_key)
|
||||
|
||||
self.accelerator.unwrap_model(optimizer).load_state_dict(loaded_obj[optimizer_key])
|
||||
scheduler_key = f'sched{ind}'
|
||||
scheduler = getattr(self, scheduler_key)
|
||||
|
||||
warmup_scheduler = self.warmup_schedulers[ind]
|
||||
|
||||
if exists(optimizer):
|
||||
optimizer.load_state_dict(loaded_obj[optimizer_key])
|
||||
|
||||
if exists(scheduler):
|
||||
scheduler.load_state_dict(loaded_obj[scheduler_key])
|
||||
|
||||
if exists(warmup_scheduler):
|
||||
warmup_scheduler.last_step = last_step
|
||||
|
||||
if self.use_ema:
|
||||
assert 'ema' in loaded_obj
|
||||
@@ -539,25 +633,36 @@ class DecoderTrainer(nn.Module):
|
||||
def unets(self):
|
||||
return nn.ModuleList([ema.ema_model for ema in self.ema_unets])
|
||||
|
||||
def update(self, unet_number = None):
|
||||
if self.num_unets == 1:
|
||||
unet_number = default(unet_number, 1)
|
||||
def increment_step(self, unet_number):
|
||||
assert 1 <= unet_number <= self.num_unets
|
||||
|
||||
assert exists(unet_number) and 1 <= unet_number <= self.num_unets
|
||||
unet_index_tensor = torch.tensor(unet_number - 1, device = self.steps.device)
|
||||
self.steps += F.one_hot(unet_index_tensor, num_classes = len(self.steps))
|
||||
|
||||
def update(self, unet_number = None):
|
||||
unet_number = self.validate_and_return_unet_number(unet_number)
|
||||
index = unet_number - 1
|
||||
|
||||
optimizer = getattr(self, f'optim{index}')
|
||||
scheduler = getattr(self, f'sched{index}')
|
||||
|
||||
if exists(self.max_grad_norm):
|
||||
self.accelerator.clip_grad_norm_(self.decoder.parameters(), self.max_grad_norm) # Automatically unscales gradients
|
||||
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
warmup_scheduler = self.warmup_schedulers[index]
|
||||
scheduler_context = warmup_scheduler.dampening if exists(warmup_scheduler) else nullcontext
|
||||
|
||||
with scheduler_context():
|
||||
scheduler.step()
|
||||
|
||||
if self.use_ema:
|
||||
ema_unet = self.ema_unets[index]
|
||||
ema_unet.update()
|
||||
|
||||
self.step += 1
|
||||
self.increment_step(unet_number)
|
||||
|
||||
@torch.no_grad()
|
||||
@cast_torch_tensor
|
||||
@@ -565,8 +670,14 @@ class DecoderTrainer(nn.Module):
|
||||
def sample(self, *args, **kwargs):
|
||||
distributed = self.accelerator.num_processes > 1
|
||||
base_decoder = self.accelerator.unwrap_model(self.decoder)
|
||||
|
||||
was_training = base_decoder.training
|
||||
base_decoder.eval()
|
||||
|
||||
if kwargs.pop('use_non_ema', False) or not self.use_ema:
|
||||
return base_decoder.sample(*args, **kwargs, distributed = distributed)
|
||||
out = base_decoder.sample(*args, **kwargs, distributed = distributed)
|
||||
base_decoder.train(was_training)
|
||||
return out
|
||||
|
||||
trainable_unets = self.accelerator.unwrap_model(self.decoder).unets
|
||||
base_decoder.unets = self.unets # swap in exponential moving averaged unets for sampling
|
||||
@@ -579,6 +690,7 @@ class DecoderTrainer(nn.Module):
|
||||
for ema in self.ema_unets:
|
||||
ema.restore_ema_model_device()
|
||||
|
||||
base_decoder.train(was_training)
|
||||
return output
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -599,22 +711,32 @@ class DecoderTrainer(nn.Module):
|
||||
*args,
|
||||
unet_number = None,
|
||||
max_batch_size = None,
|
||||
return_lowres_cond_image=False,
|
||||
**kwargs
|
||||
):
|
||||
if self.num_unets == 1:
|
||||
unet_number = default(unet_number, 1)
|
||||
unet_number = self.validate_and_return_unet_number(unet_number)
|
||||
|
||||
total_loss = 0.
|
||||
|
||||
cond_images = []
|
||||
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
|
||||
# with autocast(enabled = self.amp):
|
||||
with self.accelerator.autocast():
|
||||
loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs)
|
||||
loss_obj = self.decoder(*chunked_args, unet_number = unet_number, return_lowres_cond_image=return_lowres_cond_image, **chunked_kwargs)
|
||||
# loss_obj may be a tuple with loss and cond_image
|
||||
if return_lowres_cond_image:
|
||||
loss, cond_image = loss_obj
|
||||
else:
|
||||
loss = loss_obj
|
||||
cond_image = None
|
||||
loss = loss * chunk_size_frac
|
||||
if cond_image is not None:
|
||||
cond_images.append(cond_image)
|
||||
|
||||
total_loss += loss.item()
|
||||
|
||||
if self.training:
|
||||
self.accelerator.backward(loss)
|
||||
|
||||
return total_loss
|
||||
if return_lowres_cond_image:
|
||||
return total_loss, torch.stack(cond_images)
|
||||
else:
|
||||
return total_loss
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = '0.16.0'
|
||||
__version__ = '1.10.6'
|
||||
|
||||
183
prior.md
Normal file
183
prior.md
Normal file
@@ -0,0 +1,183 @@
|
||||
# Diffusion Prior
|
||||
This readme serves as an introduction to the diffusion prior.
|
||||
|
||||
## Intro
|
||||
|
||||
A properly trained prior will allow you to translate between two embedding spaces. If you know *a priori* that two embeddings are connected some way—then ability the translate between them could extremely helpful.
|
||||
|
||||
### Motivation
|
||||
|
||||
Before we dive into the model, let’s look at a quick example of where the model may be helpful.
|
||||
|
||||
For demonstration purposes we will imagine that we wish to generate images from text using CLIP and a Decoder.
|
||||
|
||||
> [CLIP](https://openai.com/blog/clip/) is a contrastive model that learns to maximize the cosine similarity between a given image and caption, however, there is no guarantee that these embeddings are in the same space. While the embeddings generated are ***close*** the image and text embeddings occupy two disjoint sets.
|
||||
|
||||
```python
|
||||
# Load Models
|
||||
clip_model = clip.load("ViT-L/14")
|
||||
decoder = Decoder(checkpoint="best.pth") # A decoder trained on CLIP Image embeddings
|
||||
|
||||
# Retrieve prompt from user and encode with CLIP
|
||||
prompt = "A corgi wearing sunglasses"
|
||||
tokenized_text = tokenize(prompt)
|
||||
text_embedding = clip_model.encode_text(tokenized_text)
|
||||
|
||||
# Now, pass the text embedding to the decoder
|
||||
predicted_image = decoder.sample(text_embedding)
|
||||
```
|
||||
|
||||
> **Question**: *Can you spot the issue here?*
|
||||
>
|
||||
> **Answer**: *We’re trying to generate an image from a text embedding!*
|
||||
|
||||
Unfortunately, we run into the issue previously mentioned--the image embeddings and the text embeddings are not interchangeable! Now let's look at a better solution
|
||||
|
||||
```python
|
||||
# Load Models
|
||||
prior= Prior(checkpoint="prior.pth") # A decoder trained to go from: text-> clip text emb -> clip img emb
|
||||
decoder = Decoder(checkpoint="decoder.pth") # A decoder trained on CLIP Image embeddings
|
||||
|
||||
# Retrieve prompt from user and encode with a prior
|
||||
prompt = "A corgi wearing sunglasses"
|
||||
tokenized_text = tokenize(prompt)
|
||||
text_embedding = prior.sample(tokenized_text) # <-- now we get an embedding in the same space as images!
|
||||
|
||||
# Now, pass the predicted image embedding to the decoder
|
||||
predicted_image = decoder.sample(text_embedding)
|
||||
```
|
||||
|
||||
With the prior we are able to successfully generate embeddings *within* CLIP's image space! For this reason, the decoder will perform much better as it receives input that is much closer to its training data.
|
||||
|
||||
> **You may be asking yourself the following question:**
|
||||
>
|
||||
> *"Why don't you just train the decoder on clip text embeddings instead of image embeddings?"*
|
||||
>
|
||||
> OpenAI covers this topic in their [DALLE-2 paper](https://arxiv.org/abs/2204.06125). The TL;DR is *"it doesn't work as well as decoders trained on image embeddings"*...also...its just an example :smile:
|
||||
|
||||
## Usage
|
||||
|
||||
To utilize a pre-trained prior, it’s quite simple.
|
||||
|
||||
### Loading Checkpoints
|
||||
```python
|
||||
import torch
|
||||
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter
|
||||
from dalle2_pytorch.trainer import DiffusionPriorTrainer
|
||||
|
||||
def load_diffusion_model(dprior_path):
|
||||
|
||||
prior_network = DiffusionPriorNetwork(
|
||||
dim=768,
|
||||
depth=24,
|
||||
dim_head=64,
|
||||
heads=32,
|
||||
normformer=True,
|
||||
attn_dropout=5e-2,
|
||||
ff_dropout=5e-2,
|
||||
num_time_embeds=1,
|
||||
num_image_embeds=1,
|
||||
num_text_embeds=1,
|
||||
num_timesteps=1000,
|
||||
ff_mult=4
|
||||
)
|
||||
|
||||
diffusion_prior = DiffusionPrior(
|
||||
net=prior_network,
|
||||
clip=OpenAIClipAdapter("ViT-L/14"),
|
||||
image_embed_dim=768,
|
||||
timesteps=1000,
|
||||
cond_drop_prob=0.1,
|
||||
loss_type="l2",
|
||||
condition_on_text_encodings=True,
|
||||
|
||||
)
|
||||
|
||||
trainer = DiffusionPriorTrainer(
|
||||
diffusion_prior=diffusion_prior,
|
||||
lr=1.1e-4,
|
||||
wd=6.02e-2,
|
||||
max_grad_norm=0.5,
|
||||
amp=False,
|
||||
group_wd_params=True,
|
||||
use_ema=True,
|
||||
device=device,
|
||||
accelerator=None,
|
||||
)
|
||||
|
||||
trainer.load(dprior_path)
|
||||
|
||||
return trainer
|
||||
```
|
||||
|
||||
Here we instantiate a model matches the configuration it was trained with, and then load the weights (*just like any other PyTorch model!*)
|
||||
|
||||
### Sampling
|
||||
Once we have a pre-trained model, generating embeddings is quite simple!
|
||||
```python
|
||||
# tokenize the text
|
||||
tokenized_text = clip.tokenize("<your amazing prompt>")
|
||||
# predict an embedding
|
||||
predicted_embedding = prior.sample(tokenized_text, n_samples_per_batch=2, cond_scale=1.0)
|
||||
```
|
||||
|
||||
The resulting tensor returned from `.sample()` is of the same shape as your training data along the non-batch dimension(s). For example, a prior trained on `ViT-L/14` embeddings will predict an embedding of shape (1, 768).
|
||||
|
||||
> For CLIP priors, this is quite handy as it means that you can use prior.sample(tokenizer_text) as a drop in replacement for clip.encode_text().
|
||||
|
||||
**Some things to note:**
|
||||
* It is possible to specify the number of embeddings to sample from (the default suggested by OpenAI is `n=2`). Put simply, the idea here is that you avoid getting unlucky with a bad embedding generation by creating two; and selecting the one with the higher cosine similarity with the prompt.
|
||||
* You may specify a higher conditioning scale than the default (`1.0`). It is unclear whether OpenAI uses a higher value for the prior specifically, or only on the decoder. Local testing has shown poor results with anything higher than `1.0` but *ymmv*.
|
||||
|
||||
---
|
||||
|
||||
## Training
|
||||
|
||||
### Overview
|
||||
|
||||
Training the prior is a relatively straightforward process thanks to the Trainer base class. The major step that is required of you is preparing a dataset in the format that EmbeddingReader expects. Having pre-computed embeddings massively increases training efficiency and is generally recommended as you will likely benefit from having them on hand for other tasks as well. Once you have a dataset, you are ready to move onto configuration
|
||||
|
||||
## Dataset
|
||||
|
||||
To train the prior, it is highly recommended to use precomputed embeddings for the images. To obtain these for a custom dataset, you can leverage [img2datset](https://github.com/rom1504/img2dataset) to pull images from a list of URLs and [clip_retrieval](https://github.com/rom1504/clip-retrieval#clip-inference) for generating the actual embeddings that can be used in the prior's dataloader.
|
||||
|
||||
## Configuration
|
||||
|
||||
The configuration file allows for you to easily track and reproduce experiments. It is a simple JSON file that will specify the architecture, dataset, and training parameters. For more information and specifics please see the configuration README.
|
||||
|
||||
## Distributed Training
|
||||
|
||||
If you would like to train in a distributed manner we have opted to leverage huggingface’ new Accelerate library. HFA makes it extremely simple to distribute work across multiple GPU’s and nodes. All that is required of you is to follow the simple CLI configuration tool [more information here](https://huggingface.co/docs/accelerate/accelerator).
|
||||
|
||||
## Evaluation
|
||||
|
||||
There are a variety of metrics available to you when training the prior. You can read a brief description of each in the table below:
|
||||
| Metric | Description | Comments |
|
||||
| ----------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| Online Model Validation | The validation loss associated with your online model. | Ideally validation loss will be as low as possible. Using L2 loss, values as low as `0.1` and lower are possible after around 1 Billion samples seen. |
|
||||
| EMA Validation | This metric measures the validation loss associated with your EMA model. | This will likely lag behind your "online" model's validation loss, but should outperform in the long-term. |
|
||||
| Baseline Similarity | Baseline similarity refers to the similarity between your dataset's prompts and associated image embeddings. This will serve as a guide for your prior's performance in cosine similarity. | Generally `0.3` is considered a good cosine similarity for caption similarity. |
|
||||
| Similarity With Original Image | This metric will measure the cosine similarity between your prior's predicted image embedding and the actual image that the caption was associated with. This is useful for determining wether your prior is generating images with the right contents. | Values around `0.75`+ are obtainable. This metric should improve rapidly in the early stages of training and plateau with diminishing increases over time. If it takes hundreds of millions of samples to reach above `0.5`/`0.6` similarity--then you likely are suffering from some kind of training error or inefficiency (i.e. not using EMA) |
|
||||
| Difference From Baseline Similarity | Sometimes its useful to visualize a metric in another light. This metric will show you how your prior's predicted image embeddings match up with the baseline similarity measured in your dataset. | This value should float around `0.0` with some room for variation. After a billion samples seen, values are within `0.01`+/- of `0.0`. If this climbs to high, (~>`0.02`) then this may be a sign that your model is overfitting somehow. |
|
||||
| Similarity With Text | This metric is your bread and butter cosine similarity between the predicted image embedding and the original caption given to the prior. Monitoring this metric will be on of your main focuses and is probably the second most important behind your loss. | As mentioned, this value should be close to baseline similarity. We have observed early rapid increase with diminishing returns as the prior learns to generate valid image embeddings. If this value increases too far beyond the baseline similarity--it could be an indication that your model is overfitting. |
|
||||
| Similarity With Unrelated Caption | This metric will attempt to exposed an overfit prior by feeding it arbitrary prompts (from your dataset) and then measure the similarity of this predicted embedding with some other image. | Early on we found that a poorly trained/modeled prior could effectively fool CLIP into believing that the cosine similarity between two images were high (when in fact the caption and image were completely unrelated). With this in mind--a low value is ideal, anything below `0.1` is probably safe. |
|
||||
|
||||
## Launching the script
|
||||
|
||||
Now that you’ve done all the prep it’s time for the easy part! 🚀
|
||||
|
||||
To actually launch the script, you will either use `accelerate launch train_diffusion_prior.py --config_path <path to your config>` to launch with distributed training & huggingface accelerate or `python train_diffusion_prior.py` if you would like to train on your gpu/cpu without huggingface accelerate.
|
||||
|
||||
## Checkpointing
|
||||
|
||||
Checkpoints will be saved to the directory specified in your configuration file.
|
||||
|
||||
Additionally, a final checkpoint is saved before running the test split. This file will be saved to the same directory and titled “latest.pth”. This is to avoid problems where your `save_every` configuration does not overlap with the number of steps required to do a complete pass through the data.
|
||||
|
||||
## Things To Keep In Mind
|
||||
|
||||
The prior has not been trained for tasks other than the traditional CLIP embedding translation…at least yet.
|
||||
|
||||
As we finalize the replication of unCLIP, there will almost assuredly be experiments attempting to apply the prior network to other tasks.
|
||||
|
||||
With that in mind, you are more or less a pioneer in embedding-translation if you are reading this and attempting something you don’t see documentation for!
|
||||
3
setup.py
3
setup.py
@@ -26,7 +26,7 @@ setup(
|
||||
install_requires=[
|
||||
'accelerate',
|
||||
'click',
|
||||
'clip-anytorch',
|
||||
'clip-anytorch>=2.4.0',
|
||||
'coca-pytorch>=0.0.5',
|
||||
'ema-pytorch>=0.0.7',
|
||||
'einops>=0.4',
|
||||
@@ -37,6 +37,7 @@ setup(
|
||||
'packaging',
|
||||
'pillow',
|
||||
'pydantic',
|
||||
'pytorch-warmup',
|
||||
'resize-right>=0.0.2',
|
||||
'rotary-embedding-torch',
|
||||
'torch>=1.10',
|
||||
|
||||
BIN
test_data/0.tar
Normal file
BIN
test_data/0.tar
Normal file
Binary file not shown.
BIN
test_data/1.tar
Normal file
BIN
test_data/1.tar
Normal file
Binary file not shown.
BIN
test_data/2.tar
Normal file
BIN
test_data/2.tar
Normal file
Binary file not shown.
BIN
test_data/3.tar
Normal file
BIN
test_data/3.tar
Normal file
Binary file not shown.
BIN
test_data/4.tar
Normal file
BIN
test_data/4.tar
Normal file
Binary file not shown.
BIN
test_data/5.tar
Normal file
BIN
test_data/5.tar
Normal file
Binary file not shown.
BIN
test_data/6.tar
Normal file
BIN
test_data/6.tar
Normal file
Binary file not shown.
BIN
test_data/7.tar
Normal file
BIN
test_data/7.tar
Normal file
Binary file not shown.
BIN
test_data/8.tar
Normal file
BIN
test_data/8.tar
Normal file
Binary file not shown.
BIN
test_data/9.tar
Normal file
BIN
test_data/9.tar
Normal file
Binary file not shown.
150
train_decoder.py
150
train_decoder.py
@@ -1,5 +1,6 @@
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from datetime import timedelta
|
||||
|
||||
from dalle2_pytorch.trainer import DecoderTrainer
|
||||
from dalle2_pytorch.dataloaders import create_image_embedding_dataloader
|
||||
@@ -11,11 +12,12 @@ from clip import tokenize
|
||||
|
||||
import torchvision
|
||||
import torch
|
||||
from torch import nn
|
||||
from torchmetrics.image.fid import FrechetInceptionDistance
|
||||
from torchmetrics.image.inception import InceptionScore
|
||||
from torchmetrics.image.kid import KernelInceptionDistance
|
||||
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
|
||||
from accelerate import Accelerator, DistributedDataParallelKwargs
|
||||
from accelerate import Accelerator, DistributedDataParallelKwargs, InitProcessGroupKwargs
|
||||
from accelerate.utils import dataclasses as accelerate_dataclasses
|
||||
import webdataset as wds
|
||||
import click
|
||||
@@ -132,7 +134,7 @@ def get_example_data(dataloader, device, n=5):
|
||||
break
|
||||
return list(zip(images[:n], img_embeddings[:n], text_embeddings[:n], captions[:n]))
|
||||
|
||||
def generate_samples(trainer, example_data, condition_on_text_encodings=False, text_prepend=""):
|
||||
def generate_samples(trainer, example_data, clip=None, start_unet=1, end_unet=None, condition_on_text_encodings=False, cond_scale=1.0, device=None, text_prepend="", match_image_size=True):
|
||||
"""
|
||||
Takes example data and generates images from the embeddings
|
||||
Returns three lists: real images, generated images, and captions
|
||||
@@ -142,7 +144,9 @@ def generate_samples(trainer, example_data, condition_on_text_encodings=False, t
|
||||
if img_embeddings[0] is None:
|
||||
# Generate image embeddings from clip
|
||||
imgs_tensor = torch.stack(real_images)
|
||||
img_embeddings, *_ = trainer.embed_image(imgs_tensor)
|
||||
assert clip is not None, "clip is None, but img_embeddings is None"
|
||||
imgs_tensor.to(device=device)
|
||||
img_embeddings, img_encoding = clip.embed_image(imgs_tensor)
|
||||
sample_params["image_embed"] = img_embeddings
|
||||
else:
|
||||
# Then we are using precomputed image embeddings
|
||||
@@ -151,34 +155,38 @@ def generate_samples(trainer, example_data, condition_on_text_encodings=False, t
|
||||
if condition_on_text_encodings:
|
||||
if text_embeddings[0] is None:
|
||||
# Generate text embeddings from text
|
||||
tokenized_texts = tokenize(txts, truncate=True)
|
||||
sample_params["text"] = tokenized_texts
|
||||
assert clip is not None, "clip is None, but text_embeddings is None"
|
||||
tokenized_texts = tokenize(txts, truncate=True).to(device=device)
|
||||
text_embed, text_encodings = clip.embed_text(tokenized_texts)
|
||||
sample_params["text_encodings"] = text_encodings
|
||||
else:
|
||||
# Then we are using precomputed text embeddings
|
||||
text_embeddings = torch.stack(text_embeddings)
|
||||
sample_params["text_encodings"] = text_embeddings
|
||||
samples = trainer.sample(**sample_params)
|
||||
sample_params["start_at_unet_number"] = start_unet
|
||||
sample_params["stop_at_unet_number"] = end_unet
|
||||
if start_unet > 1:
|
||||
# If we are only training upsamplers
|
||||
sample_params["image"] = torch.stack(real_images)
|
||||
if device is not None:
|
||||
sample_params["_device"] = device
|
||||
samples = trainer.sample(**sample_params, _cast_deepspeed_precision=False) # At sampling time we don't want to cast to FP16
|
||||
generated_images = list(samples)
|
||||
captions = [text_prepend + txt for txt in txts]
|
||||
if match_image_size:
|
||||
generated_image_size = generated_images[0].shape[-1]
|
||||
real_images = [resize_image_to(image, generated_image_size, clamp_range=(0, 1)) for image in real_images]
|
||||
return real_images, generated_images, captions
|
||||
|
||||
def generate_grid_samples(trainer, examples, condition_on_text_encodings=False, text_prepend=""):
|
||||
def generate_grid_samples(trainer, examples, clip=None, start_unet=1, end_unet=None, condition_on_text_encodings=False, cond_scale=1.0, device=None, text_prepend=""):
|
||||
"""
|
||||
Generates samples and uses torchvision to put them in a side by side grid for easy viewing
|
||||
"""
|
||||
real_images, generated_images, captions = generate_samples(trainer, examples, condition_on_text_encodings, text_prepend)
|
||||
|
||||
real_image_size = real_images[0].shape[-1]
|
||||
generated_image_size = generated_images[0].shape[-1]
|
||||
|
||||
# training images may be larger than the generated one
|
||||
if real_image_size > generated_image_size:
|
||||
real_images = [resize_image_to(image, generated_image_size) for image in real_images]
|
||||
|
||||
real_images, generated_images, captions = generate_samples(trainer, examples, clip, start_unet, end_unet, condition_on_text_encodings, cond_scale, device, text_prepend)
|
||||
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
|
||||
return grid_images, captions
|
||||
|
||||
def evaluate_trainer(trainer, dataloader, device, condition_on_text_encodings=False, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
|
||||
def evaluate_trainer(trainer, dataloader, device, start_unet, end_unet, clip=None, condition_on_text_encodings=False, cond_scale=1.0, inference_device=None, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
|
||||
"""
|
||||
Computes evaluation metrics for the decoder
|
||||
"""
|
||||
@@ -188,7 +196,7 @@ def evaluate_trainer(trainer, dataloader, device, condition_on_text_encodings=Fa
|
||||
if len(examples) == 0:
|
||||
print("No data to evaluate. Check that your dataloader has shards.")
|
||||
return metrics
|
||||
real_images, generated_images, captions = generate_samples(trainer, examples, condition_on_text_encodings)
|
||||
real_images, generated_images, captions = generate_samples(trainer, examples, clip, start_unet, end_unet, condition_on_text_encodings, cond_scale, inference_device)
|
||||
real_images = torch.stack(real_images).to(device=device, dtype=torch.float)
|
||||
generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float)
|
||||
# Convert from [0, 1] to [0, 255] and from torch.float to torch.uint8
|
||||
@@ -221,8 +229,8 @@ def evaluate_trainer(trainer, dataloader, device, condition_on_text_encodings=Fa
|
||||
metrics["KID_std"] = kid_std.item()
|
||||
if exists(LPIPS):
|
||||
# Convert from [0, 1] to [-1, 1]
|
||||
renorm_real_images = real_images.mul(2).sub(1)
|
||||
renorm_generated_images = generated_images.mul(2).sub(1)
|
||||
renorm_real_images = real_images.mul(2).sub(1).clamp(-1,1)
|
||||
renorm_generated_images = generated_images.mul(2).sub(1).clamp(-1,1)
|
||||
lpips = LearnedPerceptualImagePatchSimilarity(**LPIPS, dist_sync_fn=null_sync)
|
||||
lpips.to(device=device)
|
||||
lpips.update(renorm_real_images, renorm_generated_images)
|
||||
@@ -261,14 +269,17 @@ def train(
|
||||
accelerator: Accelerator,
|
||||
tracker: Tracker,
|
||||
inference_device,
|
||||
clip=None,
|
||||
evaluate_config=None,
|
||||
epoch_samples = None, # If the training dataset is resampling, we have to manually stop an epoch
|
||||
validation_samples = None,
|
||||
save_immediately=False,
|
||||
epochs = 20,
|
||||
n_sample_images = 5,
|
||||
save_every_n_samples = 100000,
|
||||
unet_training_mask=None,
|
||||
condition_on_text_encodings=False,
|
||||
cond_scale=1.0,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
@@ -276,9 +287,25 @@ def train(
|
||||
"""
|
||||
is_master = accelerator.process_index == 0
|
||||
|
||||
if not exists(unet_training_mask):
|
||||
# Then the unet mask should be true for all unets in the decoder
|
||||
unet_training_mask = [True] * len(decoder.unets)
|
||||
assert len(unet_training_mask) == len(decoder.unets), f"The unet training mask should be the same length as the number of unets in the decoder. Got {len(unet_training_mask)} and {trainer.num_unets}"
|
||||
trainable_unet_numbers = [i+1 for i, trainable in enumerate(unet_training_mask) if trainable]
|
||||
first_trainable_unet = trainable_unet_numbers[0]
|
||||
last_trainable_unet = trainable_unet_numbers[-1]
|
||||
def move_unets(unet_training_mask):
|
||||
for i in range(len(decoder.unets)):
|
||||
if not unet_training_mask[i]:
|
||||
# Replace the unet from the module list with a nn.Identity(). This training script never uses unets that aren't being trained so this is fine.
|
||||
decoder.unets[i] = nn.Identity().to(inference_device)
|
||||
# Remove non-trainable unets
|
||||
move_unets(unet_training_mask)
|
||||
|
||||
trainer = DecoderTrainer(
|
||||
decoder=decoder,
|
||||
accelerator=accelerator,
|
||||
dataloaders=dataloaders,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@@ -289,9 +316,9 @@ def train(
|
||||
sample = 0
|
||||
samples_seen = 0
|
||||
val_sample = 0
|
||||
step = lambda: int(trainer.step.item())
|
||||
step = lambda: int(trainer.num_steps_taken(unet_number=first_trainable_unet))
|
||||
|
||||
if tracker.loader is not None:
|
||||
if tracker.can_recall:
|
||||
start_epoch, validation_losses, next_task, recalled_sample, samples_seen = recall_trainer(tracker, trainer)
|
||||
if next_task == 'train':
|
||||
sample = recalled_sample
|
||||
@@ -301,11 +328,6 @@ def train(
|
||||
accelerator.print(f"Starting training from task {next_task} at sample {sample} and validation sample {val_sample}")
|
||||
trainer.to(device=inference_device)
|
||||
|
||||
if not exists(unet_training_mask):
|
||||
# Then the unet mask should be true for all unets in the decoder
|
||||
unet_training_mask = [True] * trainer.num_unets
|
||||
assert len(unet_training_mask) == trainer.num_unets, f"The unet training mask should be the same length as the number of unets in the decoder. Got {len(unet_training_mask)} and {trainer.num_unets}"
|
||||
|
||||
accelerator.print(print_ribbon("Generating Example Data", repeat=40))
|
||||
accelerator.print("This can take a while to load the shard lists...")
|
||||
if is_master:
|
||||
@@ -354,15 +376,20 @@ def train(
|
||||
forward_params['image_embed'] = img_emb
|
||||
else:
|
||||
# Forward pass automatically generates embedding
|
||||
pass
|
||||
assert clip is not None
|
||||
img_embed, img_encoding = clip.embed_image(img)
|
||||
forward_params['image_embed'] = img_embed
|
||||
if condition_on_text_encodings:
|
||||
if has_text_embedding:
|
||||
forward_params['text_encodings'] = text_emb
|
||||
else:
|
||||
# Then we need to pass the text instead
|
||||
tokenized_texts = tokenize(txt, truncate=True)
|
||||
forward_params['text'] = tokenized_texts
|
||||
loss = trainer.forward(img, **forward_params, unet_number=unet)
|
||||
assert clip is not None
|
||||
tokenized_texts = tokenize(txt, truncate=True).to(inference_device)
|
||||
assert tokenized_texts.shape[0] == len(img), f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})"
|
||||
text_embed, text_encodings = clip.embed_text(tokenized_texts)
|
||||
forward_params['text_encodings'] = text_encodings
|
||||
loss = trainer.forward(img, **forward_params, unet_number=unet, _device=inference_device)
|
||||
trainer.update(unet_number=unet)
|
||||
unet_losses_tensor[i % TRAIN_CALC_LOSS_EVERY_ITERS, unet-1] = loss
|
||||
|
||||
@@ -375,10 +402,10 @@ def train(
|
||||
unet_all_losses = accelerator.gather(unet_losses_tensor)
|
||||
mask = unet_all_losses != 0
|
||||
unet_average_loss = (unet_all_losses * mask).sum(dim=0) / mask.sum(dim=0)
|
||||
loss_map = { f"Unet {index} Training Loss": loss.item() for index, loss in enumerate(unet_average_loss) if loss != 0 }
|
||||
loss_map = { f"Unet {index} Training Loss": loss.item() for index, loss in enumerate(unet_average_loss) if unet_training_mask[index] }
|
||||
|
||||
# gather decay rate on each UNet
|
||||
ema_decay_list = {f"Unet {index} EMA Decay": ema_unet.get_current_decay() for index, ema_unet in enumerate(trainer.ema_unets)}
|
||||
ema_decay_list = {f"Unet {index} EMA Decay": ema_unet.get_current_decay() for index, ema_unet in enumerate(trainer.ema_unets) if unet_training_mask[index]}
|
||||
|
||||
log_data = {
|
||||
"Epoch": epoch,
|
||||
@@ -393,7 +420,7 @@ def train(
|
||||
if is_master:
|
||||
tracker.log(log_data, step=step())
|
||||
|
||||
if is_master and last_snapshot + save_every_n_samples < sample: # This will miss by some amount every time, but it's not a big deal... I hope
|
||||
if is_master and (last_snapshot + save_every_n_samples < sample or (save_immediately and i == 0)): # This will miss by some amount every time, but it's not a big deal... I hope
|
||||
# It is difficult to gather this kind of info on the accelerator, so we have to do it on the master
|
||||
print("Saving snapshot")
|
||||
last_snapshot = sample
|
||||
@@ -401,7 +428,7 @@ def train(
|
||||
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, samples_seen)
|
||||
if exists(n_sample_images) and n_sample_images > 0:
|
||||
trainer.eval()
|
||||
train_images, train_captions = generate_grid_samples(trainer, train_example_data, condition_on_text_encodings, "Train: ")
|
||||
train_images, train_captions = generate_grid_samples(trainer, train_example_data, clip, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Train: ")
|
||||
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
|
||||
|
||||
if epoch_samples is not None and sample >= epoch_samples:
|
||||
@@ -419,7 +446,7 @@ def train(
|
||||
timer = Timer()
|
||||
accelerator.wait_for_everyone()
|
||||
i = 0
|
||||
for i, (img, emb, txt) in enumerate(dataloaders["val"]):
|
||||
for i, (img, emb, txt) in enumerate(dataloaders['val']): # Use the accelerate prepared loader
|
||||
val_sample_length_tensor[0] = len(img)
|
||||
all_samples = accelerator.gather(val_sample_length_tensor)
|
||||
total_samples = all_samples.sum().item()
|
||||
@@ -444,15 +471,20 @@ def train(
|
||||
forward_params['image_embed'] = img_emb.float()
|
||||
else:
|
||||
# Forward pass automatically generates embedding
|
||||
pass
|
||||
assert clip is not None
|
||||
img_embed, img_encoding = clip.embed_image(img)
|
||||
forward_params['image_embed'] = img_embed
|
||||
if condition_on_text_encodings:
|
||||
if has_text_embedding:
|
||||
forward_params['text_encodings'] = text_emb.float()
|
||||
else:
|
||||
# Then we need to pass the text instead
|
||||
tokenized_texts = tokenize(txt, truncate=True)
|
||||
forward_params['text'] = tokenized_texts
|
||||
loss = trainer.forward(img.float(), **forward_params, unet_number=unet)
|
||||
assert clip is not None
|
||||
tokenized_texts = tokenize(txt, truncate=True).to(device=inference_device)
|
||||
assert tokenized_texts.shape[0] == len(img), f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})"
|
||||
text_embed, text_encodings = clip.embed_text(tokenized_texts)
|
||||
forward_params['text_encodings'] = text_encodings
|
||||
loss = trainer.forward(img.float(), **forward_params, unet_number=unet, _device=inference_device)
|
||||
average_val_loss_tensor[0, unet-1] += loss
|
||||
|
||||
if i % VALID_CALC_LOSS_EVERY_ITERS == 0:
|
||||
@@ -479,7 +511,7 @@ def train(
|
||||
if next_task == 'eval':
|
||||
if exists(evaluate_config):
|
||||
accelerator.print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
|
||||
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config.dict(), condition_on_text_encodings=condition_on_text_encodings)
|
||||
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, first_trainable_unet, last_trainable_unet, clip=clip, inference_device=inference_device, **evaluate_config.dict(), condition_on_text_encodings=condition_on_text_encodings, cond_scale=cond_scale)
|
||||
if is_master:
|
||||
tracker.log(evaluation, step=step())
|
||||
next_task = 'sample'
|
||||
@@ -490,15 +522,15 @@ def train(
|
||||
# Generate examples and save the model if we are the master
|
||||
# Generate sample images
|
||||
print(print_ribbon(f"Sampling Set {epoch}", repeat=40))
|
||||
test_images, test_captions = generate_grid_samples(trainer, test_example_data, condition_on_text_encodings, "Test: ")
|
||||
train_images, train_captions = generate_grid_samples(trainer, train_example_data, condition_on_text_encodings, "Train: ")
|
||||
test_images, test_captions = generate_grid_samples(trainer, test_example_data, clip, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Test: ")
|
||||
train_images, train_captions = generate_grid_samples(trainer, train_example_data, clip, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Train: ")
|
||||
tracker.log_images(test_images, captions=test_captions, image_section="Test Samples", step=step())
|
||||
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
|
||||
|
||||
print(print_ribbon(f"Starting Saving {epoch}", repeat=40))
|
||||
is_best = False
|
||||
if all_average_val_losses is not None:
|
||||
average_loss = all_average_val_losses.mean(dim=0).item()
|
||||
average_loss = all_average_val_losses.mean(dim=0).sum() / sum(unet_training_mask)
|
||||
if len(validation_losses) == 0 or average_loss < min(validation_losses):
|
||||
is_best = True
|
||||
validation_losses.append(average_loss)
|
||||
@@ -513,8 +545,10 @@ def create_tracker(accelerator: Accelerator, config: TrainDecoderConfig, config_
|
||||
"NumProcesses": accelerator.num_processes,
|
||||
"MixedPrecision": accelerator.mixed_precision
|
||||
}
|
||||
accelerator.wait_for_everyone() # If nodes arrive at this point at different times they might try to autoresume the current run which makes no sense and will cause errors
|
||||
tracker: Tracker = tracker_config.create(config, accelerator_config, dummy_mode=dummy)
|
||||
tracker.save_config(config_path, config_name='decoder_config.json')
|
||||
tracker.add_save_metadata(state_dict_key='config', metadata=config.dict())
|
||||
return tracker
|
||||
|
||||
def initialize_training(config: TrainDecoderConfig, config_path):
|
||||
@@ -523,7 +557,18 @@ def initialize_training(config: TrainDecoderConfig, config_path):
|
||||
|
||||
# Set up accelerator for configurable distributed training
|
||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=config.train.find_unused_parameters)
|
||||
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
|
||||
init_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=60*60))
|
||||
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs, init_kwargs])
|
||||
|
||||
if accelerator.num_processes > 1:
|
||||
# We are using distributed training and want to immediately ensure all can connect
|
||||
accelerator.print("Waiting for all processes to connect...")
|
||||
accelerator.wait_for_everyone()
|
||||
accelerator.print("All processes online and connected")
|
||||
|
||||
# If we are in deepspeed fp16 mode, we must ensure learned variance is off
|
||||
if accelerator.mixed_precision == "fp16" and accelerator.distributed_type == accelerate_dataclasses.DistributedType.DEEPSPEED and config.decoder.learned_variance:
|
||||
raise ValueError("DeepSpeed fp16 mode does not support learned variance")
|
||||
|
||||
# Set up data
|
||||
all_shards = list(range(config.data.start_shard, config.data.end_shard + 1))
|
||||
@@ -544,9 +589,14 @@ def initialize_training(config: TrainDecoderConfig, config_path):
|
||||
seed = config.seed,
|
||||
)
|
||||
|
||||
# If clip is in the model, we need to remove it for compatibility with deepspeed
|
||||
clip = None
|
||||
if config.decoder.clip is not None:
|
||||
clip = config.decoder.clip.create() # Of course we keep it to use it during training, just not in the decoder as that causes issues
|
||||
config.decoder.clip = None
|
||||
# Create the decoder model and print basic info
|
||||
decoder = config.decoder.create()
|
||||
num_parameters = sum(p.numel() for p in decoder.parameters())
|
||||
get_num_parameters = lambda model, only_training=False: sum(p.numel() for p in model.parameters() if (p.requires_grad or not only_training))
|
||||
|
||||
# Create and initialize the tracker if we are the master
|
||||
tracker = create_tracker(accelerator, config, config_path, dummy = rank!=0)
|
||||
@@ -555,7 +605,7 @@ def initialize_training(config: TrainDecoderConfig, config_path):
|
||||
has_text_embeddings = config.data.text_embeddings_url is not None
|
||||
conditioning_on_text = any([unet.cond_on_text_encodings for unet in config.decoder.unets])
|
||||
|
||||
has_clip_model = config.decoder.clip is not None
|
||||
has_clip_model = clip is not None
|
||||
data_source_string = ""
|
||||
|
||||
if has_img_embeddings:
|
||||
@@ -575,8 +625,12 @@ def initialize_training(config: TrainDecoderConfig, config_path):
|
||||
accelerator.print(print_ribbon("Loaded Config", repeat=40))
|
||||
accelerator.print(f"Running training with {accelerator.num_processes} processes and {accelerator.distributed_type} distributed training")
|
||||
accelerator.print(f"Training using {data_source_string}. {'conditioned on text' if conditioning_on_text else 'not conditioned on text'}")
|
||||
accelerator.print(f"Number of parameters: {num_parameters}")
|
||||
accelerator.print(f"Number of parameters: {get_num_parameters(decoder)} total; {get_num_parameters(decoder, only_training=True)} training")
|
||||
for i, unet in enumerate(decoder.unets):
|
||||
accelerator.print(f"Unet {i} has {get_num_parameters(unet)} total; {get_num_parameters(unet, only_training=True)} training")
|
||||
|
||||
train(dataloaders, decoder, accelerator,
|
||||
clip=clip,
|
||||
tracker=tracker,
|
||||
inference_device=accelerator.device,
|
||||
evaluate_config=config.evaluate,
|
||||
|
||||
@@ -1,31 +1,23 @@
|
||||
# TODO: add start, num_data_points, eval_every and group to config
|
||||
# TODO: switch back to repo's wandb
|
||||
|
||||
START = 0
|
||||
NUM_DATA_POINTS = 250e6
|
||||
EVAL_EVERY = 1000
|
||||
GROUP = "distributed"
|
||||
|
||||
import os
|
||||
import click
|
||||
import wandb
|
||||
|
||||
import torch
|
||||
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
import numpy as np
|
||||
|
||||
from typing import List
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import set_seed
|
||||
from torch.utils.data import DataLoader
|
||||
from embedding_reader import EmbeddingReader
|
||||
from accelerate.utils import dataclasses as accelerate_dataclasses
|
||||
|
||||
from dalle2_pytorch.dataloaders import get_reader, make_splits
|
||||
from dalle2_pytorch.utils import Timer
|
||||
from dalle2_pytorch.trackers import Tracker
|
||||
from dalle2_pytorch import DiffusionPriorTrainer
|
||||
from dalle2_pytorch.dataloaders import get_reader, make_splits
|
||||
from dalle2_pytorch.train_configs import (
|
||||
DiffusionPriorConfig,
|
||||
DiffusionPriorTrainConfig,
|
||||
TrainDiffusionPriorConfig,
|
||||
)
|
||||
from dalle2_pytorch.trackers import BaseTracker, WandbTracker
|
||||
from dalle2_pytorch import DiffusionPriorTrainer
|
||||
|
||||
|
||||
# helpers
|
||||
@@ -38,8 +30,19 @@ def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def all_between(values: list, lower_bound, upper_bound):
|
||||
for value in values:
|
||||
if value < lower_bound or value > upper_bound:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def make_model(
|
||||
prior_config, train_config, device: str = None, accelerator: Accelerator = None
|
||||
prior_config: DiffusionPriorConfig,
|
||||
train_config: DiffusionPriorTrainConfig,
|
||||
device: str = None,
|
||||
accelerator: Accelerator = None,
|
||||
):
|
||||
# create model from config
|
||||
diffusion_prior = prior_config.create()
|
||||
@@ -54,71 +57,214 @@ def make_model(
|
||||
use_ema=train_config.use_ema,
|
||||
device=device,
|
||||
accelerator=accelerator,
|
||||
warmup_steps=train_config.warmup_steps,
|
||||
)
|
||||
|
||||
return trainer
|
||||
|
||||
|
||||
def create_tracker(
|
||||
accelerator: Accelerator,
|
||||
config: TrainDiffusionPriorConfig,
|
||||
config_path: str,
|
||||
dummy: bool = False,
|
||||
) -> Tracker:
|
||||
tracker_config = config.tracker
|
||||
|
||||
accelerator_config = {
|
||||
"Distributed": accelerator.distributed_type
|
||||
!= accelerate_dataclasses.DistributedType.NO,
|
||||
"DistributedType": accelerator.distributed_type,
|
||||
"NumProcesses": accelerator.num_processes,
|
||||
"MixedPrecision": accelerator.mixed_precision,
|
||||
}
|
||||
|
||||
tracker: Tracker = tracker_config.create(
|
||||
config, accelerator_config, dummy_mode=dummy
|
||||
)
|
||||
|
||||
tracker.save_config(config_path, config_name="prior_config.json")
|
||||
|
||||
return tracker
|
||||
|
||||
|
||||
def pad_gather_reduce(trainer: DiffusionPriorTrainer, x, method="mean"):
|
||||
"""
|
||||
pad a value or tensor across all processes and gather
|
||||
|
||||
params:
|
||||
- trainer: a trainer that carries an accelerator object
|
||||
- x: a number or torch tensor to reduce
|
||||
- method: "mean", "sum", "max", "min"
|
||||
|
||||
return:
|
||||
- the average tensor after maskin out 0's
|
||||
- None if the gather resulted in an empty tensor
|
||||
"""
|
||||
|
||||
assert method in [
|
||||
"mean",
|
||||
"sum",
|
||||
"max",
|
||||
"min",
|
||||
], "This function has limited capabilities [sum, mean, max, min]"
|
||||
assert type(x) is not None, "Cannot reduce a None type object"
|
||||
|
||||
# wait for everyone to arrive here before gathering
|
||||
|
||||
if type(x) is not torch.Tensor:
|
||||
x = torch.tensor([x])
|
||||
|
||||
# verify that the tensor is on the proper device
|
||||
x = x.to(trainer.device)
|
||||
|
||||
# pad across processes
|
||||
padded_x = trainer.accelerator.pad_across_processes(x, dim=0)
|
||||
|
||||
# gather across all procesess
|
||||
gathered_x = trainer.accelerator.gather(padded_x)
|
||||
|
||||
# mask out zeros
|
||||
masked_x = gathered_x[gathered_x != 0]
|
||||
|
||||
# if the tensor is empty, warn and return None
|
||||
if len(masked_x) == 0:
|
||||
click.secho(
|
||||
f"The call to this method resulted in an empty tensor after masking out zeros. The gathered tensor was this: {gathered_x} and the original value passed was: {x}.",
|
||||
fg="red",
|
||||
)
|
||||
return None
|
||||
|
||||
if method == "mean":
|
||||
return torch.mean(masked_x)
|
||||
elif method == "sum":
|
||||
return torch.sum(masked_x)
|
||||
elif method == "max":
|
||||
return torch.max(masked_x)
|
||||
elif method == "min":
|
||||
return torch.min(masked_x)
|
||||
|
||||
|
||||
def save_trainer(
|
||||
tracker: Tracker,
|
||||
trainer: DiffusionPriorTrainer,
|
||||
is_latest: bool,
|
||||
is_best: bool,
|
||||
epoch: int,
|
||||
samples_seen: int,
|
||||
best_validation_loss: float,
|
||||
):
|
||||
"""
|
||||
Logs the model with an appropriate method depending on the tracker
|
||||
"""
|
||||
trainer.accelerator.wait_for_everyone()
|
||||
|
||||
if trainer.accelerator.is_main_process:
|
||||
click.secho(
|
||||
f"RANK:{trainer.accelerator.process_index} | Saving Model | Best={is_best} | Latest={is_latest}",
|
||||
fg="magenta",
|
||||
)
|
||||
|
||||
tracker.save(
|
||||
trainer=trainer,
|
||||
is_best=is_best,
|
||||
is_latest=is_latest,
|
||||
epoch=int(epoch),
|
||||
samples_seen=int(samples_seen),
|
||||
best_validation_loss=best_validation_loss,
|
||||
)
|
||||
|
||||
|
||||
def recall_trainer(tracker: Tracker, trainer: DiffusionPriorTrainer):
|
||||
"""
|
||||
Loads the model with an appropriate method depending on the tracker
|
||||
"""
|
||||
|
||||
if trainer.accelerator.is_main_process:
|
||||
click.secho(f"Loading model from {type(tracker.loader).__name__}", fg="yellow")
|
||||
|
||||
state_dict = tracker.recall()
|
||||
|
||||
trainer.load(state_dict, strict=True)
|
||||
|
||||
return (
|
||||
int(state_dict.get("epoch", 0)),
|
||||
state_dict.get("best_validation_loss", 0),
|
||||
int(state_dict.get("samples_seen", 0)),
|
||||
)
|
||||
|
||||
|
||||
# eval functions
|
||||
|
||||
|
||||
def eval_model(
|
||||
def report_validation_loss(
|
||||
trainer: DiffusionPriorTrainer,
|
||||
dataloader: DataLoader,
|
||||
text_conditioned: bool,
|
||||
use_ema: bool,
|
||||
tracker: Tracker,
|
||||
split: str,
|
||||
tracker_folder: str,
|
||||
loss_type: str,
|
||||
tracker_context: str,
|
||||
tracker: BaseTracker = None,
|
||||
use_ema: bool = True,
|
||||
):
|
||||
trainer.eval()
|
||||
if trainer.is_main_process():
|
||||
click.secho(f"Measuring performance on {tracker_context}", fg="green", blink=True)
|
||||
"""
|
||||
Compute the validation loss on a given subset of data.
|
||||
"""
|
||||
|
||||
with torch.no_grad():
|
||||
total_loss = 0.0
|
||||
total_samples = 0.0
|
||||
if trainer.accelerator.is_main_process:
|
||||
click.secho(
|
||||
f"Measuring performance on {use_ema}-{split} split",
|
||||
fg="green",
|
||||
blink=True,
|
||||
)
|
||||
|
||||
for image_embeddings, text_data in dataloader:
|
||||
image_embeddings = image_embeddings.to(trainer.device)
|
||||
text_data = text_data.to(trainer.device)
|
||||
total_loss = torch.zeros(1, dtype=torch.float, device=trainer.device)
|
||||
|
||||
batches = image_embeddings.shape[0]
|
||||
for image_embeddings, text_data in dataloader:
|
||||
image_embeddings = image_embeddings.to(trainer.device)
|
||||
text_data = text_data.to(trainer.device)
|
||||
|
||||
input_args = dict(image_embed=image_embeddings)
|
||||
input_args = dict(image_embed=image_embeddings)
|
||||
|
||||
if text_conditioned:
|
||||
input_args = dict(**input_args, text=text_data)
|
||||
else:
|
||||
input_args = dict(**input_args, text_embed=text_data)
|
||||
if text_conditioned:
|
||||
input_args = dict(**input_args, text=text_data)
|
||||
else:
|
||||
input_args = dict(**input_args, text_embed=text_data)
|
||||
|
||||
if use_ema:
|
||||
loss = trainer.ema_diffusion_prior(**input_args)
|
||||
else:
|
||||
loss = trainer(**input_args)
|
||||
if use_ema:
|
||||
loss = trainer.ema_diffusion_prior(**input_args)
|
||||
else:
|
||||
loss = trainer(**input_args)
|
||||
|
||||
total_loss += loss * batches
|
||||
total_samples += batches
|
||||
total_loss += loss
|
||||
|
||||
avg_loss = total_loss / total_samples
|
||||
# compute the average loss across all processes
|
||||
|
||||
stats = {f"{tracker_context}-{loss_type}": avg_loss}
|
||||
trainer.print(stats)
|
||||
avg_loss = pad_gather_reduce(trainer, total_loss, method="mean")
|
||||
stats = {f"{tracker_folder}/{loss_type}-loss": avg_loss}
|
||||
|
||||
if exists(tracker):
|
||||
tracker.log(stats, step=trainer.step.item() + 1)
|
||||
# print and log results on main process
|
||||
tracker.log(stats, step=trainer.step.item() + 1)
|
||||
|
||||
return avg_loss
|
||||
|
||||
|
||||
def report_cosine_sims(
|
||||
trainer: DiffusionPriorTrainer,
|
||||
dataloader: DataLoader,
|
||||
text_conditioned: bool,
|
||||
tracker: BaseTracker,
|
||||
tracker_context: str = "validation",
|
||||
tracker: Tracker,
|
||||
split: str,
|
||||
timesteps: int,
|
||||
tracker_folder: str,
|
||||
):
|
||||
trainer.eval()
|
||||
if trainer.is_main_process():
|
||||
click.secho("Measuring Cosine-Similarity", fg="green", blink=True)
|
||||
if trainer.accelerator.is_main_process:
|
||||
click.secho(
|
||||
f"Measuring Cosine-Similarity on {split} split with {timesteps} timesteps",
|
||||
fg="green",
|
||||
blink=True,
|
||||
)
|
||||
|
||||
for test_image_embeddings, text_data in dataloader:
|
||||
test_image_embeddings = test_image_embeddings.to(trainer.device)
|
||||
@@ -126,10 +272,8 @@ def report_cosine_sims(
|
||||
|
||||
# we are text conditioned, we produce an embedding from the tokenized text
|
||||
if text_conditioned:
|
||||
text_embedding, text_encodings, text_mask = trainer.embed_text(text_data)
|
||||
text_cond = dict(
|
||||
text_embed=text_embedding, text_encodings=text_encodings, mask=text_mask
|
||||
)
|
||||
text_embedding, text_encodings = trainer.embed_text(text_data)
|
||||
text_cond = dict(text_embed=text_embedding, text_encodings=text_encodings)
|
||||
else:
|
||||
text_embedding = text_data
|
||||
text_cond = dict(text_embed=text_embedding)
|
||||
@@ -146,15 +290,11 @@ def report_cosine_sims(
|
||||
|
||||
if text_conditioned:
|
||||
text_encodings_shuffled = text_encodings[rolled_idx]
|
||||
text_mask_shuffled = text_mask[rolled_idx]
|
||||
else:
|
||||
text_encodings_shuffled = None
|
||||
text_mask_shuffled = None
|
||||
|
||||
text_cond_shuffled = dict(
|
||||
text_embed=text_embed_shuffled,
|
||||
text_encodings=text_encodings_shuffled,
|
||||
mask=text_mask_shuffled,
|
||||
text_embed=text_embed_shuffled, text_encodings=text_encodings_shuffled
|
||||
)
|
||||
|
||||
# prepare the text embedding
|
||||
@@ -167,7 +307,9 @@ def report_cosine_sims(
|
||||
|
||||
# predict on the unshuffled text embeddings
|
||||
predicted_image_embeddings = trainer.p_sample_loop(
|
||||
test_image_embeddings.shape, text_cond
|
||||
test_image_embeddings.shape,
|
||||
text_cond,
|
||||
timesteps=timesteps,
|
||||
)
|
||||
|
||||
predicted_image_embeddings = (
|
||||
@@ -177,7 +319,9 @@ def report_cosine_sims(
|
||||
|
||||
# predict on the shuffled embeddings
|
||||
predicted_unrelated_embeddings = trainer.p_sample_loop(
|
||||
test_image_embeddings.shape, text_cond_shuffled
|
||||
test_image_embeddings.shape,
|
||||
text_cond_shuffled,
|
||||
timesteps=timesteps,
|
||||
)
|
||||
|
||||
predicted_unrelated_embeddings = (
|
||||
@@ -186,32 +330,97 @@ def report_cosine_sims(
|
||||
)
|
||||
|
||||
# calculate similarities
|
||||
original_similarity = cos(text_embed, test_image_embeddings).cpu().numpy()
|
||||
predicted_similarity = cos(text_embed, predicted_image_embeddings).cpu().numpy()
|
||||
unrelated_similarity = (
|
||||
cos(text_embed, predicted_unrelated_embeddings).cpu().numpy()
|
||||
orig_sim = pad_gather_reduce(
|
||||
trainer, cos(text_embed, test_image_embeddings), method="mean"
|
||||
)
|
||||
predicted_img_similarity = (
|
||||
cos(test_image_embeddings, predicted_image_embeddings).cpu().numpy()
|
||||
pred_sim = pad_gather_reduce(
|
||||
trainer, cos(text_embed, predicted_image_embeddings), method="mean"
|
||||
)
|
||||
unrel_sim = pad_gather_reduce(
|
||||
trainer, cos(text_embed, predicted_unrelated_embeddings), method="mean"
|
||||
)
|
||||
pred_img_sim = pad_gather_reduce(
|
||||
trainer,
|
||||
cos(test_image_embeddings, predicted_image_embeddings),
|
||||
method="mean",
|
||||
)
|
||||
|
||||
stats = {
|
||||
f"{tracker_context}/baseline similarity": np.mean(original_similarity),
|
||||
f"{tracker_context}/similarity with text": np.mean(predicted_similarity),
|
||||
f"{tracker_context}/similarity with original image": np.mean(
|
||||
predicted_img_similarity
|
||||
),
|
||||
f"{tracker_context}/similarity with unrelated caption": np.mean(unrelated_similarity),
|
||||
f"{tracker_context}/difference from baseline similarity": np.mean(
|
||||
predicted_similarity - original_similarity
|
||||
),
|
||||
f"{tracker_folder}/baseline similarity [steps={timesteps}]": orig_sim,
|
||||
f"{tracker_folder}/similarity with text [steps={timesteps}]": pred_sim,
|
||||
f"{tracker_folder}/similarity with original image [steps={timesteps}]": pred_img_sim,
|
||||
f"{tracker_folder}/similarity with unrelated caption [steps={timesteps}]": unrel_sim,
|
||||
f"{tracker_folder}/difference from baseline similarity [steps={timesteps}]": pred_sim
|
||||
- orig_sim,
|
||||
}
|
||||
|
||||
for k, v in stats.items():
|
||||
trainer.print(f"{tracker_context}/{k}: {v}")
|
||||
tracker.log(stats, step=trainer.step.item() + 1)
|
||||
|
||||
if exists(tracker):
|
||||
tracker.log(stats, step=trainer.step.item() + 1)
|
||||
|
||||
def eval_model(
|
||||
trainer: DiffusionPriorTrainer,
|
||||
dataloader: DataLoader,
|
||||
text_conditioned: bool,
|
||||
split: str,
|
||||
tracker: Tracker,
|
||||
use_ema: bool,
|
||||
report_cosine: bool,
|
||||
report_loss: bool,
|
||||
timesteps: List[int],
|
||||
loss_type: str = None,
|
||||
):
|
||||
"""
|
||||
Run evaluation on a model and track metrics
|
||||
|
||||
returns: loss if requested
|
||||
"""
|
||||
trainer.eval()
|
||||
|
||||
use_ema = "ema" if use_ema else "online"
|
||||
tracker_folder = f"metrics/{use_ema}-{split}"
|
||||
|
||||
# detemine if valid timesteps are passed
|
||||
|
||||
min_timesteps = trainer.accelerator.unwrap_model(
|
||||
trainer.diffusion_prior
|
||||
).sample_timesteps
|
||||
max_timesteps = trainer.accelerator.unwrap_model(
|
||||
trainer.diffusion_prior
|
||||
).noise_scheduler.num_timesteps
|
||||
|
||||
assert all_between(
|
||||
timesteps, lower_bound=min_timesteps, upper_bound=max_timesteps
|
||||
), f"all timesteps values must be between {min_timesteps} and {max_timesteps}: got {timesteps}"
|
||||
|
||||
# measure cosine metrics across various eta and timesteps
|
||||
|
||||
if report_cosine:
|
||||
for timestep in timesteps:
|
||||
report_cosine_sims(
|
||||
trainer,
|
||||
dataloader=dataloader,
|
||||
text_conditioned=text_conditioned,
|
||||
tracker=tracker,
|
||||
split=split,
|
||||
timesteps=timestep,
|
||||
tracker_folder=tracker_folder,
|
||||
)
|
||||
|
||||
# measure loss on a seperate split of data
|
||||
|
||||
if report_loss:
|
||||
loss = report_validation_loss(
|
||||
trainer=trainer,
|
||||
dataloader=dataloader,
|
||||
text_conditioned=text_conditioned,
|
||||
use_ema=use_ema,
|
||||
tracker=tracker,
|
||||
split=split,
|
||||
tracker_folder=tracker_folder,
|
||||
loss_type=loss_type,
|
||||
)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
# training script
|
||||
@@ -219,182 +428,327 @@ def report_cosine_sims(
|
||||
|
||||
def train(
|
||||
trainer: DiffusionPriorTrainer,
|
||||
tracker: Tracker,
|
||||
train_loader: DataLoader,
|
||||
eval_loader: DataLoader,
|
||||
test_loader: DataLoader,
|
||||
config: DiffusionPriorTrainConfig,
|
||||
):
|
||||
# distributed tracking with wandb
|
||||
if trainer.accelerator.num_processes > 1:
|
||||
os.environ["WANDB_START_METHOD"] = "thread"
|
||||
# init timers
|
||||
save_timer = Timer() # when to save
|
||||
samples_timer = Timer() # samples/sec
|
||||
validation_profiler = Timer() # how long is validation taking
|
||||
validation_countdown = Timer() # when to perform evalutation
|
||||
|
||||
tracker = wandb.init(
|
||||
name=f"RANK:{trainer.device}",
|
||||
entity=config.tracker.wandb_entity,
|
||||
project=config.tracker.wandb_project,
|
||||
config=config.dict(),
|
||||
group=GROUP,
|
||||
)
|
||||
# keep track of best validation loss
|
||||
|
||||
# sync after tracker init
|
||||
trainer.wait_for_everyone()
|
||||
|
||||
# init a timer
|
||||
timer = Timer()
|
||||
best_validation_loss = config.train.best_validation_loss
|
||||
samples_seen = config.train.num_samples_seen
|
||||
|
||||
# do training
|
||||
for img, txt in train_loader:
|
||||
trainer.train()
|
||||
current_step = trainer.step.item() + 1
|
||||
|
||||
# place data on device
|
||||
img = img.to(trainer.device)
|
||||
txt = txt.to(trainer.device)
|
||||
start_epoch = config.train.current_epoch
|
||||
|
||||
# pass to model
|
||||
loss = trainer(text=txt, image_embed=img)
|
||||
for epoch in range(start_epoch, config.train.epochs):
|
||||
# if we finished out an old epoch, reset the distribution to be a full epoch
|
||||
tracker.log({"tracking/epoch": epoch}, step=trainer.step.item())
|
||||
|
||||
# display & log loss (will only print from main process)
|
||||
trainer.print(f"Step {current_step}: Loss {loss}")
|
||||
if train_loader.dataset.get_start() > 0 and epoch == start_epoch+1:
|
||||
if trainer.accelerator.is_main_process:
|
||||
click.secho(f"Finished resumed epoch...resetting dataloader.")
|
||||
train_loader.dataset.set_start(0)
|
||||
|
||||
# perform backprop & apply EMA updates
|
||||
trainer.update()
|
||||
for img, txt in train_loader:
|
||||
# setup things every step
|
||||
|
||||
# track samples/sec/rank
|
||||
samples_per_sec = img.shape[0] / timer.elapsed()
|
||||
trainer.train()
|
||||
current_step = trainer.step.item()
|
||||
samples_timer.reset()
|
||||
|
||||
# samples seen
|
||||
samples_seen = (
|
||||
config.data.batch_size * trainer.accelerator.num_processes * current_step
|
||||
)
|
||||
# place data on device
|
||||
|
||||
# ema decay
|
||||
ema_decay = trainer.ema_diffusion_prior.get_current_decay()
|
||||
img = img.to(trainer.device)
|
||||
txt = txt.to(trainer.device)
|
||||
|
||||
# Log on all processes for debugging
|
||||
tracker.log(
|
||||
{
|
||||
"tracking/samples-sec": samples_per_sec,
|
||||
"tracking/samples-seen": samples_seen,
|
||||
"tracking/ema-decay": ema_decay,
|
||||
"metrics/training-loss": loss,
|
||||
},
|
||||
step=current_step,
|
||||
)
|
||||
# pass to model
|
||||
|
||||
# Metric Tracking & Checkpointing (outside of timer's scope)
|
||||
if current_step % EVAL_EVERY == 0:
|
||||
eval_model(
|
||||
trainer=trainer,
|
||||
dataloader=eval_loader,
|
||||
text_conditioned=config.prior.condition_on_text_encodings,
|
||||
loss_type=config.prior.loss_type,
|
||||
tracker_context="metrics/online-model-validation",
|
||||
tracker=tracker,
|
||||
use_ema=False,
|
||||
loss = trainer(text=txt, image_embed=img)
|
||||
|
||||
# perform backprop & apply EMA updates
|
||||
|
||||
trainer.update()
|
||||
|
||||
# gather info about training step
|
||||
|
||||
all_loss = pad_gather_reduce(trainer, loss, method="mean")
|
||||
num_samples = pad_gather_reduce(trainer, len(txt), method="sum")
|
||||
samples_per_sec = num_samples / samples_timer.elapsed()
|
||||
samples_seen += num_samples
|
||||
ema_decay = trainer.ema_diffusion_prior.get_current_decay()
|
||||
|
||||
# log
|
||||
|
||||
tracker.log(
|
||||
{
|
||||
"tracking/samples-sec": samples_per_sec,
|
||||
"tracking/samples-seen": samples_seen,
|
||||
"tracking/ema-decay": ema_decay,
|
||||
f"tracking/training-{config.prior.loss_type}": all_loss,
|
||||
},
|
||||
step=current_step,
|
||||
)
|
||||
|
||||
eval_model(
|
||||
trainer=trainer,
|
||||
dataloader=eval_loader,
|
||||
text_conditioned=config.prior.condition_on_text_encodings,
|
||||
loss_type=config.prior.loss_type,
|
||||
tracker_context="metrics/ema-model-validation",
|
||||
tracker=tracker,
|
||||
use_ema=True,
|
||||
# Metric Tracking @ Timed Intervals
|
||||
|
||||
eval_delta = pad_gather_reduce(
|
||||
trainer, validation_countdown.elapsed(), method="min"
|
||||
)
|
||||
|
||||
report_cosine_sims(
|
||||
trainer=trainer,
|
||||
dataloader=eval_loader,
|
||||
text_conditioned=config.prior.condition_on_text_encodings,
|
||||
tracker=tracker,
|
||||
tracker_context="metrics",
|
||||
)
|
||||
if eval_delta != None and eval_delta > config.data.eval_every_seconds:
|
||||
# begin timing how long this takes
|
||||
|
||||
if current_step % config.train.save_every == 0:
|
||||
trainer.save(f"{config.tracker.data_path}/chkpt_step_{current_step}.pth")
|
||||
validation_profiler.reset()
|
||||
|
||||
# reset timer for next round
|
||||
timer.reset()
|
||||
# package kwargs for evaluation
|
||||
|
||||
eval_kwargs = {
|
||||
"trainer": trainer,
|
||||
"tracker": tracker,
|
||||
"text_conditioned": config.prior.condition_on_text_encodings,
|
||||
"timesteps": config.train.eval_timesteps,
|
||||
}
|
||||
|
||||
# ONLINE MODEL : COSINE : LOSS : VALIDATION SPLIT
|
||||
|
||||
eval_model(
|
||||
dataloader=eval_loader,
|
||||
loss_type=config.prior.loss_type,
|
||||
split="validation",
|
||||
use_ema=False,
|
||||
report_cosine=False,
|
||||
report_loss=True,
|
||||
**eval_kwargs,
|
||||
)
|
||||
|
||||
# EMA MODEL : COSINE : LOSS : VALIDATION DATA
|
||||
|
||||
ema_val_loss = eval_model(
|
||||
dataloader=eval_loader,
|
||||
loss_type=config.prior.loss_type,
|
||||
split="validation",
|
||||
use_ema=True,
|
||||
report_cosine=True,
|
||||
report_loss=True,
|
||||
**eval_kwargs,
|
||||
)
|
||||
|
||||
tracker.log(
|
||||
{
|
||||
"tracking/validation length (minutes)": validation_profiler.elapsed()
|
||||
/ 60
|
||||
}
|
||||
)
|
||||
|
||||
# check if the ema validation is the lowest seen yet
|
||||
|
||||
if ema_val_loss < best_validation_loss:
|
||||
best_validation_loss = ema_val_loss
|
||||
|
||||
# go save the model as best
|
||||
|
||||
save_trainer(
|
||||
trainer=trainer,
|
||||
tracker=tracker,
|
||||
is_best=True,
|
||||
is_latest=False,
|
||||
samples_seen=samples_seen,
|
||||
epoch=epoch,
|
||||
best_validation_loss=best_validation_loss,
|
||||
)
|
||||
|
||||
# reset timer for validaiton
|
||||
|
||||
validation_countdown.reset()
|
||||
|
||||
elif eval_delta is None:
|
||||
click.secho(
|
||||
f"Error occured reading the eval time on rank: {trainer.device}",
|
||||
fg="yellow",
|
||||
)
|
||||
|
||||
# save as latest model on schedule
|
||||
|
||||
save_delta = pad_gather_reduce(trainer, save_timer.elapsed(), method="min")
|
||||
|
||||
if save_delta != None and save_delta >= config.train.save_every_seconds:
|
||||
save_trainer(
|
||||
trainer=trainer,
|
||||
tracker=tracker,
|
||||
is_best=False,
|
||||
is_latest=True,
|
||||
samples_seen=samples_seen,
|
||||
epoch=epoch,
|
||||
best_validation_loss=best_validation_loss,
|
||||
)
|
||||
|
||||
save_timer.reset()
|
||||
|
||||
elif save_delta is None:
|
||||
click.secho(
|
||||
f"Error occured reading the save time on rank: {trainer.device}",
|
||||
fg="yellow",
|
||||
)
|
||||
|
||||
# evaluate on test data
|
||||
|
||||
eval_model(
|
||||
if trainer.accelerator.is_main_process:
|
||||
click.secho(f"Starting Test", fg="red")
|
||||
|
||||
# save one last time as latest before beginning validation
|
||||
|
||||
save_trainer(
|
||||
tracker=tracker,
|
||||
trainer=trainer,
|
||||
is_best=False,
|
||||
is_latest=True,
|
||||
samples_seen=samples_seen,
|
||||
epoch=epoch,
|
||||
best_validation_loss=best_validation_loss,
|
||||
)
|
||||
|
||||
test_loss = eval_model(
|
||||
trainer=trainer,
|
||||
dataloader=test_loader,
|
||||
text_conditioned=config.prior.condition_on_text_encodings,
|
||||
loss_type=config.prior.loss_type,
|
||||
tracker_context="test",
|
||||
split="test",
|
||||
tracker=tracker,
|
||||
use_ema=True,
|
||||
report_cosine=False,
|
||||
report_loss=True,
|
||||
timesteps=config.train.eval_timesteps,
|
||||
loss_type=config.prior.loss_type,
|
||||
)
|
||||
|
||||
report_cosine_sims(
|
||||
trainer,
|
||||
test_loader,
|
||||
config.prior.condition_on_text_encodings,
|
||||
tracker,
|
||||
tracker_context="test",
|
||||
)
|
||||
if test_loss < best_validation_loss:
|
||||
best_validation_loss = test_loss
|
||||
|
||||
# go save the model as best
|
||||
|
||||
save_trainer(
|
||||
trainer=trainer,
|
||||
tracker=tracker,
|
||||
is_best=True,
|
||||
is_latest=False,
|
||||
samples_seen=samples_seen,
|
||||
epoch=epoch,
|
||||
best_validation_loss=test_loss,
|
||||
)
|
||||
|
||||
|
||||
def initialize_training(config, accelerator=None):
|
||||
def initialize_training(config_file, accelerator):
|
||||
"""
|
||||
Parse the configuration file, and prepare everything necessary for training
|
||||
"""
|
||||
# load the configuration file
|
||||
if accelerator.is_main_process:
|
||||
click.secho(f"Loading configuration from {config_file}", fg="green")
|
||||
|
||||
config = TrainDiffusionPriorConfig.from_json_path(config_file)
|
||||
|
||||
# seed
|
||||
|
||||
set_seed(config.train.random_seed)
|
||||
|
||||
# get a device
|
||||
|
||||
if accelerator:
|
||||
device = accelerator.device
|
||||
click.secho(f"Accelerating on: {device}", fg="yellow")
|
||||
else:
|
||||
if torch.cuda.is_available():
|
||||
click.secho("GPU detected, defaulting to cuda:0", fg="yellow")
|
||||
device = "cuda:0"
|
||||
else:
|
||||
click.secho("No GPU detected...using cpu", fg="yellow")
|
||||
device = "cpu"
|
||||
device = accelerator.device
|
||||
|
||||
# make the trainer (will automatically distribute if possible & configured)
|
||||
|
||||
trainer = make_model(config.prior, config.train, device, accelerator).to(device)
|
||||
trainer: DiffusionPriorTrainer = make_model(
|
||||
config.prior, config.train, device, accelerator
|
||||
).to(device)
|
||||
|
||||
# create a tracker
|
||||
|
||||
tracker = create_tracker(
|
||||
accelerator, config, config_file, dummy=accelerator.process_index != 0
|
||||
)
|
||||
|
||||
# reload from chcekpoint
|
||||
|
||||
if config.load.resume == True:
|
||||
click.secho(f"Loading checkpoint: {config.load.source}", fg="cyan")
|
||||
trainer.load(config.load.source)
|
||||
if tracker.can_recall:
|
||||
current_epoch, best_validation_loss, samples_seen = recall_trainer(
|
||||
tracker=tracker, trainer=trainer
|
||||
)
|
||||
|
||||
# display best values
|
||||
if trainer.accelerator.is_main_process:
|
||||
click.secho(f"Current Epoch: {current_epoch} | Best Val Loss: {best_validation_loss} | Samples Seen: {samples_seen}", fg="yellow")
|
||||
|
||||
# update config to reflect recalled values
|
||||
config.train.num_samples_seen = samples_seen
|
||||
config.train.current_epoch = current_epoch
|
||||
config.train.best_validation_loss = best_validation_loss
|
||||
|
||||
# fetch and prepare data
|
||||
|
||||
if trainer.is_main_process():
|
||||
click.secho("Grabbing data from source", fg="blue", blink=True)
|
||||
if trainer.accelerator.is_main_process:
|
||||
click.secho("Grabbing data...", fg="blue", blink=True)
|
||||
|
||||
trainer.accelerator.wait_for_everyone()
|
||||
img_reader = get_reader(
|
||||
text_conditioned=trainer.text_conditioned,
|
||||
img_url=config.data.image_url,
|
||||
meta_url=config.data.meta_url,
|
||||
)
|
||||
|
||||
# calculate start point within epoch
|
||||
|
||||
trainer.accelerator.wait_for_everyone()
|
||||
|
||||
train_loader, eval_loader, test_loader = make_splits(
|
||||
text_conditioned=trainer.text_conditioned,
|
||||
batch_size=config.data.batch_size,
|
||||
num_data_points=NUM_DATA_POINTS,
|
||||
num_data_points=config.data.num_data_points,
|
||||
train_split=config.data.splits.train,
|
||||
eval_split=config.data.splits.val,
|
||||
image_reader=img_reader,
|
||||
rank=accelerator.state.process_index if exists(accelerator) else 0,
|
||||
world_size=accelerator.state.num_processes if exists(accelerator) else 1,
|
||||
start=START,
|
||||
rank=accelerator.state.process_index,
|
||||
world_size=accelerator.state.num_processes,
|
||||
start=0,
|
||||
)
|
||||
|
||||
# wait for everyone to load data before continuing
|
||||
trainer.wait_for_everyone()
|
||||
# update the start point to finish out the epoch on a resumed run
|
||||
|
||||
if tracker.can_recall:
|
||||
samples_seen = config.train.num_samples_seen
|
||||
length = (
|
||||
config.data.num_data_points
|
||||
if samples_seen <= img_reader.count
|
||||
else img_reader.count
|
||||
)
|
||||
scaled_samples = length * config.train.current_epoch
|
||||
start_point = (
|
||||
scaled_samples - samples_seen if scaled_samples > samples_seen else samples_seen
|
||||
)
|
||||
|
||||
if trainer.accelerator.is_main_process:
|
||||
click.secho(f"Resuming at sample: {start_point}", fg="yellow")
|
||||
|
||||
train_loader.dataset.set_start(start_point)
|
||||
|
||||
# start training
|
||||
|
||||
if trainer.accelerator.is_main_process:
|
||||
click.secho(
|
||||
f"Beginning Prior Training : Distributed={accelerator.state.distributed_type != accelerate_dataclasses.DistributedType.NO}",
|
||||
fg="yellow",
|
||||
)
|
||||
|
||||
train(
|
||||
trainer=trainer,
|
||||
tracker=tracker,
|
||||
train_loader=train_loader,
|
||||
eval_loader=eval_loader,
|
||||
test_loader=test_loader,
|
||||
@@ -403,23 +757,13 @@ def initialize_training(config, accelerator=None):
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option("--hfa", default=True)
|
||||
@click.option("--config_path", default="configs/prior.json")
|
||||
def main(hfa, config_path):
|
||||
# start HFA if requested
|
||||
if hfa:
|
||||
accelerator = Accelerator()
|
||||
else:
|
||||
accelerator = None
|
||||
@click.option("--config_file", default="configs/train_prior_config.example.json")
|
||||
def main(config_file):
|
||||
# start HFA
|
||||
accelerator = Accelerator()
|
||||
|
||||
# load the configuration file on main process
|
||||
if not exists(accelerator) or accelerator.is_main_process:
|
||||
click.secho(f"Loading configuration from {config_path}", fg="green")
|
||||
|
||||
config = TrainDiffusionPriorConfig.from_json_path(config_path)
|
||||
|
||||
# send config to get processed
|
||||
initialize_training(config, accelerator)
|
||||
# setup training
|
||||
initialize_training(config_file, accelerator)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user