mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-14 21:34:19 +01:00
Compare commits
107 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b9a908ff75 | ||
|
|
e1fe3089df | ||
|
|
6d477d7654 | ||
|
|
531fe4b62f | ||
|
|
ec5a77fc55 | ||
|
|
fac63c61bc | ||
|
|
3d23ba4aa5 | ||
|
|
282c35930f | ||
|
|
27b0f7ca0d | ||
|
|
7b0edf9e42 | ||
|
|
a922a539de | ||
|
|
8f2466f1cd | ||
|
|
908ab83799 | ||
|
|
46a2558d53 | ||
|
|
86109646e3 | ||
|
|
6a11b9678b | ||
|
|
b90364695d | ||
|
|
868c001199 | ||
|
|
032e83b0e0 | ||
|
|
2e85e736f3 | ||
|
|
f5760bdb92 | ||
|
|
c453f468b1 | ||
|
|
98f0c17759 | ||
|
|
a5b9fd6ca8 | ||
|
|
4b994601ae | ||
|
|
fddf66e91e | ||
|
|
c8422ffd5d | ||
|
|
2aadc23c7c | ||
|
|
c098f57e09 | ||
|
|
0021535c26 | ||
|
|
56883910fb | ||
|
|
893f270012 | ||
|
|
f545ce18f4 | ||
|
|
fc7abf624d | ||
|
|
67f0740777 | ||
|
|
138079ca83 | ||
|
|
f5a906f5d3 | ||
|
|
0215237fc6 | ||
|
|
461b91c5c1 | ||
|
|
58892135d9 | ||
|
|
e37072a48c | ||
|
|
41ca896413 | ||
|
|
fe19b508ca | ||
|
|
6651eafa93 | ||
|
|
e6bb75e5ab | ||
|
|
b4c3e5b854 | ||
|
|
b7f9607258 | ||
|
|
2219348a6e | ||
|
|
9eea9b9862 | ||
|
|
5d958713c0 | ||
|
|
0f31980362 | ||
|
|
bee5bf3815 | ||
|
|
350a3d6045 | ||
|
|
1a81670718 | ||
|
|
934c9728dc | ||
|
|
ce4b0107c1 | ||
|
|
64c2f9c4eb | ||
|
|
22cc613278 | ||
|
|
83517849e5 | ||
|
|
708809ed6c | ||
|
|
9cc475f6e7 | ||
|
|
ffd342e9d0 | ||
|
|
f8bfd3493a | ||
|
|
9025345e29 | ||
|
|
8cc278447e | ||
|
|
38cd62010c | ||
|
|
1cc288af39 | ||
|
|
a851168633 | ||
|
|
1ffeecd0ca | ||
|
|
3df899f7a4 | ||
|
|
09534119a1 | ||
|
|
6f8b90d4d7 | ||
|
|
b588286288 | ||
|
|
b693e0be03 | ||
|
|
a0bed30a84 | ||
|
|
387c5bf774 | ||
|
|
a13d2d89c5 | ||
|
|
44d4b1bba9 | ||
|
|
f12a7589c5 | ||
|
|
b8af2210df | ||
|
|
f4fe6c570d | ||
|
|
645e207441 | ||
|
|
00743b3a0b | ||
|
|
01589aff6a | ||
|
|
7ecfd76cc0 | ||
|
|
6161b61c55 | ||
|
|
1ed0f9d80b | ||
|
|
f326a95e26 | ||
|
|
d7a0a2ce4b | ||
|
|
f23fab7ef7 | ||
|
|
857b9fbf1e | ||
|
|
8864fd0aa7 | ||
|
|
72bf159331 | ||
|
|
e5e47cfecb | ||
|
|
fa533962bd | ||
|
|
276abf337b | ||
|
|
ae42d03006 | ||
|
|
4d346e98d9 | ||
|
|
2b1fd1ad2e | ||
|
|
82a2ef37d9 | ||
|
|
5c397c9d66 | ||
|
|
0f4edff214 | ||
|
|
501a8c7c46 | ||
|
|
4e49373fc5 | ||
|
|
49de72040c | ||
|
|
271a376eaf | ||
|
|
e527002472 |
134
README.md
134
README.md
@@ -12,7 +12,7 @@ This model is SOTA for text-to-image for now.
|
|||||||
|
|
||||||
Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord" src="https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white"></a> if you are interested in helping out with the replication with the <a href="https://laion.ai/">LAION</a> community | <a href="https://www.youtube.com/watch?v=AIOE1l1W0Tw">Yannic Interview</a>
|
Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord" src="https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white"></a> if you are interested in helping out with the replication with the <a href="https://laion.ai/">LAION</a> community | <a href="https://www.youtube.com/watch?v=AIOE1l1W0Tw">Yannic Interview</a>
|
||||||
|
|
||||||
There was enough interest for a <a href="https://github.com/lucidrains/dalle2-jax">Jax version</a>. I will also eventually extend this to <a href="https://github.com/lucidrains/dalle2-video">text to video</a>, once the repository is in a good place.
|
As of 5/23/22, it is no longer SOTA. SOTA will be <a href="https://github.com/lucidrains/imagen-pytorch">here</a>. Jax versions as well as text-to-video project will be shifted towards the Imagen architecture, as it is way simpler.
|
||||||
|
|
||||||
## Status
|
## Status
|
||||||
|
|
||||||
@@ -20,10 +20,36 @@ There was enough interest for a <a href="https://github.com/lucidrains/dalle2-ja
|
|||||||
|
|
||||||
- Decoder is now verified working for unconditional generation on my experimental setup for Oxford flowers. 2 researchers have also confirmed Decoder is working for them.
|
- Decoder is now verified working for unconditional generation on my experimental setup for Oxford flowers. 2 researchers have also confirmed Decoder is working for them.
|
||||||
|
|
||||||
<img src="./samples/oxford.png" width="600px" />
|
<img src="./samples/oxford.png" width="450px" />
|
||||||
|
|
||||||
*ongoing at 21k steps*
|
*ongoing at 21k steps*
|
||||||
|
|
||||||
|
- <a href="https://twitter.com/Buntworthy/status/1529475416775434240?t=0GEge3Kr9I36cjcUVCQUTg">Justin Pinkney</a> successfully trained the diffusion prior in the repository for his CLIP to Stylegan2 text-to-image application
|
||||||
|
|
||||||
|
- <a href="https://github.com/rom1504">Romain</a> has scaled up training to 800 GPUs with the available scripts without any issues
|
||||||
|
|
||||||
|
## Pre-Trained Models
|
||||||
|
|
||||||
|
- LAION is training prior models. Checkpoints are available on <a href="https://huggingface.co/zenglishuci/conditioned-prior">🤗huggingface</a> and the training statistics are available on <a href="https://wandb.ai/nousr_laion/conditioned-prior/reports/LAION-DALLE2-PyTorch-Prior--VmlldzoyMDI2OTIx">🐝WANDB</a>.
|
||||||
|
- Decoder - <a href="https://wandb.ai/veldrovive/dalle2_train_decoder/runs/jkrtg0so?workspace=user-veldrovive">In-progress test run</a> 🚧
|
||||||
|
- Decoder - <a href="https://wandb.ai/veldrovive/dalle2_train_decoder/runs/3d5rytsa?workspace=">Another test run with sparse attention</a>
|
||||||
|
- DALL-E 2 🚧 - <a href="https://github.com/LAION-AI/dalle2-laion">DALL-E 2 Laion repository</a>
|
||||||
|
|
||||||
|
## Appreciation
|
||||||
|
|
||||||
|
This library would not have gotten to this working state without the help of
|
||||||
|
|
||||||
|
- <a href="https://github.com/nousr">Zion</a> for the distributed training code for the diffusion prior
|
||||||
|
- <a href="https://github.com/Veldrovive">Aidan</a> for the distributed training code for the decoder as well as the dataloaders
|
||||||
|
- <a href="https://github.com/krish240574">Kumar</a> for working on the initial diffusion training script
|
||||||
|
- <a href="https://github.com/rom1504">Romain</a> for the pull request reviews and project management
|
||||||
|
- <a href="https://github.com/Ciaohe">He Cao</a> and <a href="https://github.com/xiankgx">xiankgx</a> for the Q&A and for identifying of critical bugs
|
||||||
|
- <a href="https://github.com/crowsonkb">Katherine</a> for her advice
|
||||||
|
- <a href="https://stability.ai/">Stability AI</a> for the generous sponsorship
|
||||||
|
- <a href="https://huggingface.co">🤗 Huggingface</a> and in particular <a href="https://github.com/sgugger">Sylvain</a> for the <a href="https://github.com/huggingface/accelerate">Accelerate</a> library
|
||||||
|
|
||||||
|
... and many others. Thank you! 🙏
|
||||||
|
|
||||||
## Install
|
## Install
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
@@ -344,7 +370,8 @@ unet1 = Unet(
|
|||||||
image_embed_dim = 512,
|
image_embed_dim = 512,
|
||||||
cond_dim = 128,
|
cond_dim = 128,
|
||||||
channels = 3,
|
channels = 3,
|
||||||
dim_mults=(1, 2, 4, 8)
|
dim_mults=(1, 2, 4, 8),
|
||||||
|
cond_on_text_encodings = True # set to True for any unets that need to be conditioned on text encodings
|
||||||
).cuda()
|
).cuda()
|
||||||
|
|
||||||
unet2 = Unet(
|
unet2 = Unet(
|
||||||
@@ -361,8 +388,7 @@ decoder = Decoder(
|
|||||||
clip = clip,
|
clip = clip,
|
||||||
timesteps = 100,
|
timesteps = 100,
|
||||||
image_cond_drop_prob = 0.1,
|
image_cond_drop_prob = 0.1,
|
||||||
text_cond_drop_prob = 0.5,
|
text_cond_drop_prob = 0.5
|
||||||
condition_on_text_encodings = False # set this to True if you wish to condition on text during training and sampling
|
|
||||||
).cuda()
|
).cuda()
|
||||||
|
|
||||||
for unet_number in (1, 2):
|
for unet_number in (1, 2):
|
||||||
@@ -936,7 +962,7 @@ from dalle2_pytorch.dataloaders import ImageEmbeddingDataset, create_image_embed
|
|||||||
|
|
||||||
# Create a dataloader directly.
|
# Create a dataloader directly.
|
||||||
dataloader = create_image_embedding_dataloader(
|
dataloader = create_image_embedding_dataloader(
|
||||||
tar_url="/path/or/url/to/webdataset/{0000..9999}.tar", # Uses braket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar
|
tar_url="/path/or/url/to/webdataset/{0000..9999}.tar", # Uses bracket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar
|
||||||
embeddings_url="path/or/url/to/embeddings/folder", # Included if .npy files are not in webdataset. Left out or set to None otherwise
|
embeddings_url="path/or/url/to/embeddings/folder", # Included if .npy files are not in webdataset. Left out or set to None otherwise
|
||||||
num_workers=4,
|
num_workers=4,
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
@@ -993,33 +1019,6 @@ The most significant parameters for the script are as follows:
|
|||||||
|
|
||||||
- `clip`, default = `None` # Signals the prior to use pre-computed embeddings
|
- `clip`, default = `None` # Signals the prior to use pre-computed embeddings
|
||||||
|
|
||||||
#### Loading and Saving the DiffusionPrior model
|
|
||||||
|
|
||||||
Two methods are provided, load_diffusion_model and save_diffusion_model, the names being self-explanatory.
|
|
||||||
|
|
||||||
```python
|
|
||||||
from dalle2_pytorch.train import load_diffusion_model, save_diffusion_model
|
|
||||||
```
|
|
||||||
|
|
||||||
##### Loading
|
|
||||||
|
|
||||||
load_diffusion_model(dprior_path, device)
|
|
||||||
dprior_path : path to saved model(.pth)
|
|
||||||
device : the cuda device you're running on
|
|
||||||
|
|
||||||
##### Saving
|
|
||||||
|
|
||||||
save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim)
|
|
||||||
save_path : path to save at
|
|
||||||
model : object of Diffusion_Prior
|
|
||||||
optimizer : optimizer object - see train_diffusion_prior.py for how to create one.
|
|
||||||
e.g: optimizer = get_optimizer(diffusion_prior.net.parameters(), wd=weight_decay, lr=learning_rate)
|
|
||||||
scaler : a GradScaler object.
|
|
||||||
e.g: scaler = GradScaler(enabled=amp)
|
|
||||||
config : config object created in train_diffusion_prior.py - see file for example.
|
|
||||||
image_embed_dim - the dimension of the image_embedding
|
|
||||||
e.g: 768
|
|
||||||
|
|
||||||
## CLI (wip)
|
## CLI (wip)
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
@@ -1034,18 +1033,6 @@ Once built, images will be saved to the same directory the command is invoked
|
|||||||
|
|
||||||
<a href="https://github.com/lucidrains/stylegan2-pytorch">template</a>
|
<a href="https://github.com/lucidrains/stylegan2-pytorch">template</a>
|
||||||
|
|
||||||
## Appreciation
|
|
||||||
|
|
||||||
This library would not have gotten to this working state without the help of
|
|
||||||
|
|
||||||
- <a href="https://github.com/nousr">Zion</a> and <a href="https://github.com/krish240574">Kumar</a> for the diffusion training script
|
|
||||||
- <a href="https://github.com/Veldrovive">Aidan</a> for the decoder training script and dataloaders
|
|
||||||
- <a href="https://github.com/rom1504">Romain</a> for the pull request reviews and project management
|
|
||||||
- <a href="https://github.com/Ciaohe">He Cao</a> and <a href="https://github.com/xiankgx">xiankgx</a> for the Q&A and for identifying of critical bugs
|
|
||||||
- <a href="https://github.com/crowsonkb">Katherine</a> for her advice
|
|
||||||
|
|
||||||
... and many others. Thank you! 🙏
|
|
||||||
|
|
||||||
## Todo
|
## Todo
|
||||||
|
|
||||||
- [x] finish off gaussian diffusion class for latent embedding - allow for prediction of epsilon
|
- [x] finish off gaussian diffusion class for latent embedding - allow for prediction of epsilon
|
||||||
@@ -1077,22 +1064,17 @@ This library would not have gotten to this working state without the help of
|
|||||||
- [x] cross embed layers for downsampling, as an option
|
- [x] cross embed layers for downsampling, as an option
|
||||||
- [x] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a>
|
- [x] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a>
|
||||||
- [x] use pydantic for config drive training
|
- [x] use pydantic for config drive training
|
||||||
|
- [x] for both diffusion prior and decoder, all exponential moving averaged models needs to be saved and restored as well (as well as the step number)
|
||||||
|
- [x] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
|
||||||
|
- [x] allow for creation of diffusion prior model off pydantic config classes - consider the same for tracker configs
|
||||||
|
- [x] bring in skip-layer excitations (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training (doesnt work well)
|
||||||
|
- [x] test out grid attention in cascading ddpm locally, decide whether to keep or remove https://arxiv.org/abs/2204.01697 (keeping, seems to be fine)
|
||||||
|
- [x] allow for unet to be able to condition non-cross attention style as well
|
||||||
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
|
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
|
||||||
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
|
- [ ] speed up inference, read up on papers (ddim or diffusion-gan, etc)
|
||||||
- [ ] train on a toy task, offer in colab
|
|
||||||
- [ ] think about how best to design a declarative training config that handles preencoding for prior and training of multiple networks in decoder
|
|
||||||
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
|
|
||||||
- [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824
|
- [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824
|
||||||
- [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove
|
|
||||||
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
|
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
|
||||||
- [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783
|
- [ ] add inpainting ability using resampler from repaint paper https://arxiv.org/abs/2201.09865
|
||||||
- [ ] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
|
|
||||||
- [ ] bring in skip-layer excitatons (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training
|
|
||||||
- [ ] decoder needs one day worth of refactor for tech debt
|
|
||||||
- [ ] allow for unet to be able to condition non-cross attention style as well
|
|
||||||
- [ ] for all model classes with hyperparameters that changes the network architecture, make it requirement that they must expose a config property, and write a simple function that asserts that it restores the object correctly
|
|
||||||
- [ ] for both diffusion prior and decoder, all exponential moving averaged models needs to be saved and restored as well (as well as the step number)
|
|
||||||
- [ ] read the paper, figure it out, and build it https://github.com/lucidrains/DALLE2-pytorch/issues/89
|
|
||||||
|
|
||||||
## Citations
|
## Citations
|
||||||
|
|
||||||
@@ -1132,14 +1114,6 @@ This library would not have gotten to this working state without the help of
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
```bibtex
|
|
||||||
@inproceedings{Tu2022MaxViTMV,
|
|
||||||
title = {MaxViT: Multi-Axis Vision Transformer},
|
|
||||||
author = {Zhe-Wei Tu and Hossein Talebi and Han Zhang and Feng Yang and Peyman Milanfar and Alan Conrad Bovik and Yinxiao Li},
|
|
||||||
year = {2022}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
```bibtex
|
```bibtex
|
||||||
@article{Yu2021VectorquantizedIM,
|
@article{Yu2021VectorquantizedIM,
|
||||||
title = {Vector-quantized Image Modeling with Improved VQGAN},
|
title = {Vector-quantized Image Modeling with Improved VQGAN},
|
||||||
@@ -1190,4 +1164,32 @@ This library would not have gotten to this working state without the help of
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@misc{Saharia2022,
|
||||||
|
title = {Imagen: unprecedented photorealism × deep level of language understanding},
|
||||||
|
author = {Chitwan Saharia*, William Chan*, Saurabh Saxena†, Lala Li†, Jay Whang†, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S. Sara Mahdavi, Rapha Gontijo Lopes, Tim Salimans, Jonathan Ho†, David Fleet†, Mohammad Norouzi*},
|
||||||
|
year = {2022}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@article{Choi2022PerceptionPT,
|
||||||
|
title = {Perception Prioritized Training of Diffusion Models},
|
||||||
|
author = {Jooyoung Choi and Jungbeom Lee and Chaehun Shin and Sungwon Kim and Hyunwoo J. Kim and Sung-Hoon Yoon},
|
||||||
|
journal = {ArXiv},
|
||||||
|
year = {2022},
|
||||||
|
volume = {abs/2204.00227}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@article{Saharia2021PaletteID,
|
||||||
|
title = {Palette: Image-to-Image Diffusion Models},
|
||||||
|
author = {Chitwan Saharia and William Chan and Huiwen Chang and Chris A. Lee and Jonathan Ho and Tim Salimans and David J. Fleet and Mohammad Norouzi},
|
||||||
|
journal = {ArXiv},
|
||||||
|
year = {2021},
|
||||||
|
volume = {abs/2111.05826}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>
|
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>
|
||||||
|
|||||||
@@ -6,9 +6,10 @@ For more complex configuration, we provide the option of using a configuration f
|
|||||||
|
|
||||||
The decoder trainer has 7 main configuration options. A full example of their use can be found in the [example decoder configuration](train_decoder_config.example.json).
|
The decoder trainer has 7 main configuration options. A full example of their use can be found in the [example decoder configuration](train_decoder_config.example.json).
|
||||||
|
|
||||||
**<ins>Unets</ins>:**
|
**<ins>Unet</ins>:**
|
||||||
|
|
||||||
|
This is a single unet config, which belongs as an array nested under the decoder config as a list of `unets`
|
||||||
|
|
||||||
Each member of this array defines a single unet that will be added to the decoder.
|
|
||||||
| Option | Required | Default | Description |
|
| Option | Required | Default | Description |
|
||||||
| ------ | -------- | ------- | ----------- |
|
| ------ | -------- | ------- | ----------- |
|
||||||
| `dim` | Yes | N/A | The starting channels of the unet. |
|
| `dim` | Yes | N/A | The starting channels of the unet. |
|
||||||
@@ -22,6 +23,7 @@ Any parameter from the `Unet` constructor can also be given here.
|
|||||||
Defines the configuration options for the decoder model. The unets defined above will automatically be inserted.
|
Defines the configuration options for the decoder model. The unets defined above will automatically be inserted.
|
||||||
| Option | Required | Default | Description |
|
| Option | Required | Default | Description |
|
||||||
| ------ | -------- | ------- | ----------- |
|
| ------ | -------- | ------- | ----------- |
|
||||||
|
| `unets` | Yes | N/A | A list of unets, using the configuration above |
|
||||||
| `image_sizes` | Yes | N/A | The resolution of the image after each upsampling step. The length of this array should be the number of unets defined. |
|
| `image_sizes` | Yes | N/A | The resolution of the image after each upsampling step. The length of this array should be the number of unets defined. |
|
||||||
| `image_size` | Yes | N/A | Not used. Can be any number. |
|
| `image_size` | Yes | N/A | Not used. Can be any number. |
|
||||||
| `timesteps` | No | `1000` | The number of diffusion timesteps used for generation. |
|
| `timesteps` | No | `1000` | The number of diffusion timesteps used for generation. |
|
||||||
@@ -81,7 +83,7 @@ Defines which evaluation metrics will be used to test the model.
|
|||||||
Each metric can be enabled by setting its configuration. The configuration keys for each metric are defined by the torchmetrics constructors which will be linked.
|
Each metric can be enabled by setting its configuration. The configuration keys for each metric are defined by the torchmetrics constructors which will be linked.
|
||||||
| Option | Required | Default | Description |
|
| Option | Required | Default | Description |
|
||||||
| ------ | -------- | ------- | ----------- |
|
| ------ | -------- | ------- | ----------- |
|
||||||
| `n_evalation_samples` | No | `1000` | The number of samples to generate to test the model. |
|
| `n_evaluation_samples` | No | `1000` | The number of samples to generate to test the model. |
|
||||||
| `FID` | No | `None` | Setting to an object enables the [Frechet Inception Distance](https://torchmetrics.readthedocs.io/en/stable/image/frechet_inception_distance.html) metric.
|
| `FID` | No | `None` | Setting to an object enables the [Frechet Inception Distance](https://torchmetrics.readthedocs.io/en/stable/image/frechet_inception_distance.html) metric.
|
||||||
| `IS` | No | `None` | Setting to an object enables the [Inception Score](https://torchmetrics.readthedocs.io/en/stable/image/inception_score.html) metric.
|
| `IS` | No | `None` | Setting to an object enables the [Inception Score](https://torchmetrics.readthedocs.io/en/stable/image/inception_score.html) metric.
|
||||||
| `KID` | No | `None` | Setting to an object enables the [Kernel Inception Distance](https://torchmetrics.readthedocs.io/en/stable/image/kernel_inception_distance.html) metric. |
|
| `KID` | No | `None` | Setting to an object enables the [Kernel Inception Distance](https://torchmetrics.readthedocs.io/en/stable/image/kernel_inception_distance.html) metric. |
|
||||||
@@ -89,21 +91,83 @@ Each metric can be enabled by setting its configuration. The configuration keys
|
|||||||
|
|
||||||
**<ins>Tracker</ins>:**
|
**<ins>Tracker</ins>:**
|
||||||
|
|
||||||
Selects which tracker to use and configures it.
|
Selects how the experiment will be tracked.
|
||||||
| Option | Required | Default | Description |
|
| Option | Required | Default | Description |
|
||||||
| ------ | -------- | ------- | ----------- |
|
| ------ | -------- | ------- | ----------- |
|
||||||
| `tracker_type` | No | `console` | Which tracker to use. Currently accepts `console` or `wandb`. |
|
| `data_path` | No | `./.tracker-data` | The path to the folder where temporary tracker data will be saved. |
|
||||||
| `data_path` | No | `./models` | Where the tracker will store local data. |
|
| `overwrite_data_path` | No | `False` | If true, the data path will be overwritten. Otherwise, you need to delete it yourself. |
|
||||||
| `verbose` | No | `False` | Enables console logging for non-console trackers. |
|
| `log` | Yes | N/A | Logging configuration. |
|
||||||
|
| `load` | No | `None` | Checkpoint loading configuration. |
|
||||||
|
| `save` | Yes | N/A | Checkpoint/Model saving configuration. |
|
||||||
|
Tracking is split up into three sections:
|
||||||
|
* Log: Where to save run metadata and image output. Options are `console` or `wandb`.
|
||||||
|
* Load: Where to load a checkpoint from. Options are `local`, `url`, or `wandb`.
|
||||||
|
* Save: Where to save a checkpoint to. Options are `local`, `huggingface`, or `wandb`.
|
||||||
|
|
||||||
Other configuration options are required for the specific trackers. To see which are required, reference the initializer parameters of each [tracker](../dalle2_pytorch/trackers.py).
|
**Logging:**
|
||||||
|
|
||||||
**<ins>Load</ins>:**
|
If using `console` there is no further configuration than setting `log_type` to `console`.
|
||||||
|
|
||||||
Selects where to load a pretrained model from.
|
|
||||||
| Option | Required | Default | Description |
|
| Option | Required | Default | Description |
|
||||||
| ------ | -------- | ------- | ----------- |
|
| ------ | -------- | ------- | ----------- |
|
||||||
| `source` | No | `None` | Supports `file` or `wandb`. |
|
| `log_type` | Yes | N/A | Must be `console`. |
|
||||||
| `resume` | No | `False` | If the tracker support resuming the run, resume it. |
|
|
||||||
|
|
||||||
Other configuration options are required for loading from a specific source. To see which are required, reference the load methods at the top of the [tracker file](../dalle2_pytorch/trackers.py).
|
If using `wandb`
|
||||||
|
| Option | Required | Default | Description |
|
||||||
|
| ------ | -------- | ------- | ----------- |
|
||||||
|
| `log_type` | Yes | N/A | Must be `wandb`. |
|
||||||
|
| `wandb_entity` | Yes | N/A | The wandb entity to log to. |
|
||||||
|
| `wandb_project` | Yes | N/A | The wandb project save the run to. |
|
||||||
|
| `wandb_run_name` | No | `None` | The wandb run name. |
|
||||||
|
| `wandb_run_id` | No | `None` | The wandb run id. Used if resuming an old run. |
|
||||||
|
| `wandb_resume` | No | `False` | Whether to resume an old run. |
|
||||||
|
|
||||||
|
**Loading:**
|
||||||
|
|
||||||
|
If using `local`
|
||||||
|
| Option | Required | Default | Description |
|
||||||
|
| ------ | -------- | ------- | ----------- |
|
||||||
|
| `load_from` | Yes | N/A | Must be `local`. |
|
||||||
|
| `file_path` | Yes | N/A | The path to the checkpoint file. |
|
||||||
|
|
||||||
|
If using `url`
|
||||||
|
| Option | Required | Default | Description |
|
||||||
|
| ------ | -------- | ------- | ----------- |
|
||||||
|
| `load_from` | Yes | N/A | Must be `url`. |
|
||||||
|
| `url` | Yes | N/A | The url of the checkpoint file. |
|
||||||
|
|
||||||
|
If using `wandb`
|
||||||
|
| Option | Required | Default | Description |
|
||||||
|
| ------ | -------- | ------- | ----------- |
|
||||||
|
| `load_from` | Yes | N/A | Must be `wandb`. |
|
||||||
|
| `wandb_run_path` | No | `None` | The wandb run path. If `None`, uses the run that is being resumed. |
|
||||||
|
| `wandb_file_path` | Yes | N/A | The path to the checkpoint file in the W&B file system. |
|
||||||
|
|
||||||
|
**Saving:**
|
||||||
|
Unlike `log` and `load`, `save` may be an array of options so that you can save to different locations in a run.
|
||||||
|
|
||||||
|
All save locations have these configuration options
|
||||||
|
| Option | Required | Default | Description |
|
||||||
|
| ------ | -------- | ------- | ----------- |
|
||||||
|
| `save_to` | Yes | N/A | Must be `local`, `huggingface`, or `wandb`. |
|
||||||
|
| `save_latest_to` | No | `latest.pth` | Sets the relative path to save the latest model to. |
|
||||||
|
| `save_best_to` | No | `best.pth` | Sets the relative path to save the best model to every time the model has a lower validation loss than all previous models. |
|
||||||
|
| `save_type` | No | `'checkpoint'` | The type of save. `'checkpoint'` saves a checkpoint, `'model'` saves a model without any fluff (Saves with ema if ema is enabled). |
|
||||||
|
|
||||||
|
If using `local`
|
||||||
|
| Option | Required | Default | Description |
|
||||||
|
| ------ | -------- | ------- | ----------- |
|
||||||
|
| `save_to` | Yes | N/A | Must be `local`. |
|
||||||
|
|
||||||
|
If using `huggingface`
|
||||||
|
| Option | Required | Default | Description |
|
||||||
|
| ------ | -------- | ------- | ----------- |
|
||||||
|
| `save_to` | Yes | N/A | Must be `huggingface`. |
|
||||||
|
| `huggingface_repo` | Yes | N/A | The huggingface repository to save to. |
|
||||||
|
| `huggingface_base_path` | Yes | N/A | The base path that checkpoints will be saved under. |
|
||||||
|
| `token_path` | No | `None` | If logging in with the huggingface cli is not possible, point to a token file instead. |
|
||||||
|
|
||||||
|
If using `wandb`
|
||||||
|
| Option | Required | Default | Description |
|
||||||
|
| ------ | -------- | ------- | ----------- |
|
||||||
|
| `save_to` | Yes | N/A | Must be `wandb`. |
|
||||||
|
| `wandb_run_path` | No | `None` | The wandb run path. If `None`, uses the current run. You will almost always want this to be `None`. |
|
||||||
|
|||||||
@@ -1,21 +1,21 @@
|
|||||||
{
|
{
|
||||||
"unets": [
|
|
||||||
{
|
|
||||||
"dim": 128,
|
|
||||||
"image_embed_dim": 768,
|
|
||||||
"cond_dim": 64,
|
|
||||||
"channels": 3,
|
|
||||||
"dim_mults": [1, 2, 4, 8],
|
|
||||||
"attn_dim_head": 32,
|
|
||||||
"attn_heads": 16
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"decoder": {
|
"decoder": {
|
||||||
|
"unets": [
|
||||||
|
{
|
||||||
|
"dim": 128,
|
||||||
|
"image_embed_dim": 768,
|
||||||
|
"cond_dim": 64,
|
||||||
|
"channels": 3,
|
||||||
|
"dim_mults": [1, 2, 4, 8],
|
||||||
|
"attn_dim_head": 32,
|
||||||
|
"attn_heads": 16
|
||||||
|
}
|
||||||
|
],
|
||||||
"image_sizes": [64],
|
"image_sizes": [64],
|
||||||
"channels": 3,
|
"channels": 3,
|
||||||
"timesteps": 1000,
|
"timesteps": 1000,
|
||||||
"loss_type": "l2",
|
"loss_type": "l2",
|
||||||
"beta_schedule": "cosine",
|
"beta_schedule": ["cosine"],
|
||||||
"learned_variance": true
|
"learned_variance": true
|
||||||
},
|
},
|
||||||
"data": {
|
"data": {
|
||||||
@@ -80,20 +80,32 @@
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
"tracker": {
|
"tracker": {
|
||||||
"tracker_type": "console",
|
"overwrite_data_path": true,
|
||||||
"data_path": "./models",
|
|
||||||
|
|
||||||
"wandb_entity": "",
|
"log": {
|
||||||
"wandb_project": "",
|
"log_type": "wandb",
|
||||||
|
|
||||||
"verbose": false
|
"wandb_entity": "your_wandb",
|
||||||
},
|
"wandb_project": "your_project",
|
||||||
"load": {
|
|
||||||
"source": null,
|
|
||||||
|
|
||||||
"run_path": "",
|
"verbose": true
|
||||||
"file_path": "",
|
},
|
||||||
|
|
||||||
"resume": false
|
"load": {
|
||||||
|
"load_from": null
|
||||||
|
},
|
||||||
|
|
||||||
|
"save": [{
|
||||||
|
"save_to": "wandb"
|
||||||
|
}, {
|
||||||
|
"save_to": "huggingface",
|
||||||
|
"huggingface_repo": "Veldrovive/test_model",
|
||||||
|
|
||||||
|
"save_all": true,
|
||||||
|
"save_latest": true,
|
||||||
|
"save_best": true,
|
||||||
|
|
||||||
|
"save_type": "model"
|
||||||
|
}]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
70
configs/train_prior_config.example.json
Normal file
70
configs/train_prior_config.example.json
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
{
|
||||||
|
"prior": {
|
||||||
|
"clip": {
|
||||||
|
"make": "x-clip",
|
||||||
|
"model": "ViT-L/14",
|
||||||
|
"base_model_kwargs": {
|
||||||
|
"dim_text": 768,
|
||||||
|
"dim_image": 768,
|
||||||
|
"dim_latent": 768
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"net": {
|
||||||
|
"dim": 768,
|
||||||
|
"depth": 12,
|
||||||
|
"num_timesteps": 1000,
|
||||||
|
"num_time_embeds": 1,
|
||||||
|
"num_image_embeds": 1,
|
||||||
|
"num_text_embeds": 1,
|
||||||
|
"dim_head": 64,
|
||||||
|
"heads": 12,
|
||||||
|
"ff_mult": 4,
|
||||||
|
"norm_out": true,
|
||||||
|
"attn_dropout": 0.0,
|
||||||
|
"ff_dropout": 0.0,
|
||||||
|
"final_proj": true,
|
||||||
|
"normformer": true,
|
||||||
|
"rotary_emb": true
|
||||||
|
},
|
||||||
|
"image_embed_dim": 768,
|
||||||
|
"image_size": 224,
|
||||||
|
"image_channels": 3,
|
||||||
|
"timesteps": 1000,
|
||||||
|
"cond_drop_prob": 0.1,
|
||||||
|
"loss_type": "l2",
|
||||||
|
"predict_x_start": true,
|
||||||
|
"beta_schedule": "cosine",
|
||||||
|
"condition_on_text_encodings": true
|
||||||
|
},
|
||||||
|
"data": {
|
||||||
|
"image_url": "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/",
|
||||||
|
"text_url": "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/",
|
||||||
|
"meta_url": "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/laion2B-en-metadata/",
|
||||||
|
"batch_size": 256,
|
||||||
|
"splits": {
|
||||||
|
"train": 0.9,
|
||||||
|
"val": 1e-7,
|
||||||
|
"test": 0.0999999
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"train": {
|
||||||
|
"epochs": 1,
|
||||||
|
"lr": 1.1e-4,
|
||||||
|
"wd": 6.02e-2,
|
||||||
|
"max_grad_norm": 0.5,
|
||||||
|
"use_ema": true,
|
||||||
|
"amp": false,
|
||||||
|
"save_every": 10000
|
||||||
|
},
|
||||||
|
"load": {
|
||||||
|
"source": null,
|
||||||
|
"resume": false
|
||||||
|
},
|
||||||
|
"tracker": {
|
||||||
|
"tracker_type": "wandb",
|
||||||
|
"data_path": "./prior_checkpoints",
|
||||||
|
"wandb_entity": "laion",
|
||||||
|
"wandb_project": "diffusion-prior",
|
||||||
|
"verbose": true
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,3 +1,4 @@
|
|||||||
|
from dalle2_pytorch.version import __version__
|
||||||
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
|
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
|
||||||
from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter
|
from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter
|
||||||
from dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer
|
from dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -15,7 +15,7 @@ from dalle2_pytorch.dataloaders import ImageEmbeddingDataset, create_image_embed
|
|||||||
|
|
||||||
# Create a dataloader directly.
|
# Create a dataloader directly.
|
||||||
dataloader = create_image_embedding_dataloader(
|
dataloader = create_image_embedding_dataloader(
|
||||||
tar_url="/path/or/url/to/webdataset/{0000..9999}.tar", # Uses braket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar
|
tar_url="/path/or/url/to/webdataset/{0000..9999}.tar", # Uses bracket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar
|
||||||
embeddings_url="path/or/url/to/embeddings/folder", # Included if .npy files are not in webdataset. Left out or set to None otherwise
|
embeddings_url="path/or/url/to/embeddings/folder", # Included if .npy files are not in webdataset. Left out or set to None otherwise
|
||||||
num_workers=4,
|
num_workers=4,
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
@@ -39,3 +39,37 @@ dataset = ImageEmbeddingDataset(
|
|||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Diffusion Prior: Prior Embedding Dataset
|
||||||
|
When training the prior it is much more efficient to work with pre-computed embeddings. The `PriorEmbeddingDataset` class enables you to leverage the same script (with minimal modification) for both embedding-only and text-conditioned prior training. This saves you from having to worry about a lot of the boilerplate code.
|
||||||
|
|
||||||
|
To utilize the `PriorEmbeddingDataset`, all you need to do is make a single call to `get_reader()` which will create `EmbeddingReader` object(s) for you. Afterwards, you can utilize `make_splits()` to cleanly create DataLoader objects from for your training run.
|
||||||
|
|
||||||
|
If you are training in a distributed manner, `make_splits()` accepts `rank` and `world_size` arguments to properly distribute to each process. The defaults for these values are `rank=0` and `world_size=1`, so single-process training can safely ignore these parameters.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
```python
|
||||||
|
from dalle2_pytorch.dataloaders import get_reader, make_splits
|
||||||
|
|
||||||
|
# grab embeddings from some specified location
|
||||||
|
IMG_URL = "data/img_emb/"
|
||||||
|
META_URL = "data/meta/"
|
||||||
|
|
||||||
|
reader = get_reader(text_conditioned=True, img_url=IMG_URL, meta_url=META_URL)
|
||||||
|
|
||||||
|
# some config for training
|
||||||
|
TRAIN_ARGS = {
|
||||||
|
"world_size": 3,
|
||||||
|
"text_conditioned": True,
|
||||||
|
"start": 0,
|
||||||
|
"num_data_points": 10000,
|
||||||
|
"batch_size": 2,
|
||||||
|
"train_split": 0.5,
|
||||||
|
"eval_split": 0.25,
|
||||||
|
"image_reader": reader,
|
||||||
|
}
|
||||||
|
|
||||||
|
# specifying a rank will handle allocation internally
|
||||||
|
rank0_train, rank0_eval, rank0_test = make_splits(rank=0, **TRAIN_ARGS)
|
||||||
|
rank1_train, rank1_eval, rank1_test = make_splits(rank=1, **TRAIN_ARGS)
|
||||||
|
rank2_train, rank2_eval, rank2_test = make_splits(rank=2, **TRAIN_ARGS)
|
||||||
|
```
|
||||||
|
|||||||
@@ -1,2 +1,2 @@
|
|||||||
from dalle2_pytorch.dataloaders.decoder_loader import ImageEmbeddingDataset, create_image_embedding_dataloader
|
from dalle2_pytorch.dataloaders.decoder_loader import ImageEmbeddingDataset, create_image_embedding_dataloader
|
||||||
from dalle2_pytorch.dataloaders.embedding_wrapper import make_splits
|
from dalle2_pytorch.dataloaders.prior_loader import make_splits, get_reader, PriorEmbeddingDataset
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ def get_example_file(fs, path, file_format):
|
|||||||
"""
|
"""
|
||||||
return fs.glob(os.path.join(path, f"*.{file_format}"))[0]
|
return fs.glob(os.path.join(path, f"*.{file_format}"))[0]
|
||||||
|
|
||||||
def embedding_inserter(samples, embeddings_url, index_width, handler=wds.handlers.reraise_exception):
|
def embedding_inserter(samples, embeddings_url, index_width, sample_key='npy', handler=wds.handlers.reraise_exception):
|
||||||
"""Given a datum of {"__key__": str, "__url__": str, ...} adds the cooresponding embedding and yields"""
|
"""Given a datum of {"__key__": str, "__url__": str, ...} adds the cooresponding embedding and yields"""
|
||||||
previous_tar_url = None
|
previous_tar_url = None
|
||||||
current_embeddings = None
|
current_embeddings = None
|
||||||
@@ -56,7 +56,7 @@ def embedding_inserter(samples, embeddings_url, index_width, handler=wds.handler
|
|||||||
# We need to check if this sample is nonzero. If it is, this embedding is not valid and we should continue to the next loop
|
# We need to check if this sample is nonzero. If it is, this embedding is not valid and we should continue to the next loop
|
||||||
if torch.count_nonzero(embedding) == 0:
|
if torch.count_nonzero(embedding) == 0:
|
||||||
raise RuntimeError(f"Webdataset had a sample, but no embedding was found. ImgShard: {key[:-index_width]} - Index: {key[-index_width:]}")
|
raise RuntimeError(f"Webdataset had a sample, but no embedding was found. ImgShard: {key[:-index_width]} - Index: {key[-index_width:]}")
|
||||||
sample["npy"] = embedding
|
sample[sample_key] = embedding
|
||||||
yield sample
|
yield sample
|
||||||
except Exception as exn: # From wds implementation
|
except Exception as exn: # From wds implementation
|
||||||
if handler(exn):
|
if handler(exn):
|
||||||
@@ -84,18 +84,20 @@ def unassociated_shard_skipper(tarfiles, embeddings_url, handler=wds.handlers.re
|
|||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
|
|
||||||
skip_unassociated_shards = wds.filters.pipelinefilter(unassociated_shard_skipper)
|
skip_unassociated_shards = wds.filters.pipelinefilter(unassociated_shard_skipper)
|
||||||
|
|
||||||
def verify_keys(samples, handler=wds.handlers.reraise_exception):
|
def join_embeddings(samples, handler=wds.handlers.reraise_exception):
|
||||||
"""
|
"""
|
||||||
Requires that both the image and embedding are present in the sample
|
Takes the img_emb and text_emb keys and turns them into one key "emb": { "text": text_emb, "img": img_emb }
|
||||||
This is important to do as a user may forget they do not have embeddings in their webdataset and neglect to add them using the embedding_folder_url parameter.
|
either or both of text_emb and img_emb may not be in the sample so we only add the ones that exist
|
||||||
"""
|
"""
|
||||||
for sample in samples:
|
for sample in samples:
|
||||||
try:
|
try:
|
||||||
assert "jpg" in sample, f"Sample {sample['__key__']} missing image"
|
sample['emb'] = {}
|
||||||
assert "npy" in sample, f"Sample {sample['__key__']} missing embedding. Did you set embedding_folder_url?"
|
if 'text_emb' in sample:
|
||||||
|
sample['emb']['text'] = sample['text_emb']
|
||||||
|
if 'img_emb' in sample:
|
||||||
|
sample['emb']['img'] = sample['img_emb']
|
||||||
yield sample
|
yield sample
|
||||||
except Exception as exn: # From wds implementation
|
except Exception as exn: # From wds implementation
|
||||||
if handler(exn):
|
if handler(exn):
|
||||||
@@ -103,6 +105,23 @@ def verify_keys(samples, handler=wds.handlers.reraise_exception):
|
|||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
def verify_keys(samples, required_keys, handler=wds.handlers.reraise_exception):
|
||||||
|
"""
|
||||||
|
Requires that both the image and embedding are present in the sample
|
||||||
|
This is important to do as a user may forget they do not have embeddings in their webdataset and neglect to add them using the embedding_folder_url parameter.
|
||||||
|
"""
|
||||||
|
for sample in samples:
|
||||||
|
try:
|
||||||
|
for key in required_keys:
|
||||||
|
assert key in sample, f"Sample {sample['__key__']} missing {key}. Has keys {sample.keys()}"
|
||||||
|
yield sample
|
||||||
|
except Exception as exn: # From wds implementation
|
||||||
|
if handler(exn):
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
key_verifier = wds.filters.pipelinefilter(verify_keys)
|
||||||
|
|
||||||
class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
|
class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
|
||||||
"""
|
"""
|
||||||
A fluid interface wrapper for DataPipline that returns image embedding pairs
|
A fluid interface wrapper for DataPipline that returns image embedding pairs
|
||||||
@@ -112,7 +131,8 @@ class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
urls,
|
urls,
|
||||||
embedding_folder_url=None,
|
img_embedding_folder_url=None,
|
||||||
|
text_embedding_folder_url=None,
|
||||||
index_width=None,
|
index_width=None,
|
||||||
img_preproc=None,
|
img_preproc=None,
|
||||||
extra_keys=[],
|
extra_keys=[],
|
||||||
@@ -136,7 +156,12 @@ class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
keys = ["jpg", "npy"] + extra_keys
|
keys = ["jpg", "emb"] + extra_keys
|
||||||
|
# if img_embedding_folder_url is not None:
|
||||||
|
# keys.append("img_emb")
|
||||||
|
# if text_embedding_folder_url is not None:
|
||||||
|
# keys.append("text_emb")
|
||||||
|
# keys.extend(extra_keys)
|
||||||
self.key_map = {key: i for i, key in enumerate(keys)}
|
self.key_map = {key: i for i, key in enumerate(keys)}
|
||||||
self.resampling = resample
|
self.resampling = resample
|
||||||
self.img_preproc = img_preproc
|
self.img_preproc = img_preproc
|
||||||
@@ -145,7 +170,7 @@ class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
|
|||||||
# Then this has an s3 link for the webdataset and we need extra packages
|
# Then this has an s3 link for the webdataset and we need extra packages
|
||||||
if shutil.which("s3cmd") is None:
|
if shutil.which("s3cmd") is None:
|
||||||
raise RuntimeError("s3cmd is required for s3 webdataset")
|
raise RuntimeError("s3cmd is required for s3 webdataset")
|
||||||
if "s3:" in embedding_folder_url:
|
if (img_embedding_folder_url is not None and "s3:" in img_embedding_folder_url) or (text_embedding_folder_url is not None and "s3:" in text_embedding_folder_url):
|
||||||
# Then the embeddings are being loaded from s3 and fsspec requires s3fs
|
# Then the embeddings are being loaded from s3 and fsspec requires s3fs
|
||||||
try:
|
try:
|
||||||
import s3fs
|
import s3fs
|
||||||
@@ -160,20 +185,24 @@ class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
|
|||||||
if shuffle_shards:
|
if shuffle_shards:
|
||||||
self.append(wds.filters.shuffle(1000))
|
self.append(wds.filters.shuffle(1000))
|
||||||
|
|
||||||
if embedding_folder_url is not None:
|
if img_embedding_folder_url is not None:
|
||||||
# There may be webdataset shards that do not have a embedding shard associated with it. If we do not skip these, they would cause issues.
|
# There may be webdataset shards that do not have a embedding shard associated with it. If we do not skip these, they would cause issues.
|
||||||
self.append(skip_unassociated_shards(embeddings_url=embedding_folder_url, handler=handler))
|
self.append(skip_unassociated_shards(embeddings_url=img_embedding_folder_url, handler=handler))
|
||||||
|
if text_embedding_folder_url is not None:
|
||||||
self.append(wds.split_by_node)
|
self.append(skip_unassociated_shards(embeddings_url=text_embedding_folder_url, handler=handler))
|
||||||
self.append(wds.split_by_worker)
|
|
||||||
|
|
||||||
self.append(wds.tarfile_to_samples(handler=handler))
|
self.append(wds.tarfile_to_samples(handler=handler))
|
||||||
self.append(wds.decode("pilrgb", handler=handler))
|
self.append(wds.decode("pilrgb", handler=handler))
|
||||||
if embedding_folder_url is not None:
|
if img_embedding_folder_url is not None:
|
||||||
# Then we are loading embeddings for a remote source
|
# Then we are loading image embeddings for a remote source
|
||||||
assert index_width is not None, "Reading embeddings separately requires index width length to be given"
|
assert index_width is not None, "Reading embeddings separately requires index width length to be given"
|
||||||
self.append(insert_embedding(embeddings_url=embedding_folder_url, index_width=index_width, handler=handler))
|
self.append(insert_embedding(embeddings_url=img_embedding_folder_url, index_width=index_width, sample_key='img_emb', handler=handler))
|
||||||
self.append(verify_keys)
|
if text_embedding_folder_url is not None:
|
||||||
|
# Then we are loading image embeddings for a remote source
|
||||||
|
assert index_width is not None, "Reading embeddings separately requires index width length to be given"
|
||||||
|
self.append(insert_embedding(embeddings_url=text_embedding_folder_url, index_width=index_width, sample_key='text_emb', handler=handler))
|
||||||
|
self.append(join_embeddings)
|
||||||
|
self.append(key_verifier(required_keys=keys, handler=handler))
|
||||||
# Apply preprocessing
|
# Apply preprocessing
|
||||||
self.append(wds.map(self.preproc))
|
self.append(wds.map(self.preproc))
|
||||||
self.append(wds.to_tuple(*keys))
|
self.append(wds.to_tuple(*keys))
|
||||||
@@ -188,7 +217,8 @@ def create_image_embedding_dataloader(
|
|||||||
tar_url,
|
tar_url,
|
||||||
num_workers,
|
num_workers,
|
||||||
batch_size,
|
batch_size,
|
||||||
embeddings_url=None,
|
img_embeddings_url=None,
|
||||||
|
text_embeddings_url=None,
|
||||||
index_width=None,
|
index_width=None,
|
||||||
shuffle_num = None,
|
shuffle_num = None,
|
||||||
shuffle_shards = True,
|
shuffle_shards = True,
|
||||||
@@ -214,7 +244,8 @@ def create_image_embedding_dataloader(
|
|||||||
"""
|
"""
|
||||||
ds = ImageEmbeddingDataset(
|
ds = ImageEmbeddingDataset(
|
||||||
tar_url,
|
tar_url,
|
||||||
embeddings_url,
|
img_embedding_folder_url=img_embeddings_url,
|
||||||
|
text_embedding_folder_url=text_embeddings_url,
|
||||||
index_width=index_width,
|
index_width=index_width,
|
||||||
shuffle_shards=shuffle_shards,
|
shuffle_shards=shuffle_shards,
|
||||||
resample=resample_shards,
|
resample=resample_shards,
|
||||||
|
|||||||
@@ -1,180 +0,0 @@
|
|||||||
from torch.utils.data import IterableDataset
|
|
||||||
from torch import from_numpy
|
|
||||||
from clip import tokenize
|
|
||||||
from embedding_reader import EmbeddingReader
|
|
||||||
|
|
||||||
|
|
||||||
class PriorEmbeddingLoader(IterableDataset):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
text_conditioned: bool,
|
|
||||||
batch_size: int,
|
|
||||||
start: int,
|
|
||||||
stop: int,
|
|
||||||
image_reader,
|
|
||||||
text_reader: EmbeddingReader = None,
|
|
||||||
device: str = "cpu",
|
|
||||||
) -> None:
|
|
||||||
super(PriorEmbeddingLoader).__init__()
|
|
||||||
|
|
||||||
self.text_conditioned = text_conditioned
|
|
||||||
|
|
||||||
if not self.text_conditioned:
|
|
||||||
self.text_reader = text_reader
|
|
||||||
|
|
||||||
self.image_reader = image_reader
|
|
||||||
self.batch_size = batch_size
|
|
||||||
self.start = start
|
|
||||||
self.stop = stop
|
|
||||||
self.device = device
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
self.n = 0
|
|
||||||
loader_args = dict(
|
|
||||||
batch_size=self.batch_size,
|
|
||||||
start=self.start,
|
|
||||||
end=self.stop,
|
|
||||||
show_progress=False,
|
|
||||||
)
|
|
||||||
if self.text_conditioned:
|
|
||||||
self.loader = self.image_reader(**loader_args)
|
|
||||||
else:
|
|
||||||
self.loader = zip(
|
|
||||||
self.image_reader(**loader_args), self.text_reader(**loader_args)
|
|
||||||
)
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __next__(self):
|
|
||||||
try:
|
|
||||||
return self.get_sample()
|
|
||||||
except StopIteration:
|
|
||||||
raise StopIteration
|
|
||||||
|
|
||||||
def get_sample(self):
|
|
||||||
"""
|
|
||||||
pre-proocess data from either reader into a common format
|
|
||||||
"""
|
|
||||||
self.n += 1
|
|
||||||
|
|
||||||
if self.text_conditioned:
|
|
||||||
image_embedding, caption = next(self.loader)
|
|
||||||
|
|
||||||
image_embedding = from_numpy(image_embedding).to(self.device)
|
|
||||||
tokenized_caption = tokenize(
|
|
||||||
caption["caption"].to_list(), truncate=True
|
|
||||||
).to(self.device)
|
|
||||||
|
|
||||||
return image_embedding, tokenized_caption
|
|
||||||
|
|
||||||
else:
|
|
||||||
(image_embedding, _), (text_embedding, _) = next(self.loader)
|
|
||||||
|
|
||||||
image_embedding = from_numpy(image_embedding).to(self.device)
|
|
||||||
text_embedding = from_numpy(text_embedding).to(self.device)
|
|
||||||
|
|
||||||
return image_embedding, text_embedding
|
|
||||||
|
|
||||||
|
|
||||||
def make_splits(
|
|
||||||
text_conditioned: bool,
|
|
||||||
batch_size: int,
|
|
||||||
num_data_points: int,
|
|
||||||
train_split: float,
|
|
||||||
eval_split: float,
|
|
||||||
device: str,
|
|
||||||
img_url: str,
|
|
||||||
meta_url: str = None,
|
|
||||||
txt_url: str = None,
|
|
||||||
):
|
|
||||||
|
|
||||||
assert img_url is not None, "Must supply some image embeddings"
|
|
||||||
|
|
||||||
if text_conditioned:
|
|
||||||
assert meta_url is not None, "Must supply metadata url if text-conditioning"
|
|
||||||
image_reader = EmbeddingReader(
|
|
||||||
embeddings_folder=img_url,
|
|
||||||
file_format="parquet_npy",
|
|
||||||
meta_columns=["caption"],
|
|
||||||
metadata_folder=meta_url,
|
|
||||||
)
|
|
||||||
|
|
||||||
# compute split points
|
|
||||||
if num_data_points > image_reader.count:
|
|
||||||
print("Specified point count is larger than the number of points available...defaulting to max length of reader.")
|
|
||||||
num_data_points = image_reader.count
|
|
||||||
|
|
||||||
train_set_size = int(train_split * num_data_points)
|
|
||||||
eval_set_size = int(eval_split * num_data_points)
|
|
||||||
eval_stop = int(train_set_size + eval_set_size)
|
|
||||||
|
|
||||||
train_loader = PriorEmbeddingLoader(
|
|
||||||
text_conditioned=text_conditioned,
|
|
||||||
image_reader=image_reader,
|
|
||||||
batch_size=batch_size,
|
|
||||||
start=0,
|
|
||||||
stop=train_set_size,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
eval_loader = PriorEmbeddingLoader(
|
|
||||||
text_conditioned=text_conditioned,
|
|
||||||
image_reader=image_reader,
|
|
||||||
batch_size=batch_size,
|
|
||||||
start=train_set_size,
|
|
||||||
stop=eval_stop,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
test_loader = PriorEmbeddingLoader(
|
|
||||||
text_conditioned=text_conditioned,
|
|
||||||
image_reader=image_reader,
|
|
||||||
batch_size=batch_size,
|
|
||||||
start=eval_stop,
|
|
||||||
stop=int(num_data_points),
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
assert (
|
|
||||||
txt_url is not None
|
|
||||||
), "Must supply text embedding url if not text-conditioning"
|
|
||||||
|
|
||||||
image_reader = EmbeddingReader(img_url, file_format="npy")
|
|
||||||
text_reader = EmbeddingReader(txt_url, file_format="npy")
|
|
||||||
|
|
||||||
# compute split points
|
|
||||||
if num_data_points > image_reader.count:
|
|
||||||
print("Specified point count is larger than the number of points available...defaulting to max length of reader.")
|
|
||||||
num_data_points = image_reader.count
|
|
||||||
|
|
||||||
train_set_size = int(train_split * num_data_points)
|
|
||||||
eval_set_size = int(eval_split * num_data_points)
|
|
||||||
eval_stop = int(train_set_size + eval_set_size)
|
|
||||||
|
|
||||||
train_loader = PriorEmbeddingLoader(
|
|
||||||
text_conditioned=text_conditioned,
|
|
||||||
image_reader=image_reader,
|
|
||||||
text_reader=text_reader,
|
|
||||||
batch_size=batch_size,
|
|
||||||
start=0,
|
|
||||||
stop=train_set_size,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
eval_loader = PriorEmbeddingLoader(
|
|
||||||
text_conditioned=text_conditioned,
|
|
||||||
image_reader=image_reader,
|
|
||||||
text_reader=text_reader,
|
|
||||||
batch_size=batch_size,
|
|
||||||
start=train_set_size,
|
|
||||||
stop=eval_stop,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
test_loader = PriorEmbeddingLoader(
|
|
||||||
text_conditioned=text_conditioned,
|
|
||||||
image_reader=image_reader,
|
|
||||||
text_reader=text_reader,
|
|
||||||
batch_size=batch_size,
|
|
||||||
start=eval_stop,
|
|
||||||
stop=int(num_data_points),
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
|
|
||||||
return train_loader, eval_loader, test_loader
|
|
||||||
273
dalle2_pytorch/dataloaders/prior_loader.py
Normal file
273
dalle2_pytorch/dataloaders/prior_loader.py
Normal file
@@ -0,0 +1,273 @@
|
|||||||
|
from math import ceil
|
||||||
|
from clip import tokenize
|
||||||
|
from embedding_reader import EmbeddingReader
|
||||||
|
from torch import from_numpy
|
||||||
|
from torch.utils.data import IterableDataset, DataLoader
|
||||||
|
|
||||||
|
|
||||||
|
class PriorEmbeddingDataset(IterableDataset):
|
||||||
|
"""
|
||||||
|
PriorEmbeddingDataset is a wrapper of EmbeddingReader.
|
||||||
|
|
||||||
|
It enables one to simplify the logic necessary to yield samples from
|
||||||
|
the different EmbeddingReader configurations available.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
text_conditioned: bool,
|
||||||
|
batch_size: int,
|
||||||
|
start: int,
|
||||||
|
stop: int,
|
||||||
|
image_reader,
|
||||||
|
text_reader: EmbeddingReader = None,
|
||||||
|
) -> None:
|
||||||
|
super(PriorEmbeddingDataset).__init__()
|
||||||
|
|
||||||
|
self.text_conditioned = text_conditioned
|
||||||
|
|
||||||
|
if not self.text_conditioned:
|
||||||
|
self.text_reader = text_reader
|
||||||
|
|
||||||
|
self.image_reader = image_reader
|
||||||
|
self.start = start
|
||||||
|
self.stop = stop
|
||||||
|
self.batch_size = batch_size
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.stop - self.start
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
# D.R.Y loader args
|
||||||
|
loader_args = dict(
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
start=self.start,
|
||||||
|
end=self.stop,
|
||||||
|
show_progress=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# if the data requested is text conditioned, only load images
|
||||||
|
if self.text_conditioned:
|
||||||
|
self.loader = self.image_reader(**loader_args)
|
||||||
|
# otherwise, include text embeddings and bypass metadata
|
||||||
|
else:
|
||||||
|
self.loader = zip(
|
||||||
|
self.image_reader(**loader_args), self.text_reader(**loader_args)
|
||||||
|
)
|
||||||
|
|
||||||
|
# return the data loader in its formatted state
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
try:
|
||||||
|
return self.get_sample()
|
||||||
|
except StopIteration:
|
||||||
|
raise StopIteration
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"<PriorEmbeddingDataset: start: {self.start}, stop: {self.stop}, len: {self.__len__()}>"
|
||||||
|
|
||||||
|
def get_sample(self):
|
||||||
|
"""
|
||||||
|
pre-proocess data from either reader into a common format
|
||||||
|
"""
|
||||||
|
if self.text_conditioned:
|
||||||
|
image_embedding, caption = next(self.loader)
|
||||||
|
|
||||||
|
image_embedding = from_numpy(image_embedding)
|
||||||
|
tokenized_caption = tokenize(caption["caption"].to_list(), truncate=True)
|
||||||
|
|
||||||
|
return image_embedding, tokenized_caption
|
||||||
|
|
||||||
|
else:
|
||||||
|
(image_embedding, _), (text_embedding, _) = next(self.loader)
|
||||||
|
|
||||||
|
image_embedding = from_numpy(image_embedding)
|
||||||
|
text_embedding = from_numpy(text_embedding)
|
||||||
|
|
||||||
|
return image_embedding, text_embedding
|
||||||
|
|
||||||
|
|
||||||
|
# helper functions
|
||||||
|
|
||||||
|
|
||||||
|
def distribute_to_rank(start, stop, rank, world_size):
|
||||||
|
"""
|
||||||
|
Distribute data to each rank given the world size.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
- New start and stop points for this rank.
|
||||||
|
"""
|
||||||
|
num_samples = int(stop - start)
|
||||||
|
|
||||||
|
per_rank = int(ceil((num_samples) / float(world_size)))
|
||||||
|
|
||||||
|
assert (
|
||||||
|
per_rank > 0
|
||||||
|
), f"Number of samples per rank must be larger than 0, (found: {per_rank})"
|
||||||
|
|
||||||
|
rank_start = start + rank * per_rank
|
||||||
|
|
||||||
|
rank_stop = min(rank_start + per_rank, stop)
|
||||||
|
|
||||||
|
new_length = rank_stop - rank_start
|
||||||
|
|
||||||
|
assert (
|
||||||
|
new_length > 0
|
||||||
|
), "Calculated start and stop points result in a length of zero for this rank."
|
||||||
|
|
||||||
|
return rank_start, rank_stop
|
||||||
|
|
||||||
|
|
||||||
|
def get_reader(
|
||||||
|
text_conditioned: bool, img_url: str, meta_url: str = None, txt_url: str = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Create an EmbeddingReader object from the specified URLs
|
||||||
|
|
||||||
|
get_reader() will always expect a url to image embeddings.
|
||||||
|
|
||||||
|
If text-conditioned, it will also expect a meta_url for the captions.
|
||||||
|
Otherwise, it will need txt_url for the matching text embeddings.
|
||||||
|
|
||||||
|
Returns an image_reader object if text-conditioned.
|
||||||
|
Otherwise it returns both an image_reader and a text_reader
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert img_url is not None, "Must supply a image url"
|
||||||
|
|
||||||
|
if text_conditioned:
|
||||||
|
assert meta_url is not None, "Must supply meta url if text-conditioned"
|
||||||
|
|
||||||
|
image_reader = EmbeddingReader(
|
||||||
|
embeddings_folder=img_url,
|
||||||
|
file_format="parquet_npy",
|
||||||
|
# will assume the caption column exists and is the only one requested
|
||||||
|
meta_columns=["caption"],
|
||||||
|
metadata_folder=meta_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
return image_reader
|
||||||
|
|
||||||
|
# otherwise we will require text embeddings as well and return two readers
|
||||||
|
assert (
|
||||||
|
txt_url is not None
|
||||||
|
), "Must supply text embedding url if not text-conditioning"
|
||||||
|
|
||||||
|
image_reader = EmbeddingReader(img_url, file_format="npy")
|
||||||
|
text_reader = EmbeddingReader(txt_url, file_format="npy")
|
||||||
|
|
||||||
|
return image_reader, text_reader
|
||||||
|
|
||||||
|
|
||||||
|
def make_splits(
|
||||||
|
text_conditioned: bool,
|
||||||
|
batch_size: int,
|
||||||
|
num_data_points: int,
|
||||||
|
train_split: float,
|
||||||
|
eval_split: float,
|
||||||
|
image_reader: EmbeddingReader,
|
||||||
|
text_reader: EmbeddingReader = None,
|
||||||
|
start=0,
|
||||||
|
rank=0,
|
||||||
|
world_size=1,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Split an embedding reader object as needed.
|
||||||
|
|
||||||
|
NOTE: make_splits() will infer the test set size from your train and eval.
|
||||||
|
|
||||||
|
Input:
|
||||||
|
- text_conditioned: whether to prepare text-conditioned training data
|
||||||
|
- batch_size: the batch size for a single gpu
|
||||||
|
- num_data_points: the total number of data points you wish to train on
|
||||||
|
- train_split: the percentage of data you wish to train on
|
||||||
|
- eval_split: the percentage of data you wish to validate on
|
||||||
|
- image_reader: the image_reader you wish to split
|
||||||
|
- text_reader: the text_reader you want to split (if !text_conditioned)
|
||||||
|
- start: the starting point within your dataset
|
||||||
|
- rank: the rank of your worker
|
||||||
|
- world_size: the total world size of your distributed training run
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- PyTorch Dataloaders that yield tuples of (img, txt) data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert start < image_reader.count, "start position cannot exceed reader count."
|
||||||
|
|
||||||
|
# verify that the num_data_points does not exceed the max points
|
||||||
|
if num_data_points > (image_reader.count - start):
|
||||||
|
print(
|
||||||
|
"Specified count is larger than what's available...defaulting to reader's count."
|
||||||
|
)
|
||||||
|
num_data_points = image_reader.count
|
||||||
|
|
||||||
|
# compute split points
|
||||||
|
train_set_size = int(train_split * num_data_points)
|
||||||
|
eval_set_size = int(eval_split * num_data_points)
|
||||||
|
eval_start = train_set_size
|
||||||
|
eval_stop = int(eval_start + eval_set_size)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
train_split + eval_split
|
||||||
|
) < 1.0, "Specified train and eval split is too large to infer a test split."
|
||||||
|
|
||||||
|
# distribute to rank
|
||||||
|
rank_train_start, rank_train_stop = distribute_to_rank(
|
||||||
|
start, train_set_size, rank, world_size
|
||||||
|
)
|
||||||
|
rank_eval_start, rank_eval_stop = distribute_to_rank(
|
||||||
|
train_set_size, eval_stop, rank, world_size
|
||||||
|
)
|
||||||
|
rank_test_start, rank_test_stop = distribute_to_rank(
|
||||||
|
eval_stop, num_data_points, rank, world_size
|
||||||
|
)
|
||||||
|
|
||||||
|
# wrap up splits into a dict
|
||||||
|
train_split_args = dict(
|
||||||
|
start=rank_train_start, stop=rank_train_stop, batch_size=batch_size
|
||||||
|
)
|
||||||
|
eval_split_args = dict(
|
||||||
|
start=rank_eval_start, stop=rank_eval_stop, batch_size=batch_size
|
||||||
|
)
|
||||||
|
test_split_args = dict(
|
||||||
|
start=rank_test_start, stop=rank_test_stop, batch_size=batch_size
|
||||||
|
)
|
||||||
|
|
||||||
|
if text_conditioned:
|
||||||
|
# add the text-conditioned args to a unified dict
|
||||||
|
reader_args = dict(
|
||||||
|
text_conditioned=text_conditioned,
|
||||||
|
image_reader=image_reader,
|
||||||
|
)
|
||||||
|
|
||||||
|
train_split_args = dict(**reader_args, **train_split_args)
|
||||||
|
eval_split_args = dict(**reader_args, **eval_split_args)
|
||||||
|
test_split_args = dict(**reader_args, **test_split_args)
|
||||||
|
|
||||||
|
train = PriorEmbeddingDataset(**train_split_args)
|
||||||
|
val = PriorEmbeddingDataset(**eval_split_args)
|
||||||
|
test = PriorEmbeddingDataset(**test_split_args)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# add the non-conditioned args to a unified dict
|
||||||
|
reader_args = dict(
|
||||||
|
text_conditioned=text_conditioned,
|
||||||
|
image_reader=image_reader,
|
||||||
|
text_reader=text_reader,
|
||||||
|
)
|
||||||
|
|
||||||
|
train_split_args = dict(**reader_args, **train_split_args)
|
||||||
|
eval_split_args = dict(**reader_args, **eval_split_args)
|
||||||
|
test_split_args = dict(**reader_args, **test_split_args)
|
||||||
|
|
||||||
|
train = PriorEmbeddingDataset(**train_split_args)
|
||||||
|
val = PriorEmbeddingDataset(**eval_split_args)
|
||||||
|
test = PriorEmbeddingDataset(**test_split_args)
|
||||||
|
|
||||||
|
# true batch size is specifed in the PriorEmbeddingDataset
|
||||||
|
train_loader = DataLoader(train, batch_size=None)
|
||||||
|
eval_loader = DataLoader(val, batch_size=None)
|
||||||
|
test_loader = DataLoader(test, batch_size=None)
|
||||||
|
|
||||||
|
return train_loader, eval_loader, test_loader
|
||||||
@@ -1,17 +1,20 @@
|
|||||||
from torch.optim import AdamW, Adam
|
from torch.optim import AdamW, Adam
|
||||||
|
|
||||||
def separate_weight_decayable_params(params):
|
def separate_weight_decayable_params(params):
|
||||||
no_wd_params = set([param for param in params if param.ndim < 2])
|
wd_params, no_wd_params = [], []
|
||||||
wd_params = set(params) - no_wd_params
|
for param in params:
|
||||||
|
param_list = no_wd_params if param.ndim < 2 else wd_params
|
||||||
|
param_list.append(param)
|
||||||
return wd_params, no_wd_params
|
return wd_params, no_wd_params
|
||||||
|
|
||||||
def get_optimizer(
|
def get_optimizer(
|
||||||
params,
|
params,
|
||||||
lr = 1e-4,
|
lr = 1e-4,
|
||||||
wd = 1e-2,
|
wd = 1e-2,
|
||||||
betas = (0.9, 0.999),
|
betas = (0.9, 0.99),
|
||||||
eps = 1e-8,
|
eps = 1e-8,
|
||||||
filter_by_requires_grad = False,
|
filter_by_requires_grad = False,
|
||||||
|
group_wd_params = True,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
if filter_by_requires_grad:
|
if filter_by_requires_grad:
|
||||||
@@ -20,12 +23,12 @@ def get_optimizer(
|
|||||||
if wd == 0:
|
if wd == 0:
|
||||||
return Adam(params, lr = lr, betas = betas, eps = eps)
|
return Adam(params, lr = lr, betas = betas, eps = eps)
|
||||||
|
|
||||||
params = set(params)
|
if group_wd_params:
|
||||||
wd_params, no_wd_params = separate_weight_decayable_params(params)
|
wd_params, no_wd_params = separate_weight_decayable_params(params)
|
||||||
|
|
||||||
param_groups = [
|
params = [
|
||||||
{'params': list(wd_params)},
|
{'params': wd_params},
|
||||||
{'params': list(no_wd_params), 'weight_decay': 0},
|
{'params': no_wd_params, 'weight_decay': 0},
|
||||||
]
|
]
|
||||||
|
|
||||||
return AdamW(param_groups, lr = lr, weight_decay = wd, betas = betas, eps = eps)
|
return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps)
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
# to give users a quick easy start to training DALL-E without doing BPE
|
# to give users a quick easy start to training DALL-E without doing BPE
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import youtokentome as yttm
|
|
||||||
|
|
||||||
import html
|
import html
|
||||||
import os
|
import os
|
||||||
@@ -11,6 +10,8 @@ import regex as re
|
|||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
from dalle2_pytorch.utils import import_or_print_error
|
||||||
|
|
||||||
# OpenAI simple tokenizer
|
# OpenAI simple tokenizer
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
@@ -156,7 +157,9 @@ class YttmTokenizer:
|
|||||||
bpe_path = Path(bpe_path)
|
bpe_path = Path(bpe_path)
|
||||||
assert bpe_path.exists(), f'BPE json path {str(bpe_path)} does not exist'
|
assert bpe_path.exists(), f'BPE json path {str(bpe_path)} does not exist'
|
||||||
|
|
||||||
tokenizer = yttm.BPE(model = str(bpe_path))
|
self.yttm = import_or_print_error('youtokentome', 'you need to install youtokentome by `pip install youtokentome`')
|
||||||
|
|
||||||
|
tokenizer = self.yttm.BPE(model = str(bpe_path))
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.vocab_size = tokenizer.vocab_size()
|
self.vocab_size = tokenizer.vocab_size()
|
||||||
|
|
||||||
@@ -167,7 +170,7 @@ class YttmTokenizer:
|
|||||||
return self.tokenizer.decode(tokens, ignore_ids = pad_tokens.union({0}))
|
return self.tokenizer.decode(tokens, ignore_ids = pad_tokens.union({0}))
|
||||||
|
|
||||||
def encode(self, texts):
|
def encode(self, texts):
|
||||||
encoded = self.tokenizer.encode(texts, output_type = yttm.OutputType.ID)
|
encoded = self.tokenizer.encode(texts, output_type = self.yttm.OutputType.ID)
|
||||||
return list(map(torch.tensor, encoded))
|
return list(map(torch.tensor, encoded))
|
||||||
|
|
||||||
def tokenize(self, texts, context_length = 256, truncate_text = False):
|
def tokenize(self, texts, context_length = 256, truncate_text = False):
|
||||||
|
|||||||
@@ -1,10 +1,15 @@
|
|||||||
|
import urllib.request
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import importlib
|
import shutil
|
||||||
from itertools import zip_longest
|
from itertools import zip_longest
|
||||||
|
from typing import Optional, List, Union
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
|
||||||
|
from dalle2_pytorch.utils import import_or_print_error
|
||||||
|
from dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer
|
||||||
|
|
||||||
# constants
|
# constants
|
||||||
|
|
||||||
@@ -15,101 +20,494 @@ DEFAULT_DATA_PATH = './.tracker-data'
|
|||||||
def exists(val):
|
def exists(val):
|
||||||
return val is not None
|
return val is not None
|
||||||
|
|
||||||
def import_or_print_error(pkg_name, err_str = None):
|
# load file functions
|
||||||
try:
|
|
||||||
return importlib.import_module(pkg_name)
|
|
||||||
except ModuleNotFoundError as e:
|
|
||||||
if exists(err_str):
|
|
||||||
print(err_str)
|
|
||||||
exit()
|
|
||||||
|
|
||||||
# load state dict functions
|
def load_wandb_file(run_path, file_path, **kwargs):
|
||||||
|
|
||||||
def load_wandb_state_dict(run_path, file_path, **kwargs):
|
|
||||||
wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb recall function')
|
wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb recall function')
|
||||||
file_reference = wandb.restore(file_path, run_path=run_path)
|
file_reference = wandb.restore(file_path, run_path=run_path)
|
||||||
return torch.load(file_reference.name)
|
return file_reference.name
|
||||||
|
|
||||||
def load_local_state_dict(file_path, **kwargs):
|
def load_local_file(file_path, **kwargs):
|
||||||
return torch.load(file_path)
|
return file_path
|
||||||
|
|
||||||
# base class
|
class BaseLogger:
|
||||||
|
"""
|
||||||
class BaseTracker(nn.Module):
|
An abstract class representing an object that can log data.
|
||||||
def __init__(self, data_path = DEFAULT_DATA_PATH):
|
Parameters:
|
||||||
super().__init__()
|
data_path (str): A file path for storing temporary data.
|
||||||
|
verbose (bool): Whether of not to always print logs to the console.
|
||||||
|
"""
|
||||||
|
def __init__(self, data_path: str, verbose: bool = False, **kwargs):
|
||||||
self.data_path = Path(data_path)
|
self.data_path = Path(data_path)
|
||||||
self.data_path.mkdir(parents = True, exist_ok = True)
|
self.verbose = verbose
|
||||||
|
|
||||||
def init(self, config, **kwargs):
|
def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def log(self, log, **kwargs):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def log_images(self, images, **kwargs):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def save_state_dict(self, state_dict, relative_path, **kwargs):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def recall_state_dict(self, recall_source, *args, **kwargs):
|
|
||||||
"""
|
"""
|
||||||
Loads a state dict from any source.
|
Initializes the logger.
|
||||||
Since a user may wish to load a model from a different source than their own tracker (i.e. tracking using wandb but recalling from disk),
|
Errors if the logger is invalid.
|
||||||
this should not be linked to any individual tracker.
|
|
||||||
"""
|
"""
|
||||||
# TODO: Pull this into a dict or something similar so that we can add more sources without having a massive switch statement
|
raise NotImplementedError
|
||||||
if recall_source == 'wandb':
|
|
||||||
return load_wandb_state_dict(*args, **kwargs)
|
|
||||||
elif recall_source == 'local':
|
|
||||||
return load_local_state_dict(*args, **kwargs)
|
|
||||||
else:
|
|
||||||
raise ValueError('`recall_source` must be one of `wandb` or `local`')
|
|
||||||
|
|
||||||
|
def log(self, log, **kwargs) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
# basic stdout class
|
def log_images(self, images, captions=[], image_section="images", **kwargs) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
class ConsoleTracker(BaseTracker):
|
def log_file(self, file_path, **kwargs) -> None:
|
||||||
def init(self, **config):
|
raise NotImplementedError
|
||||||
print(config)
|
|
||||||
|
|
||||||
def log(self, log, **kwargs):
|
def log_error(self, error_string, **kwargs) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
class ConsoleLogger(BaseLogger):
|
||||||
|
def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:
|
||||||
|
print("Logging to console")
|
||||||
|
|
||||||
|
def log(self, log, **kwargs) -> None:
|
||||||
print(log)
|
print(log)
|
||||||
|
|
||||||
def log_images(self, images, **kwargs): # noop for logging images
|
def log_images(self, images, captions=[], image_section="images", **kwargs) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def save_state_dict(self, state_dict, relative_path, **kwargs):
|
def log_file(self, file_path, **kwargs) -> None:
|
||||||
torch.save(state_dict, str(self.data_path / relative_path))
|
pass
|
||||||
|
|
||||||
# basic wandb class
|
def log_error(self, error_string, **kwargs) -> None:
|
||||||
|
print(error_string)
|
||||||
|
|
||||||
class WandbTracker(BaseTracker):
|
class WandbLogger(BaseLogger):
|
||||||
def __init__(self, *args, **kwargs):
|
"""
|
||||||
super().__init__(*args, **kwargs)
|
Logs to a wandb run.
|
||||||
self.wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb experiment tracker')
|
Parameters:
|
||||||
|
data_path (str): A file path for storing temporary data.
|
||||||
|
wandb_entity (str): The wandb entity to log to.
|
||||||
|
wandb_project (str): The wandb project to log to.
|
||||||
|
wandb_run_id (str): The wandb run id to resume.
|
||||||
|
wandb_run_name (str): The wandb run name to use.
|
||||||
|
wandb_resume (bool): Whether to resume a wandb run.
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
data_path: str,
|
||||||
|
wandb_entity: str,
|
||||||
|
wandb_project: str,
|
||||||
|
wandb_run_id: Optional[str] = None,
|
||||||
|
wandb_run_name: Optional[str] = None,
|
||||||
|
wandb_resume: bool = False,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__(data_path, **kwargs)
|
||||||
|
self.entity = wandb_entity
|
||||||
|
self.project = wandb_project
|
||||||
|
self.run_id = wandb_run_id
|
||||||
|
self.run_name = wandb_run_name
|
||||||
|
self.resume = wandb_resume
|
||||||
|
|
||||||
|
def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:
|
||||||
|
assert self.entity is not None, "wandb_entity must be specified for wandb logger"
|
||||||
|
assert self.project is not None, "wandb_project must be specified for wandb logger"
|
||||||
|
self.wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb logger')
|
||||||
os.environ["WANDB_SILENT"] = "true"
|
os.environ["WANDB_SILENT"] = "true"
|
||||||
|
# Initializes the wandb run
|
||||||
|
init_object = {
|
||||||
|
"entity": self.entity,
|
||||||
|
"project": self.project,
|
||||||
|
"config": {**full_config.dict(), **extra_config}
|
||||||
|
}
|
||||||
|
if self.run_name is not None:
|
||||||
|
init_object['name'] = self.run_name
|
||||||
|
if self.resume:
|
||||||
|
assert self.run_id is not None, '`wandb_run_id` must be provided if `wandb_resume` is True'
|
||||||
|
if self.run_name is not None:
|
||||||
|
print("You are renaming a run. I hope that is what you intended.")
|
||||||
|
init_object['resume'] = 'must'
|
||||||
|
init_object['id'] = self.run_id
|
||||||
|
|
||||||
def init(self, **config):
|
self.wandb.init(**init_object)
|
||||||
self.wandb.init(**config)
|
print(f"Logging to wandb run {self.wandb.run.path}-{self.wandb.run.name}")
|
||||||
|
|
||||||
def log(self, log, verbose=False, **kwargs):
|
def log(self, log, **kwargs) -> None:
|
||||||
if verbose:
|
if self.verbose:
|
||||||
print(log)
|
print(log)
|
||||||
self.wandb.log(log, **kwargs)
|
self.wandb.log(log, **kwargs)
|
||||||
|
|
||||||
def log_images(self, images, captions=[], image_section="images", **kwargs):
|
def log_images(self, images, captions=[], image_section="images", **kwargs) -> None:
|
||||||
"""
|
"""
|
||||||
Takes a tensor of images and a list of captions and logs them to wandb.
|
Takes a tensor of images and a list of captions and logs them to wandb.
|
||||||
"""
|
"""
|
||||||
wandb_images = [self.wandb.Image(image, caption=caption) for image, caption in zip_longest(images, captions)]
|
wandb_images = [self.wandb.Image(image, caption=caption) for image, caption in zip_longest(images, captions)]
|
||||||
self.log({ image_section: wandb_images }, **kwargs)
|
self.wandb.log({ image_section: wandb_images }, **kwargs)
|
||||||
|
|
||||||
def save_state_dict(self, state_dict, relative_path, **kwargs):
|
def log_file(self, file_path, base_path: Optional[str] = None, **kwargs) -> None:
|
||||||
|
if base_path is None:
|
||||||
|
# Then we take the basepath as the parent of the file_path
|
||||||
|
base_path = Path(file_path).parent
|
||||||
|
self.wandb.save(str(file_path), base_path = str(base_path))
|
||||||
|
|
||||||
|
def log_error(self, error_string, step=None, **kwargs) -> None:
|
||||||
|
if self.verbose:
|
||||||
|
print(error_string)
|
||||||
|
self.wandb.log({"error": error_string, **kwargs}, step=step)
|
||||||
|
|
||||||
|
logger_type_map = {
|
||||||
|
'console': ConsoleLogger,
|
||||||
|
'wandb': WandbLogger,
|
||||||
|
}
|
||||||
|
def create_logger(logger_type: str, data_path: str, **kwargs) -> BaseLogger:
|
||||||
|
if logger_type == 'custom':
|
||||||
|
raise NotImplementedError('Custom loggers are not supported yet. Please use a different logger type.')
|
||||||
|
try:
|
||||||
|
logger_class = logger_type_map[logger_type]
|
||||||
|
except KeyError:
|
||||||
|
raise ValueError(f'Unknown logger type: {logger_type}. Must be one of {list(logger_type_map.keys())}')
|
||||||
|
return logger_class(data_path, **kwargs)
|
||||||
|
|
||||||
|
class BaseLoader:
|
||||||
|
"""
|
||||||
|
An abstract class representing an object that can load a model checkpoint.
|
||||||
|
Parameters:
|
||||||
|
data_path (str): A file path for storing temporary data.
|
||||||
|
"""
|
||||||
|
def __init__(self, data_path: str, **kwargs):
|
||||||
|
self.data_path = Path(data_path)
|
||||||
|
|
||||||
|
def init(self, logger: BaseLogger, **kwargs) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def recall() -> dict:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
class UrlLoader(BaseLoader):
|
||||||
|
"""
|
||||||
|
A loader that downloads the file from a url and loads it
|
||||||
|
Parameters:
|
||||||
|
data_path (str): A file path for storing temporary data.
|
||||||
|
url (str): The url to download the file from.
|
||||||
|
"""
|
||||||
|
def __init__(self, data_path: str, url: str, **kwargs):
|
||||||
|
super().__init__(data_path, **kwargs)
|
||||||
|
self.url = url
|
||||||
|
|
||||||
|
def init(self, logger: BaseLogger, **kwargs) -> None:
|
||||||
|
# Makes sure the file exists to be downloaded
|
||||||
|
pass # TODO: Actually implement that
|
||||||
|
|
||||||
|
def recall(self) -> dict:
|
||||||
|
# Download the file
|
||||||
|
save_path = self.data_path / 'loaded_checkpoint.pth'
|
||||||
|
urllib.request.urlretrieve(self.url, str(save_path))
|
||||||
|
# Load the file
|
||||||
|
return torch.load(str(save_path), map_location='cpu')
|
||||||
|
|
||||||
|
|
||||||
|
class LocalLoader(BaseLoader):
|
||||||
|
"""
|
||||||
|
A loader that loads a file from a local path
|
||||||
|
Parameters:
|
||||||
|
data_path (str): A file path for storing temporary data.
|
||||||
|
file_path (str): The path to the file to load.
|
||||||
|
"""
|
||||||
|
def __init__(self, data_path: str, file_path: str, **kwargs):
|
||||||
|
super().__init__(data_path, **kwargs)
|
||||||
|
self.file_path = Path(file_path)
|
||||||
|
|
||||||
|
def init(self, logger: BaseLogger, **kwargs) -> None:
|
||||||
|
# Makes sure the file exists to be loaded
|
||||||
|
if not self.file_path.exists():
|
||||||
|
raise FileNotFoundError(f'Model not found at {self.file_path}')
|
||||||
|
|
||||||
|
def recall(self) -> dict:
|
||||||
|
# Load the file
|
||||||
|
return torch.load(str(self.file_path), map_location='cpu')
|
||||||
|
|
||||||
|
class WandbLoader(BaseLoader):
|
||||||
|
"""
|
||||||
|
A loader that loads a model from an existing wandb run
|
||||||
|
"""
|
||||||
|
def __init__(self, data_path: str, wandb_file_path: str, wandb_run_path: Optional[str] = None, **kwargs):
|
||||||
|
super().__init__(data_path, **kwargs)
|
||||||
|
self.run_path = wandb_run_path
|
||||||
|
self.file_path = wandb_file_path
|
||||||
|
|
||||||
|
def init(self, logger: BaseLogger, **kwargs) -> None:
|
||||||
|
self.wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb recall function')
|
||||||
|
# Make sure the file can be downloaded
|
||||||
|
if self.wandb.run is not None and self.run_path is None:
|
||||||
|
self.run_path = self.wandb.run.path
|
||||||
|
assert self.run_path is not None, 'wandb run was not found to load from. If not using the wandb logger must specify the `wandb_run_path`.'
|
||||||
|
assert self.run_path is not None, '`wandb_run_path` must be provided for the wandb loader'
|
||||||
|
assert self.file_path is not None, '`wandb_file_path` must be provided for the wandb loader'
|
||||||
|
|
||||||
|
os.environ["WANDB_SILENT"] = "true"
|
||||||
|
pass # TODO: Actually implement that
|
||||||
|
|
||||||
|
def recall(self) -> dict:
|
||||||
|
file_reference = self.wandb.restore(self.file_path, run_path=self.run_path)
|
||||||
|
return torch.load(file_reference.name, map_location='cpu')
|
||||||
|
|
||||||
|
loader_type_map = {
|
||||||
|
'url': UrlLoader,
|
||||||
|
'local': LocalLoader,
|
||||||
|
'wandb': WandbLoader,
|
||||||
|
}
|
||||||
|
def create_loader(loader_type: str, data_path: str, **kwargs) -> BaseLoader:
|
||||||
|
if loader_type == 'custom':
|
||||||
|
raise NotImplementedError('Custom loaders are not supported yet. Please use a different loader type.')
|
||||||
|
try:
|
||||||
|
loader_class = loader_type_map[loader_type]
|
||||||
|
except KeyError:
|
||||||
|
raise ValueError(f'Unknown loader type: {loader_type}. Must be one of {list(loader_type_map.keys())}')
|
||||||
|
return loader_class(data_path, **kwargs)
|
||||||
|
|
||||||
|
class BaseSaver:
|
||||||
|
def __init__(self,
|
||||||
|
data_path: str,
|
||||||
|
save_latest_to: Optional[Union[str, bool]] = 'latest.pth',
|
||||||
|
save_best_to: Optional[Union[str, bool]] = 'best.pth',
|
||||||
|
save_meta_to: str = './',
|
||||||
|
save_type: str = 'checkpoint',
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
self.data_path = Path(data_path)
|
||||||
|
self.save_latest_to = save_latest_to
|
||||||
|
self.saving_latest = save_latest_to is not None and save_latest_to is not False
|
||||||
|
self.save_best_to = save_best_to
|
||||||
|
self.saving_best = save_best_to is not None and save_best_to is not False
|
||||||
|
self.save_meta_to = save_meta_to
|
||||||
|
self.save_type = save_type
|
||||||
|
assert save_type in ['checkpoint', 'model'], '`save_type` must be one of `checkpoint` or `model`'
|
||||||
|
assert self.save_meta_to is not None, '`save_meta_to` must be provided'
|
||||||
|
assert self.saving_latest or self.saving_best, '`save_latest_to` or `save_best_to` must be provided'
|
||||||
|
|
||||||
|
def init(self, logger: BaseLogger, **kwargs) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def save_file(self, local_path: Path, save_path: str, is_best=False, is_latest=False, **kwargs) -> None:
|
||||||
"""
|
"""
|
||||||
Saves a state_dict to disk and uploads it
|
Save a general file under save_meta_to
|
||||||
"""
|
"""
|
||||||
full_path = str(self.data_path / relative_path)
|
raise NotImplementedError
|
||||||
torch.save(state_dict, full_path)
|
|
||||||
self.wandb.save(full_path, base_path = str(self.data_path)) # Upload and keep relative to data_path
|
class LocalSaver(BaseSaver):
|
||||||
|
def __init__(self,
|
||||||
|
data_path: str,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__(data_path, **kwargs)
|
||||||
|
|
||||||
|
def init(self, logger: BaseLogger, **kwargs) -> None:
|
||||||
|
# Makes sure the directory exists to be saved to
|
||||||
|
print(f"Saving {self.save_type} locally")
|
||||||
|
if not self.data_path.exists():
|
||||||
|
self.data_path.mkdir(parents=True)
|
||||||
|
|
||||||
|
def save_file(self, local_path: str, save_path: str, **kwargs) -> None:
|
||||||
|
# Copy the file to save_path
|
||||||
|
save_path_file_name = Path(save_path).name
|
||||||
|
print(f"Saving {save_path_file_name} {self.save_type} to local path {save_path}")
|
||||||
|
shutil.copy(local_path, save_path)
|
||||||
|
|
||||||
|
class WandbSaver(BaseSaver):
|
||||||
|
def __init__(self, data_path: str, wandb_run_path: Optional[str] = None, **kwargs):
|
||||||
|
super().__init__(data_path, **kwargs)
|
||||||
|
self.run_path = wandb_run_path
|
||||||
|
|
||||||
|
def init(self, logger: BaseLogger, **kwargs) -> None:
|
||||||
|
self.wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb logger')
|
||||||
|
os.environ["WANDB_SILENT"] = "true"
|
||||||
|
# Makes sure that the user can upload tot his run
|
||||||
|
if self.run_path is not None:
|
||||||
|
entity, project, run_id = self.run_path.split("/")
|
||||||
|
self.run = self.wandb.init(entity=entity, project=project, id=run_id)
|
||||||
|
else:
|
||||||
|
assert self.wandb.run is not None, 'You must be using the wandb logger if you are saving to wandb and have not set `wandb_run_path`'
|
||||||
|
self.run = self.wandb.run
|
||||||
|
# TODO: Now actually check if upload is possible
|
||||||
|
print(f"Saving to wandb run {self.run.path}-{self.run.name}")
|
||||||
|
|
||||||
|
def save_file(self, local_path: Path, save_path: str, **kwargs) -> None:
|
||||||
|
# In order to log something in the correct place in wandb, we need to have the same file structure here
|
||||||
|
save_path_file_name = Path(save_path).name
|
||||||
|
print(f"Saving {save_path_file_name} {self.save_type} to wandb run {self.run.path}-{self.run.name}")
|
||||||
|
save_path = Path(self.data_path) / save_path
|
||||||
|
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
shutil.copy(local_path, save_path)
|
||||||
|
self.run.save(str(save_path), base_path = str(self.data_path), policy='now')
|
||||||
|
|
||||||
|
class HuggingfaceSaver(BaseSaver):
|
||||||
|
def __init__(self, data_path: str, huggingface_repo: str, token_path: Optional[str] = None, **kwargs):
|
||||||
|
super().__init__(data_path, **kwargs)
|
||||||
|
self.huggingface_repo = huggingface_repo
|
||||||
|
self.token_path = token_path
|
||||||
|
|
||||||
|
def init(self, logger: BaseLogger, **kwargs):
|
||||||
|
# Makes sure this user can upload to the repo
|
||||||
|
self.hub = import_or_print_error('huggingface_hub', '`pip install huggingface_hub` to use the huggingface saver')
|
||||||
|
try:
|
||||||
|
identity = self.hub.whoami() # Errors if not logged in
|
||||||
|
# Then we are logged in
|
||||||
|
except:
|
||||||
|
# We are not logged in. Use the token_path to set the token.
|
||||||
|
if not os.path.exists(self.token_path):
|
||||||
|
raise Exception("Not logged in to huggingface and no token_path specified. Please login with `huggingface-cli login` or if that does not work set the token_path.")
|
||||||
|
with open(self.token_path, "r") as f:
|
||||||
|
token = f.read().strip()
|
||||||
|
self.hub.HfApi.set_access_token(token)
|
||||||
|
identity = self.hub.whoami()
|
||||||
|
print(f"Saving to huggingface repo {self.huggingface_repo}")
|
||||||
|
|
||||||
|
def save_file(self, local_path: Path, save_path: str, **kwargs) -> None:
|
||||||
|
# Saving to huggingface is easy, we just need to upload the file with the correct name
|
||||||
|
save_path_file_name = Path(save_path).name
|
||||||
|
print(f"Saving {save_path_file_name} {self.save_type} to huggingface repo {self.huggingface_repo}")
|
||||||
|
self.hub.upload_file(
|
||||||
|
path_or_fileobj=str(local_path),
|
||||||
|
path_in_repo=str(save_path),
|
||||||
|
repo_id=self.huggingface_repo
|
||||||
|
)
|
||||||
|
|
||||||
|
saver_type_map = {
|
||||||
|
'local': LocalSaver,
|
||||||
|
'wandb': WandbSaver,
|
||||||
|
'huggingface': HuggingfaceSaver
|
||||||
|
}
|
||||||
|
def create_saver(saver_type: str, data_path: str, **kwargs) -> BaseSaver:
|
||||||
|
if saver_type == 'custom':
|
||||||
|
raise NotImplementedError('Custom savers are not supported yet. Please use a different saver type.')
|
||||||
|
try:
|
||||||
|
saver_class = saver_type_map[saver_type]
|
||||||
|
except KeyError:
|
||||||
|
raise ValueError(f'Unknown saver type: {saver_type}. Must be one of {list(saver_type_map.keys())}')
|
||||||
|
return saver_class(data_path, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class Tracker:
|
||||||
|
def __init__(self, data_path: Optional[str] = DEFAULT_DATA_PATH, overwrite_data_path: bool = False, dummy_mode: bool = False):
|
||||||
|
self.data_path = Path(data_path)
|
||||||
|
if not dummy_mode:
|
||||||
|
if overwrite_data_path:
|
||||||
|
if self.data_path.exists():
|
||||||
|
shutil.rmtree(self.data_path)
|
||||||
|
self.data_path.mkdir(parents=True)
|
||||||
|
else:
|
||||||
|
assert not self.data_path.exists(), f'Data path {self.data_path} already exists. Set overwrite_data_path to True to overwrite.'
|
||||||
|
if not self.data_path.exists():
|
||||||
|
self.data_path.mkdir(parents=True)
|
||||||
|
self.logger: BaseLogger = None
|
||||||
|
self.loader: Optional[BaseLoader] = None
|
||||||
|
self.savers: List[BaseSaver]= []
|
||||||
|
self.dummy_mode = dummy_mode
|
||||||
|
|
||||||
|
def init(self, full_config: BaseModel, extra_config: dict):
|
||||||
|
assert self.logger is not None, '`logger` must be set before `init` is called'
|
||||||
|
if self.dummy_mode:
|
||||||
|
# The only thing we need is a loader
|
||||||
|
if self.loader is not None:
|
||||||
|
self.loader.init(self.logger)
|
||||||
|
return
|
||||||
|
assert len(self.savers) > 0, '`savers` must be set before `init` is called'
|
||||||
|
self.logger.init(full_config, extra_config)
|
||||||
|
if self.loader is not None:
|
||||||
|
self.loader.init(self.logger)
|
||||||
|
for saver in self.savers:
|
||||||
|
saver.init(self.logger)
|
||||||
|
|
||||||
|
def add_logger(self, logger: BaseLogger):
|
||||||
|
self.logger = logger
|
||||||
|
|
||||||
|
def add_loader(self, loader: BaseLoader):
|
||||||
|
self.loader = loader
|
||||||
|
|
||||||
|
def add_saver(self, saver: BaseSaver):
|
||||||
|
self.savers.append(saver)
|
||||||
|
|
||||||
|
def log(self, *args, **kwargs):
|
||||||
|
if self.dummy_mode:
|
||||||
|
return
|
||||||
|
self.logger.log(*args, **kwargs)
|
||||||
|
|
||||||
|
def log_images(self, *args, **kwargs):
|
||||||
|
if self.dummy_mode:
|
||||||
|
return
|
||||||
|
self.logger.log_images(*args, **kwargs)
|
||||||
|
|
||||||
|
def log_file(self, *args, **kwargs):
|
||||||
|
if self.dummy_mode:
|
||||||
|
return
|
||||||
|
self.logger.log_file(*args, **kwargs)
|
||||||
|
|
||||||
|
def save_config(self, current_config_path: str, config_name = 'config.json'):
|
||||||
|
if self.dummy_mode:
|
||||||
|
return
|
||||||
|
# Save the config under config_name in the root folder of data_path
|
||||||
|
shutil.copy(current_config_path, self.data_path / config_name)
|
||||||
|
for saver in self.savers:
|
||||||
|
remote_path = Path(saver.save_meta_to) / config_name
|
||||||
|
saver.save_file(current_config_path, str(remote_path))
|
||||||
|
|
||||||
|
def _save_state_dict(self, trainer: Union[DiffusionPriorTrainer, DecoderTrainer], save_type: str, file_path: str, **kwargs) -> Path:
|
||||||
|
"""
|
||||||
|
Gets the state dict to be saved and writes it to file_path.
|
||||||
|
If save_type is 'checkpoint', we save the entire trainer state dict.
|
||||||
|
If save_type is 'model', we save only the model state dict.
|
||||||
|
"""
|
||||||
|
assert save_type in ['checkpoint', 'model']
|
||||||
|
if save_type == 'checkpoint':
|
||||||
|
trainer.save(file_path, overwrite=True, **kwargs)
|
||||||
|
elif save_type == 'model':
|
||||||
|
if isinstance(trainer, DiffusionPriorTrainer):
|
||||||
|
prior = trainer.ema_diffusion_prior.ema_model if trainer.use_ema else trainer.diffusion_prior
|
||||||
|
state_dict = trainer.unwrap_model(prior).state_dict()
|
||||||
|
torch.save(state_dict, file_path)
|
||||||
|
elif isinstance(trainer, DecoderTrainer):
|
||||||
|
decoder = trainer.accelerator.unwrap_model(trainer.decoder)
|
||||||
|
if trainer.use_ema:
|
||||||
|
trainable_unets = decoder.unets
|
||||||
|
decoder.unets = trainer.unets # Swap EMA unets in
|
||||||
|
state_dict = decoder.state_dict()
|
||||||
|
decoder.unets = trainable_unets # Swap back
|
||||||
|
else:
|
||||||
|
state_dict = decoder.state_dict()
|
||||||
|
torch.save(state_dict, file_path)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError('Saving this type of model with EMA mode enabled is not yet implemented. Actually, how did you get here?')
|
||||||
|
return Path(file_path)
|
||||||
|
|
||||||
|
def save(self, trainer, is_best: bool, is_latest: bool, **kwargs):
|
||||||
|
if self.dummy_mode:
|
||||||
|
return
|
||||||
|
if not is_best and not is_latest:
|
||||||
|
# Nothing to do
|
||||||
|
return
|
||||||
|
# Save the checkpoint and model to data_path
|
||||||
|
checkpoint_path = self.data_path / 'checkpoint.pth'
|
||||||
|
self._save_state_dict(trainer, 'checkpoint', checkpoint_path, **kwargs)
|
||||||
|
model_path = self.data_path / 'model.pth'
|
||||||
|
self._save_state_dict(trainer, 'model', model_path, **kwargs)
|
||||||
|
print("Saved cached models")
|
||||||
|
# Call the save methods on the savers
|
||||||
|
for saver in self.savers:
|
||||||
|
local_path = checkpoint_path if saver.save_type == 'checkpoint' else model_path
|
||||||
|
if saver.saving_latest and is_latest:
|
||||||
|
latest_checkpoint_path = saver.save_latest_to.format(**kwargs)
|
||||||
|
try:
|
||||||
|
saver.save_file(local_path, latest_checkpoint_path, is_latest=True, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.log_error(f'Error saving checkpoint: {e}', **kwargs)
|
||||||
|
print(f'Error saving checkpoint: {e}')
|
||||||
|
if saver.saving_best and is_best:
|
||||||
|
best_checkpoint_path = saver.save_best_to.format(**kwargs)
|
||||||
|
try:
|
||||||
|
saver.save_file(local_path, best_checkpoint_path, is_best=True, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.log_error(f'Error saving checkpoint: {e}', **kwargs)
|
||||||
|
print(f'Error saving checkpoint: {e}')
|
||||||
|
|
||||||
|
def recall(self):
|
||||||
|
if self.loader is not None:
|
||||||
|
return self.loader.recall()
|
||||||
|
else:
|
||||||
|
raise ValueError('No loader specified')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -3,18 +3,220 @@ from torchvision import transforms as T
|
|||||||
from pydantic import BaseModel, validator, root_validator
|
from pydantic import BaseModel, validator, root_validator
|
||||||
from typing import List, Iterable, Optional, Union, Tuple, Dict, Any
|
from typing import List, Iterable, Optional, Union, Tuple, Dict, Any
|
||||||
|
|
||||||
|
from x_clip import CLIP as XCLIP
|
||||||
|
from coca_pytorch import CoCa
|
||||||
|
|
||||||
|
from dalle2_pytorch.dalle2_pytorch import (
|
||||||
|
CoCaAdapter,
|
||||||
|
OpenAIClipAdapter,
|
||||||
|
Unet,
|
||||||
|
Decoder,
|
||||||
|
DiffusionPrior,
|
||||||
|
DiffusionPriorNetwork,
|
||||||
|
XClipAdapter
|
||||||
|
)
|
||||||
|
from dalle2_pytorch.trackers import Tracker, create_loader, create_logger, create_saver
|
||||||
|
|
||||||
|
# helper functions
|
||||||
|
|
||||||
def exists(val):
|
def exists(val):
|
||||||
return val is not None
|
return val is not None
|
||||||
|
|
||||||
def default(val, d):
|
def default(val, d):
|
||||||
return val if exists(val) else d
|
return val if exists(val) else d
|
||||||
|
|
||||||
|
def ListOrTuple(inner_type):
|
||||||
|
return Union[List[inner_type], Tuple[inner_type]]
|
||||||
|
|
||||||
|
def SingularOrIterable(inner_type):
|
||||||
|
return Union[inner_type, ListOrTuple(inner_type)]
|
||||||
|
|
||||||
|
# general pydantic classes
|
||||||
|
|
||||||
|
class TrainSplitConfig(BaseModel):
|
||||||
|
train: float = 0.75
|
||||||
|
val: float = 0.15
|
||||||
|
test: float = 0.1
|
||||||
|
|
||||||
|
@root_validator
|
||||||
|
def validate_all(cls, fields):
|
||||||
|
actual_sum = sum([*fields.values()])
|
||||||
|
if actual_sum != 1.:
|
||||||
|
raise ValueError(f'{fields.keys()} must sum to 1.0. Found: {actual_sum}')
|
||||||
|
return fields
|
||||||
|
|
||||||
|
class TrackerLogConfig(BaseModel):
|
||||||
|
log_type: str = 'console'
|
||||||
|
verbose: bool = False
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
# Each individual log type has it's own arguments that will be passed through the config
|
||||||
|
extra = "allow"
|
||||||
|
|
||||||
|
def create(self, data_path: str):
|
||||||
|
kwargs = self.dict()
|
||||||
|
return create_logger(self.log_type, data_path, **kwargs)
|
||||||
|
|
||||||
|
class TrackerLoadConfig(BaseModel):
|
||||||
|
load_from: Optional[str] = None
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
extra = "allow"
|
||||||
|
|
||||||
|
def create(self, data_path: str):
|
||||||
|
kwargs = self.dict()
|
||||||
|
if self.load_from is None:
|
||||||
|
return None
|
||||||
|
return create_loader(self.load_from, data_path, **kwargs)
|
||||||
|
|
||||||
|
class TrackerSaveConfig(BaseModel):
|
||||||
|
save_to: str = 'local'
|
||||||
|
save_all: bool = False
|
||||||
|
save_latest: bool = True
|
||||||
|
save_best: bool = True
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
extra = "allow"
|
||||||
|
|
||||||
|
def create(self, data_path: str):
|
||||||
|
kwargs = self.dict()
|
||||||
|
return create_saver(self.save_to, data_path, **kwargs)
|
||||||
|
|
||||||
|
class TrackerConfig(BaseModel):
|
||||||
|
data_path: str = '.tracker_data'
|
||||||
|
overwrite_data_path: bool = False
|
||||||
|
log: TrackerLogConfig
|
||||||
|
load: Optional[TrackerLoadConfig]
|
||||||
|
save: Union[List[TrackerSaveConfig], TrackerSaveConfig]
|
||||||
|
|
||||||
|
def create(self, full_config: BaseModel, extra_config: dict, dummy_mode: bool = False) -> Tracker:
|
||||||
|
tracker = Tracker(self.data_path, dummy_mode=dummy_mode, overwrite_data_path=self.overwrite_data_path)
|
||||||
|
# Add the logger
|
||||||
|
tracker.add_logger(self.log.create(self.data_path))
|
||||||
|
# Add the loader
|
||||||
|
if self.load is not None:
|
||||||
|
tracker.add_loader(self.load.create(self.data_path))
|
||||||
|
# Add the saver or savers
|
||||||
|
if isinstance(self.save, list):
|
||||||
|
for save_config in self.save:
|
||||||
|
tracker.add_saver(save_config.create(self.data_path))
|
||||||
|
else:
|
||||||
|
tracker.add_saver(self.save.create(self.data_path))
|
||||||
|
# Initialize all the components and verify that all data is valid
|
||||||
|
tracker.init(full_config, extra_config)
|
||||||
|
return tracker
|
||||||
|
|
||||||
|
# diffusion prior pydantic classes
|
||||||
|
|
||||||
|
class AdapterConfig(BaseModel):
|
||||||
|
make: str = "openai"
|
||||||
|
model: str = "ViT-L/14"
|
||||||
|
base_model_kwargs: Dict[str, Any] = None
|
||||||
|
|
||||||
|
def create(self):
|
||||||
|
if self.make == "openai":
|
||||||
|
return OpenAIClipAdapter(self.model)
|
||||||
|
elif self.make == "x-clip":
|
||||||
|
return XClipAdapter(XCLIP(**self.base_model_kwargs))
|
||||||
|
elif self.make == "coca":
|
||||||
|
return CoCaAdapter(CoCa(**self.base_model_kwargs))
|
||||||
|
else:
|
||||||
|
raise AttributeError("No adapter with that name is available.")
|
||||||
|
|
||||||
|
class DiffusionPriorNetworkConfig(BaseModel):
|
||||||
|
dim: int
|
||||||
|
depth: int
|
||||||
|
num_timesteps: int = None
|
||||||
|
num_time_embeds: int = 1
|
||||||
|
num_image_embeds: int = 1
|
||||||
|
num_text_embeds: int = 1
|
||||||
|
dim_head: int = 64
|
||||||
|
heads: int = 8
|
||||||
|
ff_mult: int = 4
|
||||||
|
norm_out: bool = True
|
||||||
|
attn_dropout: float = 0.
|
||||||
|
ff_dropout: float = 0.
|
||||||
|
final_proj: bool = True
|
||||||
|
normformer: bool = False
|
||||||
|
rotary_emb: bool = True
|
||||||
|
|
||||||
|
def create(self):
|
||||||
|
kwargs = self.dict()
|
||||||
|
return DiffusionPriorNetwork(**kwargs)
|
||||||
|
|
||||||
|
class DiffusionPriorConfig(BaseModel):
|
||||||
|
clip: AdapterConfig = None
|
||||||
|
net: DiffusionPriorNetworkConfig
|
||||||
|
image_embed_dim: int
|
||||||
|
image_size: int
|
||||||
|
image_channels: int = 3
|
||||||
|
timesteps: int = 1000
|
||||||
|
cond_drop_prob: float = 0.
|
||||||
|
loss_type: str = 'l2'
|
||||||
|
predict_x_start: bool = True
|
||||||
|
beta_schedule: str = 'cosine'
|
||||||
|
condition_on_text_encodings: bool = True
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
extra = "allow"
|
||||||
|
|
||||||
|
def create(self):
|
||||||
|
kwargs = self.dict()
|
||||||
|
|
||||||
|
has_clip = exists(kwargs.pop('clip'))
|
||||||
|
kwargs.pop('net')
|
||||||
|
|
||||||
|
clip = None
|
||||||
|
if has_clip:
|
||||||
|
clip = self.clip.create()
|
||||||
|
|
||||||
|
diffusion_prior_network = self.net.create()
|
||||||
|
return DiffusionPrior(net = diffusion_prior_network, clip = clip, **kwargs)
|
||||||
|
|
||||||
|
class DiffusionPriorTrainConfig(BaseModel):
|
||||||
|
epochs: int = 1
|
||||||
|
lr: float = 1.1e-4
|
||||||
|
wd: float = 6.02e-2
|
||||||
|
max_grad_norm: float = 0.5
|
||||||
|
use_ema: bool = True
|
||||||
|
ema_beta: float = 0.99
|
||||||
|
amp: bool = False
|
||||||
|
save_every: int = 10000 # what steps to save on
|
||||||
|
|
||||||
|
class DiffusionPriorDataConfig(BaseModel):
|
||||||
|
image_url: str # path to embeddings folder
|
||||||
|
meta_url: str # path to metadata (captions) for images
|
||||||
|
splits: TrainSplitConfig
|
||||||
|
batch_size: int = 64
|
||||||
|
|
||||||
|
class DiffusionPriorLoadConfig(BaseModel):
|
||||||
|
source: str = None
|
||||||
|
resume: bool = False
|
||||||
|
|
||||||
|
class TrainDiffusionPriorConfig(BaseModel):
|
||||||
|
prior: DiffusionPriorConfig
|
||||||
|
data: DiffusionPriorDataConfig
|
||||||
|
train: DiffusionPriorTrainConfig
|
||||||
|
load: DiffusionPriorLoadConfig
|
||||||
|
tracker: TrackerConfig
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_json_path(cls, json_path):
|
||||||
|
with open(json_path) as f:
|
||||||
|
config = json.load(f)
|
||||||
|
return cls(**config)
|
||||||
|
|
||||||
|
# decoder pydantic classes
|
||||||
|
|
||||||
class UnetConfig(BaseModel):
|
class UnetConfig(BaseModel):
|
||||||
dim: int
|
dim: int
|
||||||
dim_mults: List[int]
|
dim_mults: ListOrTuple(int)
|
||||||
image_embed_dim: int = None
|
image_embed_dim: int = None
|
||||||
|
text_embed_dim: int = None
|
||||||
|
cond_on_text_encodings: bool = None
|
||||||
cond_dim: int = None
|
cond_dim: int = None
|
||||||
channels: int = 3
|
channels: int = 3
|
||||||
|
self_attn: ListOrTuple(int)
|
||||||
attn_dim_head: int = 32
|
attn_dim_head: int = 32
|
||||||
attn_heads: int = 16
|
attn_heads: int = 16
|
||||||
|
|
||||||
@@ -22,13 +224,30 @@ class UnetConfig(BaseModel):
|
|||||||
extra = "allow"
|
extra = "allow"
|
||||||
|
|
||||||
class DecoderConfig(BaseModel):
|
class DecoderConfig(BaseModel):
|
||||||
|
unets: ListOrTuple(UnetConfig)
|
||||||
image_size: int = None
|
image_size: int = None
|
||||||
image_sizes: Union[List[int], Tuple[int]] = None
|
image_sizes: ListOrTuple(int) = None
|
||||||
|
clip: Optional[AdapterConfig] # The clip model to use if embeddings are not provided
|
||||||
channels: int = 3
|
channels: int = 3
|
||||||
timesteps: int = 1000
|
timesteps: int = 1000
|
||||||
loss_type: str = 'l2'
|
loss_type: str = 'l2'
|
||||||
beta_schedule: str = 'cosine'
|
beta_schedule: ListOrTuple(str) = 'cosine'
|
||||||
learned_variance: bool = True
|
learned_variance: bool = True
|
||||||
|
image_cond_drop_prob: float = 0.1
|
||||||
|
text_cond_drop_prob: float = 0.5
|
||||||
|
|
||||||
|
def create(self):
|
||||||
|
decoder_kwargs = self.dict()
|
||||||
|
|
||||||
|
unet_configs = decoder_kwargs.pop('unets')
|
||||||
|
unets = [Unet(**config) for config in unet_configs]
|
||||||
|
|
||||||
|
has_clip = exists(decoder_kwargs.pop('clip'))
|
||||||
|
clip = None
|
||||||
|
if has_clip:
|
||||||
|
clip = self.clip.create()
|
||||||
|
|
||||||
|
return Decoder(unets, clip=clip, **decoder_kwargs)
|
||||||
|
|
||||||
@validator('image_sizes')
|
@validator('image_sizes')
|
||||||
def check_image_sizes(cls, image_sizes, values):
|
def check_image_sizes(cls, image_sizes, values):
|
||||||
@@ -39,20 +258,10 @@ class DecoderConfig(BaseModel):
|
|||||||
class Config:
|
class Config:
|
||||||
extra = "allow"
|
extra = "allow"
|
||||||
|
|
||||||
class TrainSplitConfig(BaseModel):
|
|
||||||
train: float = 0.75
|
|
||||||
val: float = 0.15
|
|
||||||
test: float = 0.1
|
|
||||||
|
|
||||||
@root_validator
|
|
||||||
def validate_all(cls, fields):
|
|
||||||
if sum([*fields.values()]) != 1.:
|
|
||||||
raise ValueError(f'{fields.keys()} must sum to 1.0')
|
|
||||||
return fields
|
|
||||||
|
|
||||||
class DecoderDataConfig(BaseModel):
|
class DecoderDataConfig(BaseModel):
|
||||||
webdataset_base_url: str # path to a webdataset with jpg images
|
webdataset_base_url: str # path to a webdataset with jpg images
|
||||||
embeddings_url: str # path to .npy files with embeddings
|
img_embeddings_url: Optional[str] # path to .npy files with embeddings
|
||||||
|
text_embeddings_url: Optional[str] # path to .npy files with embeddings
|
||||||
num_workers: int = 4
|
num_workers: int = 4
|
||||||
batch_size: int = 64
|
batch_size: int = 64
|
||||||
start_shard: int = 0
|
start_shard: int = 0
|
||||||
@@ -64,60 +273,6 @@ class DecoderDataConfig(BaseModel):
|
|||||||
resample_train: bool = False
|
resample_train: bool = False
|
||||||
preprocessing: Dict[str, Any] = {'ToTensor': True}
|
preprocessing: Dict[str, Any] = {'ToTensor': True}
|
||||||
|
|
||||||
class DecoderTrainConfig(BaseModel):
|
|
||||||
epochs: int = 20
|
|
||||||
lr: float = 1e-4
|
|
||||||
wd: float = 0.01
|
|
||||||
max_grad_norm: float = 0.5
|
|
||||||
save_every_n_samples: int = 100000
|
|
||||||
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
|
|
||||||
device: str = 'cuda:0'
|
|
||||||
epoch_samples: int = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
|
|
||||||
validation_samples: int = None # Same as above but for validation.
|
|
||||||
use_ema: bool = True
|
|
||||||
ema_beta: float = 0.99
|
|
||||||
amp: bool = False
|
|
||||||
save_all: bool = False # Whether to preserve all checkpoints
|
|
||||||
save_latest: bool = True # Whether to always save the latest checkpoint
|
|
||||||
save_best: bool = True # Whether to save the best checkpoint
|
|
||||||
unet_training_mask: List[bool] = None # If None, use all unets
|
|
||||||
|
|
||||||
class DecoderEvaluateConfig(BaseModel):
|
|
||||||
n_evaluation_samples: int = 1000
|
|
||||||
FID: Dict[str, Any] = None
|
|
||||||
IS: Dict[str, Any] = None
|
|
||||||
KID: Dict[str, Any] = None
|
|
||||||
LPIPS: Dict[str, Any] = None
|
|
||||||
|
|
||||||
class TrackerConfig(BaseModel):
|
|
||||||
tracker_type: str = 'console' # Decoder currently supports console and wandb
|
|
||||||
data_path: str = './models' # The path where files will be saved locally
|
|
||||||
init_config: Dict[str, Any] = None
|
|
||||||
wandb_entity: str = '' # Only needs to be set if tracker_type is wandb
|
|
||||||
wandb_project: str = ''
|
|
||||||
verbose: bool = False # Whether to print console logging for non-console trackers
|
|
||||||
|
|
||||||
class DecoderLoadConfig(BaseModel):
|
|
||||||
source: str = None # Supports file and wandb
|
|
||||||
run_path: str = '' # Used only if source is wandb
|
|
||||||
file_path: str = '' # The local filepath if source is file. If source is wandb, the relative path to the model file in wandb.
|
|
||||||
resume: bool = False # If using wandb, whether to resume the run
|
|
||||||
|
|
||||||
class TrainDecoderConfig(BaseModel):
|
|
||||||
unets: List[UnetConfig]
|
|
||||||
decoder: DecoderConfig
|
|
||||||
data: DecoderDataConfig
|
|
||||||
train: DecoderTrainConfig
|
|
||||||
evaluate: DecoderEvaluateConfig
|
|
||||||
tracker: TrackerConfig
|
|
||||||
load: DecoderLoadConfig
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_json_path(cls, json_path):
|
|
||||||
with open(json_path) as f:
|
|
||||||
config = json.load(f)
|
|
||||||
return cls(**config)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def img_preproc(self):
|
def img_preproc(self):
|
||||||
def _get_transformation(transformation_name, **kwargs):
|
def _get_transformation(transformation_name, **kwargs):
|
||||||
@@ -129,7 +284,79 @@ class TrainDecoderConfig(BaseModel):
|
|||||||
return T.ToTensor()
|
return T.ToTensor()
|
||||||
|
|
||||||
transforms = []
|
transforms = []
|
||||||
for transform_name, transform_kwargs_or_bool in self.data.preprocessing.items():
|
for transform_name, transform_kwargs_or_bool in self.preprocessing.items():
|
||||||
transform_kwargs = {} if not isinstance(transform_kwargs_or_bool, dict) else transform_kwargs_or_bool
|
transform_kwargs = {} if not isinstance(transform_kwargs_or_bool, dict) else transform_kwargs_or_bool
|
||||||
transforms.append(_get_transformation(transform_name, **transform_kwargs))
|
transforms.append(_get_transformation(transform_name, **transform_kwargs))
|
||||||
return T.Compose(transforms)
|
return T.Compose(transforms)
|
||||||
|
|
||||||
|
class DecoderTrainConfig(BaseModel):
|
||||||
|
epochs: int = 20
|
||||||
|
lr: SingularOrIterable(float) = 1e-4
|
||||||
|
wd: SingularOrIterable(float) = 0.01
|
||||||
|
find_unused_parameters: bool = True
|
||||||
|
max_grad_norm: SingularOrIterable(float) = 0.5
|
||||||
|
save_every_n_samples: int = 100000
|
||||||
|
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
|
||||||
|
device: str = 'cuda:0'
|
||||||
|
epoch_samples: int = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
|
||||||
|
validation_samples: int = None # Same as above but for validation.
|
||||||
|
use_ema: bool = True
|
||||||
|
ema_beta: float = 0.999
|
||||||
|
amp: bool = False
|
||||||
|
unet_training_mask: ListOrTuple(bool) = None # If None, use all unets
|
||||||
|
|
||||||
|
class DecoderEvaluateConfig(BaseModel):
|
||||||
|
n_evaluation_samples: int = 1000
|
||||||
|
FID: Dict[str, Any] = None
|
||||||
|
IS: Dict[str, Any] = None
|
||||||
|
KID: Dict[str, Any] = None
|
||||||
|
LPIPS: Dict[str, Any] = None
|
||||||
|
|
||||||
|
class DecoderLoadConfig(BaseModel):
|
||||||
|
source: str = None # Supports file and wandb
|
||||||
|
run_path: str = '' # Used only if source is wandb
|
||||||
|
file_path: str = '' # The local filepath if source is file. If source is wandb, the relative path to the model file in wandb.
|
||||||
|
resume: bool = False # If using wandb, whether to resume the run
|
||||||
|
|
||||||
|
class TrainDecoderConfig(BaseModel):
|
||||||
|
decoder: DecoderConfig
|
||||||
|
data: DecoderDataConfig
|
||||||
|
train: DecoderTrainConfig
|
||||||
|
evaluate: DecoderEvaluateConfig
|
||||||
|
tracker: TrackerConfig
|
||||||
|
seed: int = 0
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_json_path(cls, json_path):
|
||||||
|
with open(json_path) as f:
|
||||||
|
config = json.load(f)
|
||||||
|
return cls(**config)
|
||||||
|
|
||||||
|
@root_validator
|
||||||
|
def check_has_embeddings(cls, values):
|
||||||
|
# Makes sure that enough information is provided to get the embeddings specified for training
|
||||||
|
data_config, decoder_config = values.get('data'), values.get('decoder')
|
||||||
|
|
||||||
|
if not exists(data_config) or not exists(decoder_config):
|
||||||
|
# Then something else errored and we should just pass through
|
||||||
|
return values
|
||||||
|
|
||||||
|
using_text_embeddings = any([unet.cond_on_text_encodings for unet in decoder_config.unets])
|
||||||
|
using_clip = exists(decoder_config.clip)
|
||||||
|
img_emb_url = data_config.img_embeddings_url
|
||||||
|
text_emb_url = data_config.text_embeddings_url
|
||||||
|
|
||||||
|
if using_text_embeddings:
|
||||||
|
# Then we need some way to get the embeddings
|
||||||
|
assert using_clip or exists(text_emb_url), 'If text conditioning, either clip or text_embeddings_url must be provided'
|
||||||
|
|
||||||
|
if using_clip:
|
||||||
|
if using_text_embeddings:
|
||||||
|
assert not exists(text_emb_url) or not exists(img_emb_url), 'Loaded clip, but also provided text_embeddings_url and img_embeddings_url. This is redundant. Remove the clip model or the text embeddings'
|
||||||
|
else:
|
||||||
|
assert not exists(img_emb_url), 'Loaded clip, but also provided img_embeddings_url. This is redundant. Remove the clip model or the embeddings'
|
||||||
|
|
||||||
|
if text_emb_url:
|
||||||
|
assert using_text_embeddings, "Text embeddings are being loaded, but text embeddings are not being conditioned on. This will slow down the dataloader for no reason."
|
||||||
|
|
||||||
|
return values
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import time
|
import time
|
||||||
import copy
|
import copy
|
||||||
|
from pathlib import Path
|
||||||
from math import ceil
|
from math import ceil
|
||||||
from functools import partial, wraps
|
from functools import partial, wraps
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
@@ -10,6 +11,12 @@ from torch.cuda.amp import autocast, GradScaler
|
|||||||
|
|
||||||
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
|
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
|
||||||
from dalle2_pytorch.optimizer import get_optimizer
|
from dalle2_pytorch.optimizer import get_optimizer
|
||||||
|
from dalle2_pytorch.version import __version__
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
|
from ema_pytorch import EMA
|
||||||
|
|
||||||
|
from accelerate import Accelerator
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -19,7 +26,9 @@ def exists(val):
|
|||||||
return val is not None
|
return val is not None
|
||||||
|
|
||||||
def default(val, d):
|
def default(val, d):
|
||||||
return val if exists(val) else d
|
if exists(val):
|
||||||
|
return val
|
||||||
|
return d() if callable(d) else d
|
||||||
|
|
||||||
def cast_tuple(val, length = 1):
|
def cast_tuple(val, length = 1):
|
||||||
return val if isinstance(val, tuple) else ((val,) * length)
|
return val if isinstance(val, tuple) else ((val,) * length)
|
||||||
@@ -128,111 +137,6 @@ def split_args_and_kwargs(*args, split_size = None, **kwargs):
|
|||||||
chunk_size_frac = chunk_size / batch_size
|
chunk_size_frac = chunk_size / batch_size
|
||||||
yield chunk_size_frac, (chunked_args, chunked_kwargs)
|
yield chunk_size_frac, (chunked_args, chunked_kwargs)
|
||||||
|
|
||||||
# print helpers
|
|
||||||
|
|
||||||
def print_ribbon(s, symbol = '=', repeat = 40):
|
|
||||||
flank = symbol * repeat
|
|
||||||
return f'{flank} {s} {flank}'
|
|
||||||
|
|
||||||
# saving and loading functions
|
|
||||||
|
|
||||||
# for diffusion prior
|
|
||||||
|
|
||||||
def load_diffusion_model(dprior_path, device):
|
|
||||||
dprior_path = Path(dprior_path)
|
|
||||||
assert dprior_path.exists(), 'Dprior model file does not exist'
|
|
||||||
loaded_obj = torch.load(str(dprior_path), map_location='cpu')
|
|
||||||
|
|
||||||
# Get hyperparameters of loaded model
|
|
||||||
dpn_config = loaded_obj['hparams']['diffusion_prior_network']
|
|
||||||
dp_config = loaded_obj['hparams']['diffusion_prior']
|
|
||||||
image_embed_dim = loaded_obj['image_embed_dim']['image_embed_dim']
|
|
||||||
|
|
||||||
# Create DiffusionPriorNetwork and DiffusionPrior with loaded hyperparameters
|
|
||||||
|
|
||||||
# DiffusionPriorNetwork
|
|
||||||
prior_network = DiffusionPriorNetwork( dim = image_embed_dim, **dpn_config).to(device)
|
|
||||||
|
|
||||||
# DiffusionPrior with text embeddings and image embeddings pre-computed
|
|
||||||
diffusion_prior = DiffusionPrior(net = prior_network, **dp_config, image_embed_dim = image_embed_dim).to(device)
|
|
||||||
|
|
||||||
# Load state dict from saved model
|
|
||||||
diffusion_prior.load_state_dict(loaded_obj['model'])
|
|
||||||
|
|
||||||
return diffusion_prior, loaded_obj
|
|
||||||
|
|
||||||
def save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim):
|
|
||||||
# Saving State Dict
|
|
||||||
print_ribbon('Saving checkpoint')
|
|
||||||
|
|
||||||
state_dict = dict(model=model.state_dict(),
|
|
||||||
optimizer=optimizer.state_dict(),
|
|
||||||
scaler=scaler.state_dict(),
|
|
||||||
hparams = config,
|
|
||||||
image_embed_dim = {"image_embed_dim":image_embed_dim})
|
|
||||||
torch.save(state_dict, save_path+'/'+str(time.time())+'_saved_model.pth')
|
|
||||||
|
|
||||||
# exponential moving average wrapper
|
|
||||||
|
|
||||||
class EMA(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model,
|
|
||||||
beta = 0.9999,
|
|
||||||
update_after_step = 1000,
|
|
||||||
update_every = 10,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.beta = beta
|
|
||||||
self.online_model = model
|
|
||||||
self.ema_model = copy.deepcopy(model)
|
|
||||||
|
|
||||||
self.update_every = update_every
|
|
||||||
self.update_after_step = update_after_step // update_every # only start EMA after this step number, starting at 0
|
|
||||||
|
|
||||||
self.register_buffer('initted', torch.Tensor([False]))
|
|
||||||
self.register_buffer('step', torch.tensor([0.]))
|
|
||||||
|
|
||||||
def restore_ema_model_device(self):
|
|
||||||
device = self.initted.device
|
|
||||||
self.ema_model.to(device)
|
|
||||||
|
|
||||||
def copy_params_from_model_to_ema(self):
|
|
||||||
self.ema_model.state_dict(self.online_model.state_dict())
|
|
||||||
|
|
||||||
def update(self):
|
|
||||||
self.step += 1
|
|
||||||
|
|
||||||
if (self.step % self.update_every) != 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
if self.step <= self.update_after_step:
|
|
||||||
self.copy_params_from_model_to_ema()
|
|
||||||
return
|
|
||||||
|
|
||||||
if not self.initted:
|
|
||||||
self.copy_params_from_model_to_ema()
|
|
||||||
self.initted.data.copy_(torch.Tensor([True]))
|
|
||||||
|
|
||||||
self.update_moving_average(self.ema_model, self.online_model)
|
|
||||||
|
|
||||||
def update_moving_average(self, ma_model, current_model):
|
|
||||||
def calculate_ema(beta, old, new):
|
|
||||||
if not exists(old):
|
|
||||||
return new
|
|
||||||
return old * beta + (1 - beta) * new
|
|
||||||
|
|
||||||
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
|
|
||||||
old_weight, up_weight = ma_params.data, current_params.data
|
|
||||||
ma_params.data = calculate_ema(self.beta, old_weight, up_weight)
|
|
||||||
|
|
||||||
for current_buffer, ma_buffer in zip(current_model.buffers(), ma_model.buffers()):
|
|
||||||
new_buffer_value = calculate_ema(self.beta, ma_buffer, current_buffer)
|
|
||||||
ma_buffer.copy_(new_buffer_value)
|
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
|
||||||
return self.ema_model(*args, **kwargs)
|
|
||||||
|
|
||||||
# diffusion prior trainer
|
# diffusion prior trainer
|
||||||
|
|
||||||
def prior_sample_in_chunks(fn):
|
def prior_sample_in_chunks(fn):
|
||||||
@@ -255,44 +159,190 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
eps = 1e-6,
|
eps = 1e-6,
|
||||||
max_grad_norm = None,
|
max_grad_norm = None,
|
||||||
amp = False,
|
amp = False,
|
||||||
|
group_wd_params = True,
|
||||||
|
device = None,
|
||||||
|
accelerator = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert isinstance(diffusion_prior, DiffusionPrior)
|
assert isinstance(diffusion_prior, DiffusionPrior)
|
||||||
|
assert not exists(accelerator) or isinstance(accelerator, Accelerator)
|
||||||
|
assert exists(accelerator) or exists(device), "You must supply some method of obtaining a device."
|
||||||
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
|
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
|
||||||
|
|
||||||
|
# assign some helpful member vars
|
||||||
|
self.accelerator = accelerator
|
||||||
|
self.device = accelerator.device if exists(accelerator) else device
|
||||||
|
self.text_conditioned = diffusion_prior.condition_on_text_encodings
|
||||||
|
|
||||||
|
# save model
|
||||||
|
|
||||||
self.diffusion_prior = diffusion_prior
|
self.diffusion_prior = diffusion_prior
|
||||||
|
|
||||||
# exponential moving average
|
|
||||||
|
|
||||||
self.use_ema = use_ema
|
|
||||||
if self.use_ema:
|
|
||||||
self.ema_diffusion_prior = EMA(diffusion_prior, **ema_kwargs)
|
|
||||||
|
|
||||||
# optimizer and mixed precision stuff
|
# optimizer and mixed precision stuff
|
||||||
|
|
||||||
self.amp = amp
|
self.amp = amp
|
||||||
|
|
||||||
self.scaler = GradScaler(enabled = amp)
|
self.scaler = GradScaler(enabled = amp)
|
||||||
|
|
||||||
|
self.optim_kwargs = dict(lr=lr, wd=wd, eps=eps, group_wd_params=group_wd_params)
|
||||||
|
|
||||||
self.optimizer = get_optimizer(
|
self.optimizer = get_optimizer(
|
||||||
diffusion_prior.parameters(),
|
self.diffusion_prior.parameters(),
|
||||||
lr = lr,
|
**self.optim_kwargs,
|
||||||
wd = wd,
|
|
||||||
eps = eps,
|
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# distribute the model if using HFA
|
||||||
|
if exists(self.accelerator):
|
||||||
|
self.diffusion_prior, self.optimizer = self.accelerator.prepare(self.diffusion_prior, self.optimizer)
|
||||||
|
|
||||||
|
# exponential moving average stuff
|
||||||
|
|
||||||
|
self.use_ema = use_ema
|
||||||
|
|
||||||
|
if self.use_ema:
|
||||||
|
self.ema_diffusion_prior = EMA(self.unwrap_model(self.diffusion_prior), **ema_kwargs)
|
||||||
|
|
||||||
# gradient clipping if needed
|
# gradient clipping if needed
|
||||||
|
|
||||||
self.max_grad_norm = max_grad_norm
|
self.max_grad_norm = max_grad_norm
|
||||||
|
|
||||||
self.register_buffer('step', torch.tensor([0.]))
|
# track steps internally
|
||||||
|
|
||||||
|
self.register_buffer('step', torch.tensor([0]))
|
||||||
|
|
||||||
|
# accelerator wrappers
|
||||||
|
|
||||||
|
def print(self, msg):
|
||||||
|
if exists(self.accelerator):
|
||||||
|
self.accelerator.print(msg)
|
||||||
|
else:
|
||||||
|
print(msg)
|
||||||
|
|
||||||
|
def unwrap_model(self, model):
|
||||||
|
if exists(self.accelerator):
|
||||||
|
return self.accelerator.unwrap_model(model)
|
||||||
|
else:
|
||||||
|
return model
|
||||||
|
|
||||||
|
def wait_for_everyone(self):
|
||||||
|
if exists(self.accelerator):
|
||||||
|
self.accelerator.wait_for_everyone()
|
||||||
|
|
||||||
|
def is_main_process(self):
|
||||||
|
if exists(self.accelerator):
|
||||||
|
return self.accelerator.is_main_process
|
||||||
|
else:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def clip_grad_norm_(self, *args):
|
||||||
|
if exists(self.accelerator):
|
||||||
|
return self.accelerator.clip_grad_norm_(*args)
|
||||||
|
else:
|
||||||
|
return torch.nn.utils.clip_grad_norm_(*args)
|
||||||
|
|
||||||
|
def backprop(self, x):
|
||||||
|
if exists(self.accelerator):
|
||||||
|
self.accelerator.backward(x)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
x.backward()
|
||||||
|
except Exception as e:
|
||||||
|
self.print(f"Caught error in backprop call: {e}")
|
||||||
|
|
||||||
|
# utility
|
||||||
|
|
||||||
|
def save(self, path, overwrite = True, **kwargs):
|
||||||
|
# ensure we sync gradients before continuing
|
||||||
|
self.wait_for_everyone()
|
||||||
|
|
||||||
|
# only save on the main process
|
||||||
|
if self.is_main_process():
|
||||||
|
self.print(f"Saving checkpoint at step: {self.step.item()}")
|
||||||
|
path = Path(path)
|
||||||
|
assert not (path.exists() and not overwrite)
|
||||||
|
path.parent.mkdir(parents = True, exist_ok = True)
|
||||||
|
|
||||||
|
save_obj = dict(
|
||||||
|
scaler = self.scaler.state_dict(),
|
||||||
|
optimizer = self.optimizer.state_dict(),
|
||||||
|
model = self.unwrap_model(self.diffusion_prior).state_dict(), # unwrap the model from distribution if applicable
|
||||||
|
version = version.parse(__version__),
|
||||||
|
step = self.step.item(),
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.use_ema:
|
||||||
|
save_obj = {
|
||||||
|
**save_obj,
|
||||||
|
'ema': self.ema_diffusion_prior.state_dict(),
|
||||||
|
'ema_model': self.ema_diffusion_prior.ema_model.state_dict() # save the ema model specifically for easy ema-only reload
|
||||||
|
}
|
||||||
|
|
||||||
|
torch.save(save_obj, str(path))
|
||||||
|
|
||||||
|
def load(self, path, overwrite_lr = True, strict = True):
|
||||||
|
"""
|
||||||
|
Load a checkpoint of a diffusion prior trainer.
|
||||||
|
|
||||||
|
Will load the entire trainer, including the optimizer and EMA.
|
||||||
|
|
||||||
|
Params:
|
||||||
|
- path (str): a path to the DiffusionPriorTrainer checkpoint file
|
||||||
|
- overwrite_lr (bool): wether or not to overwrite the stored LR with the LR specified in the new trainer
|
||||||
|
- strict (bool): kwarg for `torch.nn.Module.load_state_dict`, will force an exact checkpoint match
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
loaded_obj (dict): The loaded checkpoint dictionary
|
||||||
|
"""
|
||||||
|
|
||||||
|
# all processes need to load checkpoint. no restriction here
|
||||||
|
path = Path(path)
|
||||||
|
assert path.exists()
|
||||||
|
|
||||||
|
loaded_obj = torch.load(str(path), map_location=self.device)
|
||||||
|
|
||||||
|
if version.parse(__version__) != loaded_obj['version']:
|
||||||
|
print(f'loading saved diffusion prior at version {loaded_obj["version"]} but current package version is at {__version__}')
|
||||||
|
|
||||||
|
# unwrap the model when loading from checkpoint
|
||||||
|
self.unwrap_model(self.diffusion_prior).load_state_dict(loaded_obj['model'], strict = strict)
|
||||||
|
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
|
||||||
|
|
||||||
|
self.scaler.load_state_dict(loaded_obj['scaler'])
|
||||||
|
self.optimizer.load_state_dict(loaded_obj['optimizer'])
|
||||||
|
|
||||||
|
if overwrite_lr:
|
||||||
|
new_lr = self.optim_kwargs["lr"]
|
||||||
|
|
||||||
|
self.print(f"Overriding LR to be {new_lr}")
|
||||||
|
|
||||||
|
for group in self.optimizer.param_groups:
|
||||||
|
group["lr"] = new_lr
|
||||||
|
|
||||||
|
if self.use_ema:
|
||||||
|
assert 'ema' in loaded_obj
|
||||||
|
self.ema_diffusion_prior.load_state_dict(loaded_obj['ema'], strict = strict)
|
||||||
|
# below not be necessary, but I had a suspicion that this wasn't being loaded correctly
|
||||||
|
self.ema_diffusion_prior.ema_model.load_state_dict(loaded_obj["ema_model"])
|
||||||
|
|
||||||
|
# sync and inform
|
||||||
|
self.wait_for_everyone()
|
||||||
|
self.print(f"Loaded model")
|
||||||
|
|
||||||
|
return loaded_obj
|
||||||
|
|
||||||
|
# model functionality
|
||||||
|
|
||||||
def update(self):
|
def update(self):
|
||||||
|
# only continue with updates until all ranks finish
|
||||||
|
self.wait_for_everyone()
|
||||||
|
|
||||||
if exists(self.max_grad_norm):
|
if exists(self.max_grad_norm):
|
||||||
self.scaler.unscale_(self.optimizer)
|
self.scaler.unscale_(self.optimizer)
|
||||||
nn.utils.clip_grad_norm_(self.diffusion_prior.parameters(), self.max_grad_norm)
|
# utilize HFA clipping where applicable
|
||||||
|
self.clip_grad_norm_(self.diffusion_prior.parameters(), self.max_grad_norm)
|
||||||
|
|
||||||
self.scaler.step(self.optimizer)
|
self.scaler.step(self.optimizer)
|
||||||
self.scaler.update()
|
self.scaler.update()
|
||||||
@@ -307,17 +357,26 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
@cast_torch_tensor
|
@cast_torch_tensor
|
||||||
@prior_sample_in_chunks
|
@prior_sample_in_chunks
|
||||||
def p_sample_loop(self, *args, **kwargs):
|
def p_sample_loop(self, *args, **kwargs):
|
||||||
return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs)
|
model = self.ema_diffusion_prior.ema_model if self.use_ema else self.diffusion_prior
|
||||||
|
return model.p_sample_loop(*args, **kwargs)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@cast_torch_tensor
|
@cast_torch_tensor
|
||||||
@prior_sample_in_chunks
|
@prior_sample_in_chunks
|
||||||
def sample(self, *args, **kwargs):
|
def sample(self, *args, **kwargs):
|
||||||
return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs)
|
model = self.ema_diffusion_prior.ema_model if self.use_ema else self.diffusion_prior
|
||||||
|
return model.sample(*args, **kwargs)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_batch_size(self, *args, **kwargs):
|
def sample_batch_size(self, *args, **kwargs):
|
||||||
return self.ema_diffusion_prior.ema_model.sample_batch_size(*args, **kwargs)
|
model = self.ema_diffusion_prior.ema_model if self.use_ema else self.diffusion_prior
|
||||||
|
return model.sample_batch_size(*args, **kwargs)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
@cast_torch_tensor
|
||||||
|
@prior_sample_in_chunks
|
||||||
|
def embed_text(self, *args, **kwargs):
|
||||||
|
return self.unwrap_model(self.diffusion_prior).clip.embed_text(*args, **kwargs)
|
||||||
|
|
||||||
@cast_torch_tensor
|
@cast_torch_tensor
|
||||||
def forward(
|
def forward(
|
||||||
@@ -335,8 +394,10 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
|
|
||||||
total_loss += loss.item()
|
total_loss += loss.item()
|
||||||
|
|
||||||
|
# backprop with accelerate if applicable
|
||||||
|
|
||||||
if self.training:
|
if self.training:
|
||||||
self.scaler.scale(loss).backward()
|
self.backprop(self.scaler.scale(loss))
|
||||||
|
|
||||||
return total_loss
|
return total_loss
|
||||||
|
|
||||||
@@ -362,20 +423,23 @@ class DecoderTrainer(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
decoder,
|
decoder,
|
||||||
|
accelerator = None,
|
||||||
use_ema = True,
|
use_ema = True,
|
||||||
lr = 1e-4,
|
lr = 1e-4,
|
||||||
wd = 1e-2,
|
wd = 1e-2,
|
||||||
eps = 1e-8,
|
eps = 1e-8,
|
||||||
max_grad_norm = 0.5,
|
max_grad_norm = 0.5,
|
||||||
amp = False,
|
amp = False,
|
||||||
|
group_wd_params = True,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert isinstance(decoder, Decoder)
|
assert isinstance(decoder, Decoder)
|
||||||
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
|
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
|
||||||
|
|
||||||
self.decoder = decoder
|
self.accelerator = default(accelerator, Accelerator)
|
||||||
self.num_unets = len(self.decoder.unets)
|
|
||||||
|
self.num_unets = len(decoder.unets)
|
||||||
|
|
||||||
self.use_ema = use_ema
|
self.use_ema = use_ema
|
||||||
self.ema_unets = nn.ModuleList([])
|
self.ema_unets = nn.ModuleList([])
|
||||||
@@ -387,56 +451,106 @@ class DecoderTrainer(nn.Module):
|
|||||||
|
|
||||||
lr, wd, eps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps))
|
lr, wd, eps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps))
|
||||||
|
|
||||||
for ind, (unet, unet_lr, unet_wd, unet_eps) in enumerate(zip(self.decoder.unets, lr, wd, eps)):
|
assert all([unet_lr < 1e-3 for unet_lr in lr]), 'your learning rate is too high, recommend sticking with 1e-4, at most 5e-4'
|
||||||
|
|
||||||
|
optimizers = []
|
||||||
|
|
||||||
|
for unet, unet_lr, unet_wd, unet_eps in zip(decoder.unets, lr, wd, eps):
|
||||||
optimizer = get_optimizer(
|
optimizer = get_optimizer(
|
||||||
unet.parameters(),
|
unet.parameters(),
|
||||||
lr = unet_lr,
|
lr = unet_lr,
|
||||||
wd = unet_wd,
|
wd = unet_wd,
|
||||||
eps = unet_eps,
|
eps = unet_eps,
|
||||||
|
group_wd_params = group_wd_params,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
setattr(self, f'optim{ind}', optimizer) # cannot use pytorch ModuleList for some reason with optimizers
|
optimizers.append(optimizer)
|
||||||
|
|
||||||
if self.use_ema:
|
if self.use_ema:
|
||||||
self.ema_unets.append(EMA(unet, **ema_kwargs))
|
self.ema_unets.append(EMA(unet, **ema_kwargs))
|
||||||
|
|
||||||
scaler = GradScaler(enabled = amp)
|
|
||||||
setattr(self, f'scaler{ind}', scaler)
|
|
||||||
|
|
||||||
# gradient clipping if needed
|
# gradient clipping if needed
|
||||||
|
|
||||||
self.max_grad_norm = max_grad_norm
|
self.max_grad_norm = max_grad_norm
|
||||||
|
|
||||||
self.register_buffer('step', torch.tensor([0.]))
|
self.register_buffer('step', torch.tensor([0.]))
|
||||||
|
|
||||||
|
decoder, *optimizers = list(self.accelerator.prepare(decoder, *optimizers))
|
||||||
|
|
||||||
|
self.decoder = decoder
|
||||||
|
|
||||||
|
for opt_ind, optimizer in zip(range(len(optimizers)), optimizers):
|
||||||
|
setattr(self, f'optim{opt_ind}', optimizer)
|
||||||
|
|
||||||
|
def save(self, path, overwrite = True, **kwargs):
|
||||||
|
path = Path(path)
|
||||||
|
assert not (path.exists() and not overwrite)
|
||||||
|
path.parent.mkdir(parents = True, exist_ok = True)
|
||||||
|
|
||||||
|
save_obj = dict(
|
||||||
|
model = self.accelerator.unwrap_model(self.decoder).state_dict(),
|
||||||
|
version = __version__,
|
||||||
|
step = self.step.item(),
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
for ind in range(0, self.num_unets):
|
||||||
|
optimizer_key = f'optim{ind}'
|
||||||
|
optimizer = getattr(self, optimizer_key)
|
||||||
|
save_obj = {**save_obj, optimizer_key: self.accelerator.unwrap_model(optimizer).state_dict()}
|
||||||
|
|
||||||
|
if self.use_ema:
|
||||||
|
save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}
|
||||||
|
|
||||||
|
self.accelerator.save(save_obj, str(path))
|
||||||
|
|
||||||
|
def load_state_dict(self, loaded_obj, only_model = False, strict = True):
|
||||||
|
if version.parse(__version__) != version.parse(loaded_obj['version']):
|
||||||
|
self.accelerator.print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {__version__}')
|
||||||
|
|
||||||
|
self.accelerator.unwrap_model(self.decoder).load_state_dict(loaded_obj['model'], strict = strict)
|
||||||
|
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
|
||||||
|
|
||||||
|
if only_model:
|
||||||
|
return loaded_obj
|
||||||
|
|
||||||
|
for ind in range(0, self.num_unets):
|
||||||
|
optimizer_key = f'optim{ind}'
|
||||||
|
optimizer = getattr(self, optimizer_key)
|
||||||
|
|
||||||
|
self.accelerator.unwrap_model(optimizer).load_state_dict(loaded_obj[optimizer_key])
|
||||||
|
|
||||||
|
if self.use_ema:
|
||||||
|
assert 'ema' in loaded_obj
|
||||||
|
self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict)
|
||||||
|
|
||||||
|
def load(self, path, only_model = False, strict = True):
|
||||||
|
path = Path(path)
|
||||||
|
assert path.exists()
|
||||||
|
|
||||||
|
loaded_obj = torch.load(str(path), map_location = 'cpu')
|
||||||
|
|
||||||
|
self.load_state_dict(loaded_obj, only_model = only_model, strict = strict)
|
||||||
|
|
||||||
|
return loaded_obj
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def unets(self):
|
def unets(self):
|
||||||
return nn.ModuleList([ema.ema_model for ema in self.ema_unets])
|
return nn.ModuleList([ema.ema_model for ema in self.ema_unets])
|
||||||
|
|
||||||
def scale(self, loss, *, unet_number):
|
|
||||||
assert 1 <= unet_number <= self.num_unets
|
|
||||||
index = unet_number - 1
|
|
||||||
scaler = getattr(self, f'scaler{index}')
|
|
||||||
return scaler.scale(loss)
|
|
||||||
|
|
||||||
def update(self, unet_number = None):
|
def update(self, unet_number = None):
|
||||||
if self.num_unets == 1:
|
if self.num_unets == 1:
|
||||||
unet_number = default(unet_number, 1)
|
unet_number = default(unet_number, 1)
|
||||||
|
|
||||||
assert exists(unet_number) and 1 <= unet_number <= self.num_unets
|
assert exists(unet_number) and 1 <= unet_number <= self.num_unets
|
||||||
index = unet_number - 1
|
index = unet_number - 1
|
||||||
unet = self.decoder.unets[index]
|
|
||||||
|
|
||||||
optimizer = getattr(self, f'optim{index}')
|
optimizer = getattr(self, f'optim{index}')
|
||||||
scaler = getattr(self, f'scaler{index}')
|
|
||||||
|
|
||||||
if exists(self.max_grad_norm):
|
if exists(self.max_grad_norm):
|
||||||
scaler.unscale_(optimizer)
|
self.accelerator.clip_grad_norm_(self.decoder.parameters(), self.max_grad_norm) # Automatically unscales gradients
|
||||||
nn.utils.clip_grad_norm_(unet.parameters(), self.max_grad_norm)
|
optimizer.step()
|
||||||
|
|
||||||
scaler.step(optimizer)
|
|
||||||
scaler.update()
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
if self.use_ema:
|
if self.use_ema:
|
||||||
@@ -449,15 +563,17 @@ class DecoderTrainer(nn.Module):
|
|||||||
@cast_torch_tensor
|
@cast_torch_tensor
|
||||||
@decoder_sample_in_chunks
|
@decoder_sample_in_chunks
|
||||||
def sample(self, *args, **kwargs):
|
def sample(self, *args, **kwargs):
|
||||||
|
distributed = self.accelerator.num_processes > 1
|
||||||
|
base_decoder = self.accelerator.unwrap_model(self.decoder)
|
||||||
if kwargs.pop('use_non_ema', False) or not self.use_ema:
|
if kwargs.pop('use_non_ema', False) or not self.use_ema:
|
||||||
return self.decoder.sample(*args, **kwargs)
|
return base_decoder.sample(*args, **kwargs, distributed = distributed)
|
||||||
|
|
||||||
trainable_unets = self.decoder.unets
|
trainable_unets = self.accelerator.unwrap_model(self.decoder).unets
|
||||||
self.decoder.unets = self.unets # swap in exponential moving averaged unets for sampling
|
base_decoder.unets = self.unets # swap in exponential moving averaged unets for sampling
|
||||||
|
|
||||||
output = self.decoder.sample(*args, **kwargs)
|
output = base_decoder.sample(*args, **kwargs, distributed = distributed)
|
||||||
|
|
||||||
self.decoder.unets = trainable_unets # restore original training unets
|
base_decoder.unets = trainable_unets # restore original training unets
|
||||||
|
|
||||||
# cast the ema_model unets back to original device
|
# cast the ema_model unets back to original device
|
||||||
for ema in self.ema_unets:
|
for ema in self.ema_unets:
|
||||||
@@ -465,6 +581,18 @@ class DecoderTrainer(nn.Module):
|
|||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
@cast_torch_tensor
|
||||||
|
@prior_sample_in_chunks
|
||||||
|
def embed_text(self, *args, **kwargs):
|
||||||
|
return self.accelerator.unwrap_model(self.decoder).clip.embed_text(*args, **kwargs)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
@cast_torch_tensor
|
||||||
|
@prior_sample_in_chunks
|
||||||
|
def embed_image(self, *args, **kwargs):
|
||||||
|
return self.accelerator.unwrap_model(self.decoder).clip.embed_image(*args, **kwargs)
|
||||||
|
|
||||||
@cast_torch_tensor
|
@cast_torch_tensor
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -479,13 +607,14 @@ class DecoderTrainer(nn.Module):
|
|||||||
total_loss = 0.
|
total_loss = 0.
|
||||||
|
|
||||||
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
|
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
|
||||||
with autocast(enabled = self.amp):
|
# with autocast(enabled = self.amp):
|
||||||
|
with self.accelerator.autocast():
|
||||||
loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs)
|
loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs)
|
||||||
loss = loss * chunk_size_frac
|
loss = loss * chunk_size_frac
|
||||||
|
|
||||||
total_loss += loss.item()
|
total_loss += loss.item()
|
||||||
|
|
||||||
if self.training:
|
if self.training:
|
||||||
self.scale(loss, unet_number = unet_number).backward()
|
self.accelerator.backward(loss)
|
||||||
|
|
||||||
return total_loss
|
return total_loss
|
||||||
|
|||||||
@@ -1,4 +1,12 @@
|
|||||||
import time
|
import time
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
# helper functions
|
||||||
|
|
||||||
|
def exists(val):
|
||||||
|
return val is not None
|
||||||
|
|
||||||
|
# time helpers
|
||||||
|
|
||||||
class Timer:
|
class Timer:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -9,3 +17,19 @@ class Timer:
|
|||||||
|
|
||||||
def elapsed(self):
|
def elapsed(self):
|
||||||
return time.time() - self.last_time
|
return time.time() - self.last_time
|
||||||
|
|
||||||
|
# print helpers
|
||||||
|
|
||||||
|
def print_ribbon(s, symbol = '=', repeat = 40):
|
||||||
|
flank = symbol * repeat
|
||||||
|
return f'{flank} {s} {flank}'
|
||||||
|
|
||||||
|
# import helpers
|
||||||
|
|
||||||
|
def import_or_print_error(pkg_name, err_str = None):
|
||||||
|
try:
|
||||||
|
return importlib.import_module(pkg_name)
|
||||||
|
except ModuleNotFoundError as e:
|
||||||
|
if exists(err_str):
|
||||||
|
print(err_str)
|
||||||
|
exit()
|
||||||
|
|||||||
1
dalle2_pytorch/version.py
Normal file
1
dalle2_pytorch/version.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
__version__ = '0.16.2'
|
||||||
@@ -68,8 +68,8 @@ def group_dict_by_key(cond, d):
|
|||||||
return_val[ind][key] = d[key]
|
return_val[ind][key] = d[key]
|
||||||
return (*return_val,)
|
return (*return_val,)
|
||||||
|
|
||||||
def string_begins_with(prefix, str):
|
def string_begins_with(prefix, string_input):
|
||||||
return str.startswith(prefix)
|
return string_input.startswith(prefix)
|
||||||
|
|
||||||
def group_by_key_prefix(prefix, d):
|
def group_by_key_prefix(prefix, d):
|
||||||
return group_dict_by_key(partial(string_begins_with, prefix), d)
|
return group_dict_by_key(partial(string_begins_with, prefix), d)
|
||||||
|
|||||||
@@ -16,10 +16,11 @@ from torchvision.utils import make_grid, save_image
|
|||||||
|
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
from dalle2_pytorch.train import EMA
|
|
||||||
from dalle2_pytorch.vqgan_vae import VQGanVAE
|
from dalle2_pytorch.vqgan_vae import VQGanVAE
|
||||||
from dalle2_pytorch.optimizer import get_optimizer
|
from dalle2_pytorch.optimizer import get_optimizer
|
||||||
|
|
||||||
|
from ema_pytorch import EMA
|
||||||
|
|
||||||
# helpers
|
# helpers
|
||||||
|
|
||||||
def exists(val):
|
def exists(val):
|
||||||
@@ -97,7 +98,7 @@ class VQGanVAETrainer(nn.Module):
|
|||||||
valid_frac = 0.05,
|
valid_frac = 0.05,
|
||||||
random_split_seed = 42,
|
random_split_seed = 42,
|
||||||
ema_beta = 0.995,
|
ema_beta = 0.995,
|
||||||
ema_update_after_step = 2000,
|
ema_update_after_step = 500,
|
||||||
ema_update_every = 10,
|
ema_update_every = 10,
|
||||||
apply_grad_penalty_every = 4,
|
apply_grad_penalty_every = 4,
|
||||||
amp = False
|
amp = False
|
||||||
|
|||||||
7
setup.py
7
setup.py
@@ -1,4 +1,5 @@
|
|||||||
from setuptools import setup, find_packages
|
from setuptools import setup, find_packages
|
||||||
|
exec(open('dalle2_pytorch/version.py').read())
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name = 'dalle2-pytorch',
|
name = 'dalle2-pytorch',
|
||||||
@@ -10,7 +11,7 @@ setup(
|
|||||||
'dream = dalle2_pytorch.cli:dream'
|
'dream = dalle2_pytorch.cli:dream'
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
version = '0.4.2',
|
version = __version__,
|
||||||
license='MIT',
|
license='MIT',
|
||||||
description = 'DALL-E 2',
|
description = 'DALL-E 2',
|
||||||
author = 'Phil Wang',
|
author = 'Phil Wang',
|
||||||
@@ -23,14 +24,17 @@ setup(
|
|||||||
'text to image'
|
'text to image'
|
||||||
],
|
],
|
||||||
install_requires=[
|
install_requires=[
|
||||||
|
'accelerate',
|
||||||
'click',
|
'click',
|
||||||
'clip-anytorch',
|
'clip-anytorch',
|
||||||
'coca-pytorch>=0.0.5',
|
'coca-pytorch>=0.0.5',
|
||||||
|
'ema-pytorch>=0.0.7',
|
||||||
'einops>=0.4',
|
'einops>=0.4',
|
||||||
'einops-exts>=0.0.3',
|
'einops-exts>=0.0.3',
|
||||||
'embedding-reader',
|
'embedding-reader',
|
||||||
'kornia>=0.5.4',
|
'kornia>=0.5.4',
|
||||||
'numpy',
|
'numpy',
|
||||||
|
'packaging',
|
||||||
'pillow',
|
'pillow',
|
||||||
'pydantic',
|
'pydantic',
|
||||||
'resize-right>=0.0.2',
|
'resize-right>=0.0.2',
|
||||||
@@ -40,7 +44,6 @@ setup(
|
|||||||
'tqdm',
|
'tqdm',
|
||||||
'vector-quantize-pytorch',
|
'vector-quantize-pytorch',
|
||||||
'x-clip>=0.4.4',
|
'x-clip>=0.4.4',
|
||||||
'youtokentome',
|
|
||||||
'webdataset>=0.2.5',
|
'webdataset>=0.2.5',
|
||||||
'fsspec>=2022.1.0',
|
'fsspec>=2022.1.0',
|
||||||
'torchmetrics[image]>=0.8.0'
|
'torchmetrics[image]>=0.8.0'
|
||||||
|
|||||||
572
train_decoder.py
572
train_decoder.py
@@ -1,9 +1,13 @@
|
|||||||
from dalle2_pytorch import Unet, Decoder
|
from pathlib import Path
|
||||||
from dalle2_pytorch.trainer import DecoderTrainer, print_ribbon
|
from typing import List
|
||||||
|
|
||||||
|
from dalle2_pytorch.trainer import DecoderTrainer
|
||||||
from dalle2_pytorch.dataloaders import create_image_embedding_dataloader
|
from dalle2_pytorch.dataloaders import create_image_embedding_dataloader
|
||||||
from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker
|
from dalle2_pytorch.trackers import Tracker
|
||||||
from dalle2_pytorch.train_configs import TrainDecoderConfig
|
from dalle2_pytorch.train_configs import DecoderConfig, TrainDecoderConfig
|
||||||
from dalle2_pytorch.utils import Timer
|
from dalle2_pytorch.utils import Timer, print_ribbon
|
||||||
|
from dalle2_pytorch.dalle2_pytorch import Decoder, resize_image_to
|
||||||
|
from clip import tokenize
|
||||||
|
|
||||||
import torchvision
|
import torchvision
|
||||||
import torch
|
import torch
|
||||||
@@ -11,6 +15,8 @@ from torchmetrics.image.fid import FrechetInceptionDistance
|
|||||||
from torchmetrics.image.inception import InceptionScore
|
from torchmetrics.image.inception import InceptionScore
|
||||||
from torchmetrics.image.kid import KernelInceptionDistance
|
from torchmetrics.image.kid import KernelInceptionDistance
|
||||||
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
|
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
|
||||||
|
from accelerate import Accelerator, DistributedDataParallelKwargs
|
||||||
|
from accelerate.utils import dataclasses as accelerate_dataclasses
|
||||||
import webdataset as wds
|
import webdataset as wds
|
||||||
import click
|
import click
|
||||||
|
|
||||||
@@ -29,7 +35,8 @@ def exists(val):
|
|||||||
def create_dataloaders(
|
def create_dataloaders(
|
||||||
available_shards,
|
available_shards,
|
||||||
webdataset_base_url,
|
webdataset_base_url,
|
||||||
embeddings_url,
|
img_embeddings_url=None,
|
||||||
|
text_embeddings_url=None,
|
||||||
shard_width=6,
|
shard_width=6,
|
||||||
num_workers=4,
|
num_workers=4,
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
@@ -41,6 +48,7 @@ def create_dataloaders(
|
|||||||
train_prop = 0.75,
|
train_prop = 0.75,
|
||||||
val_prop = 0.15,
|
val_prop = 0.15,
|
||||||
test_prop = 0.10,
|
test_prop = 0.10,
|
||||||
|
seed = 0,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -51,21 +59,22 @@ def create_dataloaders(
|
|||||||
num_test = round(test_prop*len(available_shards))
|
num_test = round(test_prop*len(available_shards))
|
||||||
num_val = len(available_shards) - num_train - num_test
|
num_val = len(available_shards) - num_train - num_test
|
||||||
assert num_train + num_test + num_val == len(available_shards), f"{num_train} + {num_test} + {num_val} = {num_train + num_test + num_val} != {len(available_shards)}"
|
assert num_train + num_test + num_val == len(available_shards), f"{num_train} + {num_test} + {num_val} = {num_train + num_test + num_val} != {len(available_shards)}"
|
||||||
train_split, test_split, val_split = torch.utils.data.random_split(available_shards, [num_train, num_test, num_val], generator=torch.Generator().manual_seed(0))
|
train_split, test_split, val_split = torch.utils.data.random_split(available_shards, [num_train, num_test, num_val], generator=torch.Generator().manual_seed(seed))
|
||||||
|
|
||||||
# The shard number in the webdataset file names has a fixed width. We zero pad the shard numbers so they correspond to a filename.
|
# The shard number in the webdataset file names has a fixed width. We zero pad the shard numbers so they correspond to a filename.
|
||||||
train_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in train_split]
|
train_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in train_split]
|
||||||
test_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in test_split]
|
test_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in test_split]
|
||||||
val_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in val_split]
|
val_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in val_split]
|
||||||
|
|
||||||
create_dataloader = lambda tar_urls, shuffle=False, resample=False, with_text=False, for_sampling=False: create_image_embedding_dataloader(
|
create_dataloader = lambda tar_urls, shuffle=False, resample=False, for_sampling=False: create_image_embedding_dataloader(
|
||||||
tar_url=tar_urls,
|
tar_url=tar_urls,
|
||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
batch_size=batch_size if not for_sampling else n_sample_images,
|
batch_size=batch_size if not for_sampling else n_sample_images,
|
||||||
embeddings_url=embeddings_url,
|
img_embeddings_url=img_embeddings_url,
|
||||||
|
text_embeddings_url=text_embeddings_url,
|
||||||
index_width=index_width,
|
index_width=index_width,
|
||||||
shuffle_num = None,
|
shuffle_num = None,
|
||||||
extra_keys= ["txt"] if with_text else [],
|
extra_keys= ["txt"],
|
||||||
shuffle_shards = shuffle,
|
shuffle_shards = shuffle,
|
||||||
resample_shards = resample,
|
resample_shards = resample,
|
||||||
img_preproc=img_preproc,
|
img_preproc=img_preproc,
|
||||||
@@ -74,8 +83,8 @@ def create_dataloaders(
|
|||||||
|
|
||||||
train_dataloader = create_dataloader(train_urls, shuffle=shuffle_train, resample=resample_train)
|
train_dataloader = create_dataloader(train_urls, shuffle=shuffle_train, resample=resample_train)
|
||||||
train_sampling_dataloader = create_dataloader(train_urls, shuffle=False, for_sampling=True)
|
train_sampling_dataloader = create_dataloader(train_urls, shuffle=False, for_sampling=True)
|
||||||
val_dataloader = create_dataloader(val_urls, shuffle=False, with_text=True)
|
val_dataloader = create_dataloader(val_urls, shuffle=False)
|
||||||
test_dataloader = create_dataloader(test_urls, shuffle=False, with_text=True)
|
test_dataloader = create_dataloader(test_urls, shuffle=False)
|
||||||
test_sampling_dataloader = create_dataloader(test_urls, shuffle=False, for_sampling=True)
|
test_sampling_dataloader = create_dataloader(test_urls, shuffle=False, for_sampling=True)
|
||||||
return {
|
return {
|
||||||
"train": train_dataloader,
|
"train": train_dataloader,
|
||||||
@@ -85,20 +94,6 @@ def create_dataloaders(
|
|||||||
"test_sampling": test_sampling_dataloader
|
"test_sampling": test_sampling_dataloader
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def create_decoder(device, decoder_config, unets_config):
|
|
||||||
"""Creates a sample decoder"""
|
|
||||||
|
|
||||||
unets = [Unet(**config.dict()) for config in unets_config]
|
|
||||||
|
|
||||||
decoder = Decoder(
|
|
||||||
unet=unets,
|
|
||||||
**decoder_config.dict()
|
|
||||||
)
|
|
||||||
|
|
||||||
decoder.to(device=device)
|
|
||||||
return decoder
|
|
||||||
|
|
||||||
def get_dataset_keys(dataloader):
|
def get_dataset_keys(dataloader):
|
||||||
"""
|
"""
|
||||||
It is sometimes neccesary to get the keys the dataloader is returning. Since the dataset is burried in the dataloader, we need to do a process to recover it.
|
It is sometimes neccesary to get the keys the dataloader is returning. Since the dataset is burried in the dataloader, we need to do a process to recover it.
|
||||||
@@ -113,74 +108,111 @@ def get_example_data(dataloader, device, n=5):
|
|||||||
Samples the dataloader and returns a zipped list of examples
|
Samples the dataloader and returns a zipped list of examples
|
||||||
"""
|
"""
|
||||||
images = []
|
images = []
|
||||||
embeddings = []
|
img_embeddings = []
|
||||||
|
text_embeddings = []
|
||||||
captions = []
|
captions = []
|
||||||
dataset_keys = get_dataset_keys(dataloader)
|
for img, emb, txt in dataloader:
|
||||||
has_caption = "txt" in dataset_keys
|
img_emb, text_emb = emb.get('img'), emb.get('text')
|
||||||
for data in dataloader:
|
if img_emb is not None:
|
||||||
if has_caption:
|
img_emb = img_emb.to(device=device, dtype=torch.float)
|
||||||
img, emb, txt = data
|
img_embeddings.extend(list(img_emb))
|
||||||
else:
|
else:
|
||||||
img, emb = data
|
# Then we add None img.shape[0] times
|
||||||
txt = [""] * emb.shape[0]
|
img_embeddings.extend([None]*img.shape[0])
|
||||||
|
if text_emb is not None:
|
||||||
|
text_emb = text_emb.to(device=device, dtype=torch.float)
|
||||||
|
text_embeddings.extend(list(text_emb))
|
||||||
|
else:
|
||||||
|
# Then we add None img.shape[0] times
|
||||||
|
text_embeddings.extend([None]*img.shape[0])
|
||||||
img = img.to(device=device, dtype=torch.float)
|
img = img.to(device=device, dtype=torch.float)
|
||||||
emb = emb.to(device=device, dtype=torch.float)
|
|
||||||
images.extend(list(img))
|
images.extend(list(img))
|
||||||
embeddings.extend(list(emb))
|
|
||||||
captions.extend(list(txt))
|
captions.extend(list(txt))
|
||||||
if len(images) >= n:
|
if len(images) >= n:
|
||||||
break
|
break
|
||||||
print("Generated {} examples".format(len(images)))
|
return list(zip(images[:n], img_embeddings[:n], text_embeddings[:n], captions[:n]))
|
||||||
return list(zip(images[:n], embeddings[:n], captions[:n]))
|
|
||||||
|
|
||||||
def generate_samples(trainer, example_data, text_prepend=""):
|
def generate_samples(trainer, example_data, condition_on_text_encodings=False, text_prepend=""):
|
||||||
"""
|
"""
|
||||||
Takes example data and generates images from the embeddings
|
Takes example data and generates images from the embeddings
|
||||||
Returns three lists: real images, generated images, and captions
|
Returns three lists: real images, generated images, and captions
|
||||||
"""
|
"""
|
||||||
real_images, embeddings, txts = zip(*example_data)
|
real_images, img_embeddings, text_embeddings, txts = zip(*example_data)
|
||||||
embeddings_tensor = torch.stack(embeddings)
|
sample_params = {}
|
||||||
samples = trainer.sample(embeddings_tensor)
|
if img_embeddings[0] is None:
|
||||||
|
# Generate image embeddings from clip
|
||||||
|
imgs_tensor = torch.stack(real_images)
|
||||||
|
img_embeddings, *_ = trainer.embed_image(imgs_tensor)
|
||||||
|
sample_params["image_embed"] = img_embeddings
|
||||||
|
else:
|
||||||
|
# Then we are using precomputed image embeddings
|
||||||
|
img_embeddings = torch.stack(img_embeddings)
|
||||||
|
sample_params["image_embed"] = img_embeddings
|
||||||
|
if condition_on_text_encodings:
|
||||||
|
if text_embeddings[0] is None:
|
||||||
|
# Generate text embeddings from text
|
||||||
|
tokenized_texts = tokenize(txts, truncate=True)
|
||||||
|
sample_params["text"] = tokenized_texts
|
||||||
|
else:
|
||||||
|
# Then we are using precomputed text embeddings
|
||||||
|
text_embeddings = torch.stack(text_embeddings)
|
||||||
|
sample_params["text_encodings"] = text_embeddings
|
||||||
|
samples = trainer.sample(**sample_params)
|
||||||
generated_images = list(samples)
|
generated_images = list(samples)
|
||||||
captions = [text_prepend + txt for txt in txts]
|
captions = [text_prepend + txt for txt in txts]
|
||||||
return real_images, generated_images, captions
|
return real_images, generated_images, captions
|
||||||
|
|
||||||
def generate_grid_samples(trainer, examples, text_prepend=""):
|
def generate_grid_samples(trainer, examples, condition_on_text_encodings=False, text_prepend=""):
|
||||||
"""
|
"""
|
||||||
Generates samples and uses torchvision to put them in a side by side grid for easy viewing
|
Generates samples and uses torchvision to put them in a side by side grid for easy viewing
|
||||||
"""
|
"""
|
||||||
real_images, generated_images, captions = generate_samples(trainer, examples, text_prepend)
|
real_images, generated_images, captions = generate_samples(trainer, examples, condition_on_text_encodings, text_prepend)
|
||||||
|
|
||||||
|
real_image_size = real_images[0].shape[-1]
|
||||||
|
generated_image_size = generated_images[0].shape[-1]
|
||||||
|
|
||||||
|
# training images may be larger than the generated one
|
||||||
|
if real_image_size > generated_image_size:
|
||||||
|
real_images = [resize_image_to(image, generated_image_size) for image in real_images]
|
||||||
|
|
||||||
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
|
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
|
||||||
return grid_images, captions
|
return grid_images, captions
|
||||||
|
|
||||||
def evaluate_trainer(trainer, dataloader, device, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
|
def evaluate_trainer(trainer, dataloader, device, condition_on_text_encodings=False, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
|
||||||
"""
|
"""
|
||||||
Computes evaluation metrics for the decoder
|
Computes evaluation metrics for the decoder
|
||||||
"""
|
"""
|
||||||
metrics = {}
|
metrics = {}
|
||||||
# Prepare the data
|
# Prepare the data
|
||||||
examples = get_example_data(dataloader, device, n_evaluation_samples)
|
examples = get_example_data(dataloader, device, n_evaluation_samples)
|
||||||
real_images, generated_images, captions = generate_samples(trainer, examples)
|
if len(examples) == 0:
|
||||||
|
print("No data to evaluate. Check that your dataloader has shards.")
|
||||||
|
return metrics
|
||||||
|
real_images, generated_images, captions = generate_samples(trainer, examples, condition_on_text_encodings)
|
||||||
real_images = torch.stack(real_images).to(device=device, dtype=torch.float)
|
real_images = torch.stack(real_images).to(device=device, dtype=torch.float)
|
||||||
generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float)
|
generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float)
|
||||||
# Convert from [0, 1] to [0, 255] and from torch.float to torch.uint8
|
# Convert from [0, 1] to [0, 255] and from torch.float to torch.uint8
|
||||||
int_real_images = real_images.mul(255).add(0.5).clamp(0, 255).type(torch.uint8)
|
int_real_images = real_images.mul(255).add(0.5).clamp(0, 255).type(torch.uint8)
|
||||||
int_generated_images = generated_images.mul(255).add(0.5).clamp(0, 255).type(torch.uint8)
|
int_generated_images = generated_images.mul(255).add(0.5).clamp(0, 255).type(torch.uint8)
|
||||||
|
|
||||||
|
def null_sync(t, *args, **kwargs):
|
||||||
|
return [t]
|
||||||
|
|
||||||
if exists(FID):
|
if exists(FID):
|
||||||
fid = FrechetInceptionDistance(**FID)
|
fid = FrechetInceptionDistance(**FID, dist_sync_fn=null_sync)
|
||||||
fid.to(device=device)
|
fid.to(device=device)
|
||||||
fid.update(int_real_images, real=True)
|
fid.update(int_real_images, real=True)
|
||||||
fid.update(int_generated_images, real=False)
|
fid.update(int_generated_images, real=False)
|
||||||
metrics["FID"] = fid.compute().item()
|
metrics["FID"] = fid.compute().item()
|
||||||
if exists(IS):
|
if exists(IS):
|
||||||
inception = InceptionScore(**IS)
|
inception = InceptionScore(**IS, dist_sync_fn=null_sync)
|
||||||
inception.to(device=device)
|
inception.to(device=device)
|
||||||
inception.update(int_real_images)
|
inception.update(int_real_images)
|
||||||
is_mean, is_std = inception.compute()
|
is_mean, is_std = inception.compute()
|
||||||
metrics["IS_mean"] = is_mean.item()
|
metrics["IS_mean"] = is_mean.item()
|
||||||
metrics["IS_std"] = is_std.item()
|
metrics["IS_std"] = is_std.item()
|
||||||
if exists(KID):
|
if exists(KID):
|
||||||
kernel_inception = KernelInceptionDistance(**KID)
|
kernel_inception = KernelInceptionDistance(**KID, dist_sync_fn=null_sync)
|
||||||
kernel_inception.to(device=device)
|
kernel_inception.to(device=device)
|
||||||
kernel_inception.update(int_real_images, real=True)
|
kernel_inception.update(int_real_images, real=True)
|
||||||
kernel_inception.update(int_generated_images, real=False)
|
kernel_inception.update(int_generated_images, real=False)
|
||||||
@@ -191,68 +223,82 @@ def evaluate_trainer(trainer, dataloader, device, n_evaluation_samples=1000, FID
|
|||||||
# Convert from [0, 1] to [-1, 1]
|
# Convert from [0, 1] to [-1, 1]
|
||||||
renorm_real_images = real_images.mul(2).sub(1)
|
renorm_real_images = real_images.mul(2).sub(1)
|
||||||
renorm_generated_images = generated_images.mul(2).sub(1)
|
renorm_generated_images = generated_images.mul(2).sub(1)
|
||||||
lpips = LearnedPerceptualImagePatchSimilarity(**LPIPS)
|
lpips = LearnedPerceptualImagePatchSimilarity(**LPIPS, dist_sync_fn=null_sync)
|
||||||
lpips.to(device=device)
|
lpips.to(device=device)
|
||||||
lpips.update(renorm_real_images, renorm_generated_images)
|
lpips.update(renorm_real_images, renorm_generated_images)
|
||||||
metrics["LPIPS"] = lpips.compute().item()
|
metrics["LPIPS"] = lpips.compute().item()
|
||||||
|
|
||||||
|
if trainer.accelerator.num_processes > 1:
|
||||||
|
# Then we should sync the metrics
|
||||||
|
metrics_order = sorted(metrics.keys())
|
||||||
|
metrics_tensor = torch.zeros(1, len(metrics), device=device, dtype=torch.float)
|
||||||
|
for i, metric_name in enumerate(metrics_order):
|
||||||
|
metrics_tensor[0, i] = metrics[metric_name]
|
||||||
|
metrics_tensor = trainer.accelerator.gather(metrics_tensor)
|
||||||
|
metrics_tensor = metrics_tensor.mean(dim=0)
|
||||||
|
for i, metric_name in enumerate(metrics_order):
|
||||||
|
metrics[metric_name] = metrics_tensor[i].item()
|
||||||
return metrics
|
return metrics
|
||||||
|
|
||||||
def save_trainer(tracker, trainer, epoch, step, validation_losses, relative_paths):
|
def save_trainer(tracker: Tracker, trainer: DecoderTrainer, epoch: int, sample: int, next_task: str, validation_losses: List[float], samples_seen: int, is_latest=True, is_best=False):
|
||||||
"""
|
"""
|
||||||
Logs the model with an appropriate method depending on the tracker
|
Logs the model with an appropriate method depending on the tracker
|
||||||
"""
|
"""
|
||||||
if isinstance(relative_paths, str):
|
tracker.save(trainer, is_best=is_best, is_latest=is_latest, epoch=epoch, sample=sample, next_task=next_task, validation_losses=validation_losses, samples_seen=samples_seen)
|
||||||
relative_paths = [relative_paths]
|
|
||||||
trainer_state_dict = {}
|
|
||||||
trainer_state_dict["trainer"] = trainer.state_dict()
|
|
||||||
trainer_state_dict['epoch'] = epoch
|
|
||||||
trainer_state_dict['step'] = step
|
|
||||||
trainer_state_dict['validation_losses'] = validation_losses
|
|
||||||
for relative_path in relative_paths:
|
|
||||||
tracker.save_state_dict(trainer_state_dict, relative_path)
|
|
||||||
|
|
||||||
def recall_trainer(tracker, trainer, recall_source=None, **load_config):
|
def recall_trainer(tracker: Tracker, trainer: DecoderTrainer):
|
||||||
"""
|
"""
|
||||||
Loads the model with an appropriate method depending on the tracker
|
Loads the model with an appropriate method depending on the tracker
|
||||||
"""
|
"""
|
||||||
print(print_ribbon(f"Loading model from {recall_source}"))
|
trainer.accelerator.print(print_ribbon(f"Loading model from {type(tracker.loader).__name__}"))
|
||||||
state_dict = tracker.recall_state_dict(recall_source, **load_config)
|
state_dict = tracker.recall()
|
||||||
trainer.load_state_dict(state_dict["trainer"])
|
trainer.load_state_dict(state_dict, only_model=False, strict=True)
|
||||||
print("Model loaded")
|
return state_dict.get("epoch", 0), state_dict.get("validation_losses", []), state_dict.get("next_task", "train"), state_dict.get("sample", 0), state_dict.get("samples_seen", 0)
|
||||||
return state_dict["epoch"], state_dict["step"], state_dict["validation_losses"]
|
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
dataloaders,
|
dataloaders,
|
||||||
decoder,
|
decoder: Decoder,
|
||||||
tracker,
|
accelerator: Accelerator,
|
||||||
|
tracker: Tracker,
|
||||||
inference_device,
|
inference_device,
|
||||||
load_config=None,
|
|
||||||
evaluate_config=None,
|
evaluate_config=None,
|
||||||
epoch_samples = None, # If the training dataset is resampling, we have to manually stop an epoch
|
epoch_samples = None, # If the training dataset is resampling, we have to manually stop an epoch
|
||||||
validation_samples = None,
|
validation_samples = None,
|
||||||
epochs = 20,
|
epochs = 20,
|
||||||
n_sample_images = 5,
|
n_sample_images = 5,
|
||||||
save_every_n_samples = 100000,
|
save_every_n_samples = 100000,
|
||||||
save_all=False,
|
|
||||||
save_latest=True,
|
|
||||||
save_best=True,
|
|
||||||
unet_training_mask=None,
|
unet_training_mask=None,
|
||||||
|
condition_on_text_encodings=False,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Trains a decoder on a dataset.
|
Trains a decoder on a dataset.
|
||||||
"""
|
"""
|
||||||
trainer = DecoderTrainer( # TODO: Change the get_optimizer function so that it can take arbitrary named args so we can just put **kwargs as an argument here
|
is_master = accelerator.process_index == 0
|
||||||
decoder,
|
|
||||||
|
trainer = DecoderTrainer(
|
||||||
|
decoder=decoder,
|
||||||
|
accelerator=accelerator,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set up starting model and parameters based on a recalled state dict
|
# Set up starting model and parameters based on a recalled state dict
|
||||||
start_step = 0
|
|
||||||
start_epoch = 0
|
start_epoch = 0
|
||||||
validation_losses = []
|
validation_losses = []
|
||||||
|
next_task = 'train'
|
||||||
|
sample = 0
|
||||||
|
samples_seen = 0
|
||||||
|
val_sample = 0
|
||||||
|
step = lambda: int(trainer.step.item())
|
||||||
|
|
||||||
if exists(load_config) and exists(load_config.source):
|
if tracker.loader is not None:
|
||||||
start_epoch, start_step, validation_losses = recall_trainer(tracker, trainer, recall_source=load_config.source, **load_config)
|
start_epoch, validation_losses, next_task, recalled_sample, samples_seen = recall_trainer(tracker, trainer)
|
||||||
|
if next_task == 'train':
|
||||||
|
sample = recalled_sample
|
||||||
|
if next_task == 'val':
|
||||||
|
val_sample = recalled_sample
|
||||||
|
accelerator.print(f"Loaded model from {type(tracker.loader).__name__} on epoch {start_epoch} having seen {samples_seen} samples with minimum validation loss {min(validation_losses) if len(validation_losses) > 0 else 'N/A'}")
|
||||||
|
accelerator.print(f"Starting training from task {next_task} at sample {sample} and validation sample {val_sample}")
|
||||||
trainer.to(device=inference_device)
|
trainer.to(device=inference_device)
|
||||||
|
|
||||||
if not exists(unet_training_mask):
|
if not exists(unet_training_mask):
|
||||||
@@ -260,186 +306,281 @@ def train(
|
|||||||
unet_training_mask = [True] * trainer.num_unets
|
unet_training_mask = [True] * trainer.num_unets
|
||||||
assert len(unet_training_mask) == trainer.num_unets, f"The unet training mask should be the same length as the number of unets in the decoder. Got {len(unet_training_mask)} and {trainer.num_unets}"
|
assert len(unet_training_mask) == trainer.num_unets, f"The unet training mask should be the same length as the number of unets in the decoder. Got {len(unet_training_mask)} and {trainer.num_unets}"
|
||||||
|
|
||||||
print(print_ribbon("Generating Example Data", repeat=40))
|
accelerator.print(print_ribbon("Generating Example Data", repeat=40))
|
||||||
print("This can take a while to load the shard lists...")
|
accelerator.print("This can take a while to load the shard lists...")
|
||||||
train_example_data = get_example_data(dataloaders["train_sampling"], inference_device, n_sample_images)
|
if is_master:
|
||||||
test_example_data = get_example_data(dataloaders["test_sampling"], inference_device, n_sample_images)
|
train_example_data = get_example_data(dataloaders["train_sampling"], inference_device, n_sample_images)
|
||||||
|
accelerator.print("Generated training examples")
|
||||||
|
test_example_data = get_example_data(dataloaders["test_sampling"], inference_device, n_sample_images)
|
||||||
|
accelerator.print("Generated testing examples")
|
||||||
|
|
||||||
send_to_device = lambda arr: [x.to(device=inference_device, dtype=torch.float) for x in arr]
|
send_to_device = lambda arr: [x.to(device=inference_device, dtype=torch.float) for x in arr]
|
||||||
step = start_step
|
|
||||||
|
|
||||||
|
sample_length_tensor = torch.zeros(1, dtype=torch.int, device=inference_device)
|
||||||
|
unet_losses_tensor = torch.zeros(TRAIN_CALC_LOSS_EVERY_ITERS, trainer.num_unets, dtype=torch.float, device=inference_device)
|
||||||
for epoch in range(start_epoch, epochs):
|
for epoch in range(start_epoch, epochs):
|
||||||
print(print_ribbon(f"Starting epoch {epoch}", repeat=40))
|
accelerator.print(print_ribbon(f"Starting epoch {epoch}", repeat=40))
|
||||||
|
|
||||||
timer = Timer()
|
timer = Timer()
|
||||||
|
last_sample = sample
|
||||||
|
last_snapshot = sample
|
||||||
|
|
||||||
sample = 0
|
if next_task == 'train':
|
||||||
last_sample = 0
|
for i, (img, emb, txt) in enumerate(dataloaders["train"]):
|
||||||
last_snapshot = 0
|
# We want to count the total number of samples across all processes
|
||||||
|
sample_length_tensor[0] = len(img)
|
||||||
|
all_samples = accelerator.gather(sample_length_tensor) # TODO: accelerator.reduce is broken when this was written. If it is fixed replace this.
|
||||||
|
total_samples = all_samples.sum().item()
|
||||||
|
sample += total_samples
|
||||||
|
samples_seen += total_samples
|
||||||
|
img_emb = emb.get('img')
|
||||||
|
has_img_embedding = img_emb is not None
|
||||||
|
if has_img_embedding:
|
||||||
|
img_emb, = send_to_device((img_emb,))
|
||||||
|
text_emb = emb.get('text')
|
||||||
|
has_text_embedding = text_emb is not None
|
||||||
|
if has_text_embedding:
|
||||||
|
text_emb, = send_to_device((text_emb,))
|
||||||
|
img, = send_to_device((img,))
|
||||||
|
|
||||||
losses = []
|
trainer.train()
|
||||||
|
for unet in range(1, trainer.num_unets+1):
|
||||||
|
# Check if this is a unet we are training
|
||||||
|
if not unet_training_mask[unet-1]: # Unet index is the unet number - 1
|
||||||
|
continue
|
||||||
|
|
||||||
for i, (img, emb) in enumerate(dataloaders["train"]):
|
forward_params = {}
|
||||||
step += 1
|
if has_img_embedding:
|
||||||
sample += img.shape[0]
|
forward_params['image_embed'] = img_emb
|
||||||
img, emb = send_to_device((img, emb))
|
else:
|
||||||
|
# Forward pass automatically generates embedding
|
||||||
|
pass
|
||||||
|
if condition_on_text_encodings:
|
||||||
|
if has_text_embedding:
|
||||||
|
forward_params['text_encodings'] = text_emb
|
||||||
|
else:
|
||||||
|
# Then we need to pass the text instead
|
||||||
|
tokenized_texts = tokenize(txt, truncate=True)
|
||||||
|
forward_params['text'] = tokenized_texts
|
||||||
|
loss = trainer.forward(img, **forward_params, unet_number=unet)
|
||||||
|
trainer.update(unet_number=unet)
|
||||||
|
unet_losses_tensor[i % TRAIN_CALC_LOSS_EVERY_ITERS, unet-1] = loss
|
||||||
|
|
||||||
trainer.train()
|
samples_per_sec = (sample - last_sample) / timer.elapsed()
|
||||||
for unet in range(1, trainer.num_unets+1):
|
timer.reset()
|
||||||
# Check if this is a unet we are training
|
last_sample = sample
|
||||||
if not unet_training_mask[unet-1]: # Unet index is the unet number - 1
|
|
||||||
continue
|
|
||||||
|
|
||||||
loss = trainer.forward(img, image_embed=emb, unet_number=unet)
|
if i % TRAIN_CALC_LOSS_EVERY_ITERS == 0:
|
||||||
trainer.update(unet_number=unet)
|
# We want to average losses across all processes
|
||||||
losses.append(loss)
|
unet_all_losses = accelerator.gather(unet_losses_tensor)
|
||||||
|
mask = unet_all_losses != 0
|
||||||
|
unet_average_loss = (unet_all_losses * mask).sum(dim=0) / mask.sum(dim=0)
|
||||||
|
loss_map = { f"Unet {index} Training Loss": loss.item() for index, loss in enumerate(unet_average_loss) if loss != 0 }
|
||||||
|
|
||||||
samples_per_sec = (sample - last_sample) / timer.elapsed()
|
# gather decay rate on each UNet
|
||||||
|
ema_decay_list = {f"Unet {index} EMA Decay": ema_unet.get_current_decay() for index, ema_unet in enumerate(trainer.ema_unets)}
|
||||||
|
|
||||||
timer.reset()
|
log_data = {
|
||||||
last_sample = sample
|
"Epoch": epoch,
|
||||||
|
"Sample": sample,
|
||||||
|
"Step": i,
|
||||||
|
"Samples per second": samples_per_sec,
|
||||||
|
"Samples Seen": samples_seen,
|
||||||
|
**ema_decay_list,
|
||||||
|
**loss_map
|
||||||
|
}
|
||||||
|
|
||||||
if i % TRAIN_CALC_LOSS_EVERY_ITERS == 0:
|
if is_master:
|
||||||
average_loss = sum(losses) / len(losses)
|
tracker.log(log_data, step=step())
|
||||||
log_data = {
|
|
||||||
"Training loss": average_loss,
|
|
||||||
"Epoch": epoch,
|
|
||||||
"Sample": sample,
|
|
||||||
"Step": i,
|
|
||||||
"Samples per second": samples_per_sec
|
|
||||||
}
|
|
||||||
tracker.log(log_data, step=step, verbose=True)
|
|
||||||
losses = []
|
|
||||||
|
|
||||||
if last_snapshot + save_every_n_samples < sample: # This will miss by some amount every time, but it's not a big deal... I hope
|
if is_master and last_snapshot + save_every_n_samples < sample: # This will miss by some amount every time, but it's not a big deal... I hope
|
||||||
last_snapshot = sample
|
# It is difficult to gather this kind of info on the accelerator, so we have to do it on the master
|
||||||
# We need to know where the model should be saved
|
print("Saving snapshot")
|
||||||
save_paths = []
|
last_snapshot = sample
|
||||||
if save_latest:
|
# We need to know where the model should be saved
|
||||||
save_paths.append("latest.pth")
|
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, samples_seen)
|
||||||
if save_all:
|
if exists(n_sample_images) and n_sample_images > 0:
|
||||||
save_paths.append(f"checkpoints/epoch_{epoch}_step_{step}.pth")
|
trainer.eval()
|
||||||
|
train_images, train_captions = generate_grid_samples(trainer, train_example_data, condition_on_text_encodings, "Train: ")
|
||||||
|
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
|
||||||
|
|
||||||
save_trainer(tracker, trainer, epoch, step, validation_losses, save_paths)
|
if epoch_samples is not None and sample >= epoch_samples:
|
||||||
|
break
|
||||||
if exists(n_sample_images) and n_sample_images > 0:
|
next_task = 'val'
|
||||||
trainer.eval()
|
|
||||||
train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ")
|
|
||||||
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step)
|
|
||||||
|
|
||||||
if exists(epoch_samples) and sample >= epoch_samples:
|
|
||||||
break
|
|
||||||
|
|
||||||
trainer.eval()
|
|
||||||
print(print_ribbon(f"Starting Validation {epoch}", repeat=40))
|
|
||||||
with torch.no_grad():
|
|
||||||
sample = 0
|
sample = 0
|
||||||
average_loss = 0
|
|
||||||
|
all_average_val_losses = None
|
||||||
|
if next_task == 'val':
|
||||||
|
trainer.eval()
|
||||||
|
accelerator.print(print_ribbon(f"Starting Validation {epoch}", repeat=40))
|
||||||
|
last_val_sample = val_sample
|
||||||
|
val_sample_length_tensor = torch.zeros(1, dtype=torch.int, device=inference_device)
|
||||||
|
average_val_loss_tensor = torch.zeros(1, trainer.num_unets, dtype=torch.float, device=inference_device)
|
||||||
timer = Timer()
|
timer = Timer()
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
i = 0
|
||||||
for i, (img, emb, txt) in enumerate(dataloaders["val"]):
|
for i, (img, emb, txt) in enumerate(dataloaders["val"]):
|
||||||
sample += img.shape[0]
|
val_sample_length_tensor[0] = len(img)
|
||||||
img, emb = send_to_device((img, emb))
|
all_samples = accelerator.gather(val_sample_length_tensor)
|
||||||
|
total_samples = all_samples.sum().item()
|
||||||
|
val_sample += total_samples
|
||||||
|
img_emb = emb.get('img')
|
||||||
|
has_img_embedding = img_emb is not None
|
||||||
|
if has_img_embedding:
|
||||||
|
img_emb, = send_to_device((img_emb,))
|
||||||
|
text_emb = emb.get('text')
|
||||||
|
has_text_embedding = text_emb is not None
|
||||||
|
if has_text_embedding:
|
||||||
|
text_emb, = send_to_device((text_emb,))
|
||||||
|
img, = send_to_device((img,))
|
||||||
|
|
||||||
for unet in range(1, len(decoder.unets)+1):
|
for unet in range(1, len(decoder.unets)+1):
|
||||||
loss = trainer.forward(img.float(), image_embed=emb.float(), unet_number=unet)
|
if not unet_training_mask[unet-1]: # Unet index is the unet number - 1
|
||||||
average_loss += loss
|
# No need to evaluate an unchanging unet
|
||||||
|
continue
|
||||||
|
|
||||||
|
forward_params = {}
|
||||||
|
if has_img_embedding:
|
||||||
|
forward_params['image_embed'] = img_emb.float()
|
||||||
|
else:
|
||||||
|
# Forward pass automatically generates embedding
|
||||||
|
pass
|
||||||
|
if condition_on_text_encodings:
|
||||||
|
if has_text_embedding:
|
||||||
|
forward_params['text_encodings'] = text_emb.float()
|
||||||
|
else:
|
||||||
|
# Then we need to pass the text instead
|
||||||
|
tokenized_texts = tokenize(txt, truncate=True)
|
||||||
|
forward_params['text'] = tokenized_texts
|
||||||
|
loss = trainer.forward(img.float(), **forward_params, unet_number=unet)
|
||||||
|
average_val_loss_tensor[0, unet-1] += loss
|
||||||
|
|
||||||
if i % VALID_CALC_LOSS_EVERY_ITERS == 0:
|
if i % VALID_CALC_LOSS_EVERY_ITERS == 0:
|
||||||
print(f"Epoch {epoch}/{epochs} - {sample / timer.elapsed():.2f} samples/sec")
|
samples_per_sec = (val_sample - last_val_sample) / timer.elapsed()
|
||||||
print(f"Loss: {average_loss / (i+1)}")
|
timer.reset()
|
||||||
print("")
|
last_val_sample = val_sample
|
||||||
|
accelerator.print(f"Epoch {epoch}/{epochs} Val Step {i} - Sample {val_sample} - {samples_per_sec:.2f} samples/sec")
|
||||||
|
accelerator.print(f"Loss: {(average_val_loss_tensor / (i+1))}")
|
||||||
|
accelerator.print("")
|
||||||
|
|
||||||
if exists(validation_samples) and sample >= validation_samples:
|
if validation_samples is not None and val_sample >= validation_samples:
|
||||||
break
|
break
|
||||||
|
print(f"Rank {accelerator.state.process_index} finished validation after {i} steps")
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
average_val_loss_tensor /= i+1
|
||||||
|
# Gather all the average loss tensors
|
||||||
|
all_average_val_losses = accelerator.gather(average_val_loss_tensor)
|
||||||
|
if is_master:
|
||||||
|
unet_average_val_loss = all_average_val_losses.mean(dim=0)
|
||||||
|
val_loss_map = { f"Unet {index} Validation Loss": loss.item() for index, loss in enumerate(unet_average_val_loss) if loss != 0 }
|
||||||
|
tracker.log(val_loss_map, step=step())
|
||||||
|
next_task = 'eval'
|
||||||
|
|
||||||
average_loss /= i+1
|
if next_task == 'eval':
|
||||||
log_data = {
|
if exists(evaluate_config):
|
||||||
"Validation loss": average_loss
|
accelerator.print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
|
||||||
}
|
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config.dict(), condition_on_text_encodings=condition_on_text_encodings)
|
||||||
tracker.log(log_data, step=step, verbose=True)
|
if is_master:
|
||||||
|
tracker.log(evaluation, step=step())
|
||||||
|
next_task = 'sample'
|
||||||
|
val_sample = 0
|
||||||
|
|
||||||
# Compute evaluation metrics
|
if next_task == 'sample':
|
||||||
if exists(evaluate_config):
|
if is_master:
|
||||||
print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
|
# Generate examples and save the model if we are the master
|
||||||
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config)
|
# Generate sample images
|
||||||
tracker.log(evaluation, step=step, verbose=True)
|
print(print_ribbon(f"Sampling Set {epoch}", repeat=40))
|
||||||
|
test_images, test_captions = generate_grid_samples(trainer, test_example_data, condition_on_text_encodings, "Test: ")
|
||||||
|
train_images, train_captions = generate_grid_samples(trainer, train_example_data, condition_on_text_encodings, "Train: ")
|
||||||
|
tracker.log_images(test_images, captions=test_captions, image_section="Test Samples", step=step())
|
||||||
|
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
|
||||||
|
|
||||||
# Generate sample images
|
print(print_ribbon(f"Starting Saving {epoch}", repeat=40))
|
||||||
print(print_ribbon(f"Sampling Set {epoch}", repeat=40))
|
is_best = False
|
||||||
test_images, test_captions = generate_grid_samples(trainer, test_example_data, "Test: ")
|
if all_average_val_losses is not None:
|
||||||
train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ")
|
average_loss = all_average_val_losses.mean(dim=0).item()
|
||||||
tracker.log_images(test_images, captions=test_captions, image_section="Test Samples", step=step)
|
if len(validation_losses) == 0 or average_loss < min(validation_losses):
|
||||||
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step)
|
is_best = True
|
||||||
|
validation_losses.append(average_loss)
|
||||||
|
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, samples_seen, is_best=is_best)
|
||||||
|
next_task = 'train'
|
||||||
|
|
||||||
print(print_ribbon(f"Starting Saving {epoch}", repeat=40))
|
def create_tracker(accelerator: Accelerator, config: TrainDecoderConfig, config_path: str, dummy: bool = False) -> Tracker:
|
||||||
# Get the same paths
|
|
||||||
save_paths = []
|
|
||||||
if save_latest:
|
|
||||||
save_paths.append("latest.pth")
|
|
||||||
if save_best and (len(validation_losses) == 0 or average_loss < min(validation_losses)):
|
|
||||||
save_paths.append("best.pth")
|
|
||||||
validation_losses.append(average_loss)
|
|
||||||
save_trainer(tracker, trainer, epoch, step, validation_losses, save_paths)
|
|
||||||
|
|
||||||
def create_tracker(config, tracker_type=None, data_path=None, **kwargs):
|
|
||||||
"""
|
|
||||||
Creates a tracker of the specified type and initializes special features based on the full config
|
|
||||||
"""
|
|
||||||
tracker_config = config.tracker
|
tracker_config = config.tracker
|
||||||
init_config = {}
|
accelerator_config = {
|
||||||
|
"Distributed": accelerator.distributed_type != accelerate_dataclasses.DistributedType.NO,
|
||||||
if exists(tracker_config.init_config):
|
"DistributedType": accelerator.distributed_type,
|
||||||
init_config["config"] = tracker_config.init_config
|
"NumProcesses": accelerator.num_processes,
|
||||||
|
"MixedPrecision": accelerator.mixed_precision
|
||||||
if tracker_type == "console":
|
}
|
||||||
tracker = ConsoleTracker(**init_config)
|
tracker: Tracker = tracker_config.create(config, accelerator_config, dummy_mode=dummy)
|
||||||
elif tracker_type == "wandb":
|
tracker.save_config(config_path, config_name='decoder_config.json')
|
||||||
# We need to initialize the resume state here
|
|
||||||
load_config = config.load
|
|
||||||
if load_config.source == "wandb" and load_config.resume:
|
|
||||||
# Then we are resuming the run load_config["run_path"]
|
|
||||||
run_id = load_config.run_path.split("/")[-1]
|
|
||||||
init_config["id"] = run_id
|
|
||||||
init_config["resume"] = "must"
|
|
||||||
|
|
||||||
init_config["entity"] = tracker_config.wandb_entity
|
|
||||||
init_config["project"] = tracker_config.wandb_project
|
|
||||||
tracker = WandbTracker(data_path)
|
|
||||||
tracker.init(**init_config)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Tracker type {tracker_type} not supported by decoder trainer")
|
|
||||||
return tracker
|
return tracker
|
||||||
|
|
||||||
def initialize_training(config):
|
def initialize_training(config: TrainDecoderConfig, config_path):
|
||||||
# Create the save path
|
# Make sure if we are not loading, distributed models are initialized to the same values
|
||||||
if "cuda" in config.train.device:
|
torch.manual_seed(config.seed)
|
||||||
assert torch.cuda.is_available(), "CUDA is not available"
|
|
||||||
device = torch.device(config.train.device)
|
|
||||||
torch.cuda.set_device(device)
|
|
||||||
all_shards = list(range(config.data.start_shard, config.data.end_shard + 1))
|
|
||||||
|
|
||||||
|
# Set up accelerator for configurable distributed training
|
||||||
|
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=config.train.find_unused_parameters)
|
||||||
|
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
|
||||||
|
|
||||||
|
# Set up data
|
||||||
|
all_shards = list(range(config.data.start_shard, config.data.end_shard + 1))
|
||||||
|
world_size = accelerator.num_processes
|
||||||
|
rank = accelerator.process_index
|
||||||
|
shards_per_process = len(all_shards) // world_size
|
||||||
|
assert shards_per_process > 0, "Not enough shards to split evenly"
|
||||||
|
my_shards = all_shards[rank * shards_per_process: (rank + 1) * shards_per_process]
|
||||||
dataloaders = create_dataloaders (
|
dataloaders = create_dataloaders (
|
||||||
available_shards=all_shards,
|
available_shards=my_shards,
|
||||||
img_preproc = config.img_preproc,
|
img_preproc = config.data.img_preproc,
|
||||||
train_prop = config.data.splits.train,
|
train_prop = config.data.splits.train,
|
||||||
val_prop = config.data.splits.val,
|
val_prop = config.data.splits.val,
|
||||||
test_prop = config.data.splits.test,
|
test_prop = config.data.splits.test,
|
||||||
n_sample_images=config.train.n_sample_images,
|
n_sample_images=config.train.n_sample_images,
|
||||||
**config.data.dict()
|
**config.data.dict(),
|
||||||
|
rank = rank,
|
||||||
|
seed = config.seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
decoder = create_decoder(device, config.decoder, config.unets)
|
# Create the decoder model and print basic info
|
||||||
|
decoder = config.decoder.create()
|
||||||
num_parameters = sum(p.numel() for p in decoder.parameters())
|
num_parameters = sum(p.numel() for p in decoder.parameters())
|
||||||
print(print_ribbon("Loaded Config", repeat=40))
|
|
||||||
print(f"Number of parameters: {num_parameters}")
|
|
||||||
|
|
||||||
tracker = create_tracker(config, **config.tracker.dict())
|
# Create and initialize the tracker if we are the master
|
||||||
|
tracker = create_tracker(accelerator, config, config_path, dummy = rank!=0)
|
||||||
|
|
||||||
train(dataloaders, decoder,
|
has_img_embeddings = config.data.img_embeddings_url is not None
|
||||||
|
has_text_embeddings = config.data.text_embeddings_url is not None
|
||||||
|
conditioning_on_text = any([unet.cond_on_text_encodings for unet in config.decoder.unets])
|
||||||
|
|
||||||
|
has_clip_model = config.decoder.clip is not None
|
||||||
|
data_source_string = ""
|
||||||
|
|
||||||
|
if has_img_embeddings:
|
||||||
|
data_source_string += "precomputed image embeddings"
|
||||||
|
elif has_clip_model:
|
||||||
|
data_source_string += "clip image embeddings generation"
|
||||||
|
else:
|
||||||
|
raise ValueError("No image embeddings source specified")
|
||||||
|
if conditioning_on_text:
|
||||||
|
if has_text_embeddings:
|
||||||
|
data_source_string += " and precomputed text embeddings"
|
||||||
|
elif has_clip_model:
|
||||||
|
data_source_string += " and clip text encoding generation"
|
||||||
|
else:
|
||||||
|
raise ValueError("No text embeddings source specified")
|
||||||
|
|
||||||
|
accelerator.print(print_ribbon("Loaded Config", repeat=40))
|
||||||
|
accelerator.print(f"Running training with {accelerator.num_processes} processes and {accelerator.distributed_type} distributed training")
|
||||||
|
accelerator.print(f"Training using {data_source_string}. {'conditioned on text' if conditioning_on_text else 'not conditioned on text'}")
|
||||||
|
accelerator.print(f"Number of parameters: {num_parameters}")
|
||||||
|
train(dataloaders, decoder, accelerator,
|
||||||
tracker=tracker,
|
tracker=tracker,
|
||||||
inference_device=device,
|
inference_device=accelerator.device,
|
||||||
load_config=config.load,
|
|
||||||
evaluate_config=config.evaluate,
|
evaluate_config=config.evaluate,
|
||||||
|
condition_on_text_encodings=conditioning_on_text,
|
||||||
**config.train.dict(),
|
**config.train.dict(),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -447,10 +588,9 @@ def initialize_training(config):
|
|||||||
@click.command()
|
@click.command()
|
||||||
@click.option("--config_file", default="./train_decoder_config.json", help="Path to config file")
|
@click.option("--config_file", default="./train_decoder_config.json", help="Path to config file")
|
||||||
def main(config_file):
|
def main(config_file):
|
||||||
print("Recalling config from {}".format(config_file))
|
config_file_path = Path(config_file)
|
||||||
config = TrainDecoderConfig.from_json_path(config_file)
|
config = TrainDecoderConfig.from_json_path(str(config_file_path))
|
||||||
initialize_training(config)
|
initialize_training(config, config_path=config_file_path)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
@@ -1,75 +1,135 @@
|
|||||||
from pathlib import Path
|
# TODO: add start, num_data_points, eval_every and group to config
|
||||||
|
# TODO: switch back to repo's wandb
|
||||||
|
|
||||||
|
START = 0
|
||||||
|
NUM_DATA_POINTS = 250e6
|
||||||
|
EVAL_EVERY = 1000
|
||||||
|
GROUP = "distributed"
|
||||||
|
|
||||||
|
import os
|
||||||
import click
|
import click
|
||||||
import math
|
import wandb
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import clip
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from dalle2_pytorch.dataloaders import make_splits
|
import numpy as np
|
||||||
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter
|
|
||||||
from dalle2_pytorch.trainer import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model, print_ribbon
|
|
||||||
|
|
||||||
from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker
|
from accelerate import Accelerator
|
||||||
|
|
||||||
|
from dalle2_pytorch.dataloaders import get_reader, make_splits
|
||||||
from dalle2_pytorch.utils import Timer
|
from dalle2_pytorch.utils import Timer
|
||||||
|
from dalle2_pytorch.train_configs import (
|
||||||
|
DiffusionPriorTrainConfig,
|
||||||
|
TrainDiffusionPriorConfig,
|
||||||
|
)
|
||||||
|
from dalle2_pytorch.trackers import BaseTracker, WandbTracker
|
||||||
|
from dalle2_pytorch import DiffusionPriorTrainer
|
||||||
|
|
||||||
from embedding_reader import EmbeddingReader
|
|
||||||
|
|
||||||
from tqdm import tqdm
|
# helpers
|
||||||
|
|
||||||
# constants
|
|
||||||
|
|
||||||
REPORT_METRICS_EVERY = 250 # for cosine similarity and other metric reporting during training
|
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
|
||||||
|
|
||||||
tracker = WandbTracker()
|
|
||||||
|
|
||||||
# helpers functions
|
|
||||||
|
|
||||||
def exists(val):
|
def exists(val):
|
||||||
val is not None
|
return val is not None
|
||||||
|
|
||||||
# functions
|
|
||||||
|
|
||||||
def eval_model(model, dataloader, text_conditioned, loss_type, phase="Validation"):
|
def make_model(
|
||||||
model.eval()
|
prior_config, train_config, device: str = None, accelerator: Accelerator = None
|
||||||
|
):
|
||||||
|
# create model from config
|
||||||
|
diffusion_prior = prior_config.create()
|
||||||
|
|
||||||
|
# instantiate the trainer
|
||||||
|
trainer = DiffusionPriorTrainer(
|
||||||
|
diffusion_prior=diffusion_prior,
|
||||||
|
lr=train_config.lr,
|
||||||
|
wd=train_config.wd,
|
||||||
|
max_grad_norm=train_config.max_grad_norm,
|
||||||
|
amp=train_config.amp,
|
||||||
|
use_ema=train_config.use_ema,
|
||||||
|
device=device,
|
||||||
|
accelerator=accelerator,
|
||||||
|
)
|
||||||
|
|
||||||
|
return trainer
|
||||||
|
|
||||||
|
|
||||||
|
# eval functions
|
||||||
|
|
||||||
|
|
||||||
|
def eval_model(
|
||||||
|
trainer: DiffusionPriorTrainer,
|
||||||
|
dataloader: DataLoader,
|
||||||
|
text_conditioned: bool,
|
||||||
|
loss_type: str,
|
||||||
|
tracker_context: str,
|
||||||
|
tracker: BaseTracker = None,
|
||||||
|
use_ema: bool = True,
|
||||||
|
):
|
||||||
|
trainer.eval()
|
||||||
|
if trainer.is_main_process():
|
||||||
|
click.secho(f"Measuring performance on {tracker_context}", fg="green", blink=True)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
total_loss = 0.
|
total_loss = 0.0
|
||||||
total_samples = 0.
|
total_samples = 0.0
|
||||||
|
|
||||||
for image_embeddings, text_data in tqdm(dataloader):
|
for image_embeddings, text_data in dataloader:
|
||||||
|
image_embeddings = image_embeddings.to(trainer.device)
|
||||||
|
text_data = text_data.to(trainer.device)
|
||||||
|
|
||||||
batches = image_embeddings.shape[0]
|
batches = image_embeddings.shape[0]
|
||||||
|
|
||||||
input_args = dict(image_embed=image_embeddings)
|
input_args = dict(image_embed=image_embeddings)
|
||||||
|
|
||||||
if text_conditioned:
|
if text_conditioned:
|
||||||
input_args = dict(**input_args, text = text_data)
|
input_args = dict(**input_args, text=text_data)
|
||||||
else:
|
else:
|
||||||
input_args = dict(**input_args, text_embed=text_data)
|
input_args = dict(**input_args, text_embed=text_data)
|
||||||
|
|
||||||
loss = model(**input_args)
|
if use_ema:
|
||||||
|
loss = trainer.ema_diffusion_prior(**input_args)
|
||||||
|
else:
|
||||||
|
loss = trainer(**input_args)
|
||||||
|
|
||||||
total_loss += loss * batches
|
total_loss += loss * batches
|
||||||
total_samples += batches
|
total_samples += batches
|
||||||
|
|
||||||
avg_loss = (total_loss / total_samples)
|
avg_loss = total_loss / total_samples
|
||||||
|
|
||||||
tracker.log({f'{phase} {loss_type}': avg_loss})
|
stats = {f"{tracker_context}-{loss_type}": avg_loss}
|
||||||
|
trainer.print(stats)
|
||||||
|
|
||||||
def report_cosine_sims(diffusion_prior, dataloader, text_conditioned):
|
if exists(tracker):
|
||||||
diffusion_prior.eval()
|
tracker.log(stats, step=trainer.step.item() + 1)
|
||||||
|
|
||||||
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
|
|
||||||
|
|
||||||
for test_image_embeddings, text_data in tqdm(dataloader):
|
def report_cosine_sims(
|
||||||
|
trainer: DiffusionPriorTrainer,
|
||||||
|
dataloader: DataLoader,
|
||||||
|
text_conditioned: bool,
|
||||||
|
tracker: BaseTracker,
|
||||||
|
tracker_context: str = "validation",
|
||||||
|
):
|
||||||
|
trainer.eval()
|
||||||
|
if trainer.is_main_process():
|
||||||
|
click.secho("Measuring Cosine-Similarity", fg="green", blink=True)
|
||||||
|
|
||||||
|
for test_image_embeddings, text_data in dataloader:
|
||||||
|
test_image_embeddings = test_image_embeddings.to(trainer.device)
|
||||||
|
text_data = text_data.to(trainer.device)
|
||||||
|
|
||||||
# we are text conditioned, we produce an embedding from the tokenized text
|
# we are text conditioned, we produce an embedding from the tokenized text
|
||||||
if text_conditioned:
|
if text_conditioned:
|
||||||
text_embedding, text_encodings, text_mask = diffusion_prior.clip.embed_text(
|
text_embedding, text_encodings, text_mask = trainer.embed_text(text_data)
|
||||||
text_data)
|
text_cond = dict(
|
||||||
text_cond = dict(text_embed=text_embedding,
|
text_embed=text_embedding, text_encodings=text_encodings, mask=text_mask
|
||||||
text_encodings=text_encodings, mask=text_mask)
|
)
|
||||||
else:
|
else:
|
||||||
text_embedding = text_data
|
text_embedding = text_data
|
||||||
text_cond = dict(text_embed=text_embedding)
|
text_cond = dict(text_embed=text_embedding)
|
||||||
@@ -80,8 +140,9 @@ def report_cosine_sims(diffusion_prior, dataloader, text_conditioned):
|
|||||||
# roll the text to simulate "unrelated" captions
|
# roll the text to simulate "unrelated" captions
|
||||||
rolled_idx = torch.roll(torch.arange(text_embedding.shape[0]), 1)
|
rolled_idx = torch.roll(torch.arange(text_embedding.shape[0]), 1)
|
||||||
text_embed_shuffled = text_embed_shuffled[rolled_idx]
|
text_embed_shuffled = text_embed_shuffled[rolled_idx]
|
||||||
text_embed_shuffled = text_embed_shuffled / \
|
text_embed_shuffled = text_embed_shuffled / text_embed_shuffled.norm(
|
||||||
text_embed_shuffled.norm(dim=1, keepdim=True)
|
dim=1, keepdim=True
|
||||||
|
)
|
||||||
|
|
||||||
if text_conditioned:
|
if text_conditioned:
|
||||||
text_encodings_shuffled = text_encodings[rolled_idx]
|
text_encodings_shuffled = text_encodings[rolled_idx]
|
||||||
@@ -90,276 +151,276 @@ def report_cosine_sims(diffusion_prior, dataloader, text_conditioned):
|
|||||||
text_encodings_shuffled = None
|
text_encodings_shuffled = None
|
||||||
text_mask_shuffled = None
|
text_mask_shuffled = None
|
||||||
|
|
||||||
text_cond_shuffled = dict(text_embed=text_embed_shuffled,
|
text_cond_shuffled = dict(
|
||||||
text_encodings=text_encodings_shuffled, mask=text_mask_shuffled)
|
text_embed=text_embed_shuffled,
|
||||||
|
text_encodings=text_encodings_shuffled,
|
||||||
|
mask=text_mask_shuffled,
|
||||||
|
)
|
||||||
|
|
||||||
# prepare the text embedding
|
# prepare the text embedding
|
||||||
text_embed = text_embedding / text_embedding.norm(dim=1, keepdim=True)
|
text_embed = text_embedding / text_embedding.norm(dim=1, keepdim=True)
|
||||||
|
|
||||||
# prepare image embeddings
|
# prepare image embeddings
|
||||||
test_image_embeddings = test_image_embeddings / \
|
test_image_embeddings = test_image_embeddings / test_image_embeddings.norm(
|
||||||
test_image_embeddings.norm(dim=1, keepdim=True)
|
dim=1, keepdim=True
|
||||||
|
)
|
||||||
|
|
||||||
# predict on the unshuffled text embeddings
|
# predict on the unshuffled text embeddings
|
||||||
predicted_image_embeddings = diffusion_prior.p_sample_loop(
|
predicted_image_embeddings = trainer.p_sample_loop(
|
||||||
test_image_embeddings.shape, text_cond)
|
test_image_embeddings.shape, text_cond
|
||||||
predicted_image_embeddings = predicted_image_embeddings / \
|
)
|
||||||
predicted_image_embeddings.norm(dim=1, keepdim=True)
|
|
||||||
|
predicted_image_embeddings = (
|
||||||
|
predicted_image_embeddings
|
||||||
|
/ predicted_image_embeddings.norm(dim=1, keepdim=True)
|
||||||
|
)
|
||||||
|
|
||||||
# predict on the shuffled embeddings
|
# predict on the shuffled embeddings
|
||||||
predicted_unrelated_embeddings = diffusion_prior.p_sample_loop(
|
predicted_unrelated_embeddings = trainer.p_sample_loop(
|
||||||
test_image_embeddings.shape, text_cond_shuffled)
|
test_image_embeddings.shape, text_cond_shuffled
|
||||||
predicted_unrelated_embeddings = predicted_unrelated_embeddings / \
|
)
|
||||||
predicted_unrelated_embeddings.norm(dim=1, keepdim=True)
|
|
||||||
|
predicted_unrelated_embeddings = (
|
||||||
|
predicted_unrelated_embeddings
|
||||||
|
/ predicted_unrelated_embeddings.norm(dim=1, keepdim=True)
|
||||||
|
)
|
||||||
|
|
||||||
# calculate similarities
|
# calculate similarities
|
||||||
original_similarity = cos(
|
original_similarity = cos(text_embed, test_image_embeddings).cpu().numpy()
|
||||||
text_embed, test_image_embeddings).cpu().numpy()
|
predicted_similarity = cos(text_embed, predicted_image_embeddings).cpu().numpy()
|
||||||
predicted_similarity = cos(
|
unrelated_similarity = (
|
||||||
text_embed, predicted_image_embeddings).cpu().numpy()
|
cos(text_embed, predicted_unrelated_embeddings).cpu().numpy()
|
||||||
unrelated_similarity = cos(
|
)
|
||||||
text_embed, predicted_unrelated_embeddings).cpu().numpy()
|
predicted_img_similarity = (
|
||||||
predicted_img_similarity = cos(
|
cos(test_image_embeddings, predicted_image_embeddings).cpu().numpy()
|
||||||
test_image_embeddings, predicted_image_embeddings).cpu().numpy()
|
)
|
||||||
tracker.log({"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity),
|
|
||||||
"CosineSimilarity(text_embed,predicted_image_embed)":np.mean(predicted_similarity),
|
stats = {
|
||||||
"CosineSimilarity(orig_image_embed,predicted_image_embed)":np.mean(predicted_img_similarity),
|
f"{tracker_context}/baseline similarity": np.mean(original_similarity),
|
||||||
"CosineSimilarity(text_embed,predicted_unrelated_embed)": np.mean(unrelated_similarity),
|
f"{tracker_context}/similarity with text": np.mean(predicted_similarity),
|
||||||
"Cosine similarity difference":np.mean(predicted_similarity - original_similarity)})
|
f"{tracker_context}/similarity with original image": np.mean(
|
||||||
|
predicted_img_similarity
|
||||||
|
),
|
||||||
|
f"{tracker_context}/similarity with unrelated caption": np.mean(unrelated_similarity),
|
||||||
|
f"{tracker_context}/difference from baseline similarity": np.mean(
|
||||||
|
predicted_similarity - original_similarity
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, v in stats.items():
|
||||||
|
trainer.print(f"{tracker_context}/{k}: {v}")
|
||||||
|
|
||||||
|
if exists(tracker):
|
||||||
|
tracker.log(stats, step=trainer.step.item() + 1)
|
||||||
|
|
||||||
|
|
||||||
|
# training script
|
||||||
|
|
||||||
|
|
||||||
|
def train(
|
||||||
|
trainer: DiffusionPriorTrainer,
|
||||||
|
train_loader: DataLoader,
|
||||||
|
eval_loader: DataLoader,
|
||||||
|
test_loader: DataLoader,
|
||||||
|
config: DiffusionPriorTrainConfig,
|
||||||
|
):
|
||||||
|
# distributed tracking with wandb
|
||||||
|
if trainer.accelerator.num_processes > 1:
|
||||||
|
os.environ["WANDB_START_METHOD"] = "thread"
|
||||||
|
|
||||||
|
tracker = wandb.init(
|
||||||
|
name=f"RANK:{trainer.device}",
|
||||||
|
entity=config.tracker.wandb_entity,
|
||||||
|
project=config.tracker.wandb_project,
|
||||||
|
config=config.dict(),
|
||||||
|
group=GROUP,
|
||||||
|
)
|
||||||
|
|
||||||
|
# sync after tracker init
|
||||||
|
trainer.wait_for_everyone()
|
||||||
|
|
||||||
|
# init a timer
|
||||||
|
timer = Timer()
|
||||||
|
|
||||||
|
# do training
|
||||||
|
for img, txt in train_loader:
|
||||||
|
trainer.train()
|
||||||
|
current_step = trainer.step.item() + 1
|
||||||
|
|
||||||
|
# place data on device
|
||||||
|
img = img.to(trainer.device)
|
||||||
|
txt = txt.to(trainer.device)
|
||||||
|
|
||||||
|
# pass to model
|
||||||
|
loss = trainer(text=txt, image_embed=img)
|
||||||
|
|
||||||
|
# display & log loss (will only print from main process)
|
||||||
|
trainer.print(f"Step {current_step}: Loss {loss}")
|
||||||
|
|
||||||
|
# perform backprop & apply EMA updates
|
||||||
|
trainer.update()
|
||||||
|
|
||||||
|
# track samples/sec/rank
|
||||||
|
samples_per_sec = img.shape[0] / timer.elapsed()
|
||||||
|
|
||||||
|
# samples seen
|
||||||
|
samples_seen = (
|
||||||
|
config.data.batch_size * trainer.accelerator.num_processes * current_step
|
||||||
|
)
|
||||||
|
|
||||||
|
# ema decay
|
||||||
|
ema_decay = trainer.ema_diffusion_prior.get_current_decay()
|
||||||
|
|
||||||
|
# Log on all processes for debugging
|
||||||
|
tracker.log(
|
||||||
|
{
|
||||||
|
"tracking/samples-sec": samples_per_sec,
|
||||||
|
"tracking/samples-seen": samples_seen,
|
||||||
|
"tracking/ema-decay": ema_decay,
|
||||||
|
"metrics/training-loss": loss,
|
||||||
|
},
|
||||||
|
step=current_step,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Metric Tracking & Checkpointing (outside of timer's scope)
|
||||||
|
if current_step % EVAL_EVERY == 0:
|
||||||
|
eval_model(
|
||||||
|
trainer=trainer,
|
||||||
|
dataloader=eval_loader,
|
||||||
|
text_conditioned=config.prior.condition_on_text_encodings,
|
||||||
|
loss_type=config.prior.loss_type,
|
||||||
|
tracker_context="metrics/online-model-validation",
|
||||||
|
tracker=tracker,
|
||||||
|
use_ema=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
eval_model(
|
||||||
|
trainer=trainer,
|
||||||
|
dataloader=eval_loader,
|
||||||
|
text_conditioned=config.prior.condition_on_text_encodings,
|
||||||
|
loss_type=config.prior.loss_type,
|
||||||
|
tracker_context="metrics/ema-model-validation",
|
||||||
|
tracker=tracker,
|
||||||
|
use_ema=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
report_cosine_sims(
|
||||||
|
trainer=trainer,
|
||||||
|
dataloader=eval_loader,
|
||||||
|
text_conditioned=config.prior.condition_on_text_encodings,
|
||||||
|
tracker=tracker,
|
||||||
|
tracker_context="metrics",
|
||||||
|
)
|
||||||
|
|
||||||
|
if current_step % config.train.save_every == 0:
|
||||||
|
trainer.save(f"{config.tracker.data_path}/chkpt_step_{current_step}.pth")
|
||||||
|
|
||||||
|
# reset timer for next round
|
||||||
|
timer.reset()
|
||||||
|
|
||||||
|
# evaluate on test data
|
||||||
|
|
||||||
|
eval_model(
|
||||||
|
trainer=trainer,
|
||||||
|
dataloader=test_loader,
|
||||||
|
text_conditioned=config.prior.condition_on_text_encodings,
|
||||||
|
loss_type=config.prior.loss_type,
|
||||||
|
tracker_context="test",
|
||||||
|
tracker=tracker,
|
||||||
|
)
|
||||||
|
|
||||||
|
report_cosine_sims(
|
||||||
|
trainer,
|
||||||
|
test_loader,
|
||||||
|
config.prior.condition_on_text_encodings,
|
||||||
|
tracker,
|
||||||
|
tracker_context="test",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_training(config, accelerator=None):
|
||||||
|
"""
|
||||||
|
Parse the configuration file, and prepare everything necessary for training
|
||||||
|
"""
|
||||||
|
|
||||||
|
# get a device
|
||||||
|
|
||||||
|
if accelerator:
|
||||||
|
device = accelerator.device
|
||||||
|
click.secho(f"Accelerating on: {device}", fg="yellow")
|
||||||
|
else:
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
click.secho("GPU detected, defaulting to cuda:0", fg="yellow")
|
||||||
|
device = "cuda:0"
|
||||||
|
else:
|
||||||
|
click.secho("No GPU detected...using cpu", fg="yellow")
|
||||||
|
device = "cpu"
|
||||||
|
|
||||||
|
# make the trainer (will automatically distribute if possible & configured)
|
||||||
|
|
||||||
|
trainer = make_model(config.prior, config.train, device, accelerator).to(device)
|
||||||
|
|
||||||
|
# reload from chcekpoint
|
||||||
|
|
||||||
|
if config.load.resume == True:
|
||||||
|
click.secho(f"Loading checkpoint: {config.load.source}", fg="cyan")
|
||||||
|
trainer.load(config.load.source)
|
||||||
|
|
||||||
|
# fetch and prepare data
|
||||||
|
|
||||||
|
if trainer.is_main_process():
|
||||||
|
click.secho("Grabbing data from source", fg="blue", blink=True)
|
||||||
|
|
||||||
|
img_reader = get_reader(
|
||||||
|
text_conditioned=trainer.text_conditioned,
|
||||||
|
img_url=config.data.image_url,
|
||||||
|
meta_url=config.data.meta_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
train_loader, eval_loader, test_loader = make_splits(
|
||||||
|
text_conditioned=trainer.text_conditioned,
|
||||||
|
batch_size=config.data.batch_size,
|
||||||
|
num_data_points=NUM_DATA_POINTS,
|
||||||
|
train_split=config.data.splits.train,
|
||||||
|
eval_split=config.data.splits.val,
|
||||||
|
image_reader=img_reader,
|
||||||
|
rank=accelerator.state.process_index if exists(accelerator) else 0,
|
||||||
|
world_size=accelerator.state.num_processes if exists(accelerator) else 1,
|
||||||
|
start=START,
|
||||||
|
)
|
||||||
|
|
||||||
|
# wait for everyone to load data before continuing
|
||||||
|
trainer.wait_for_everyone()
|
||||||
|
|
||||||
|
# start training
|
||||||
|
train(
|
||||||
|
trainer=trainer,
|
||||||
|
train_loader=train_loader,
|
||||||
|
eval_loader=eval_loader,
|
||||||
|
test_loader=test_loader,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@click.command()
|
@click.command()
|
||||||
@click.option("--wandb-entity", default="laion")
|
@click.option("--hfa", default=True)
|
||||||
@click.option("--wandb-project", default="diffusion-prior")
|
@click.option("--config_path", default="configs/prior.json")
|
||||||
@click.option("--wandb-dataset", default="LAION-5B")
|
def main(hfa, config_path):
|
||||||
@click.option("--wandb-arch", default="DiffusionPrior")
|
# start HFA if requested
|
||||||
@click.option("--image-embed-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/")
|
if hfa:
|
||||||
@click.option("--text-embed-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/")
|
accelerator = Accelerator()
|
||||||
@click.option("--meta-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/laion2B-en-metadata/")
|
|
||||||
@click.option("--learning-rate", default=1.1e-4)
|
|
||||||
@click.option("--weight-decay", default=6.02e-2)
|
|
||||||
@click.option("--dropout", default=5e-2)
|
|
||||||
@click.option("--max-grad-norm", default=0.5)
|
|
||||||
@click.option("--num-data-points", default=250e6)
|
|
||||||
@click.option("--batch-size", default=320)
|
|
||||||
@click.option("--num-epochs", default=5)
|
|
||||||
@click.option("--image-embed-dim", default=768)
|
|
||||||
@click.option("--train-percent", default=0.9)
|
|
||||||
@click.option("--val-percent", default=1e-7)
|
|
||||||
@click.option("--test-percent", default=0.0999999)
|
|
||||||
@click.option("--dpn-depth", default=12)
|
|
||||||
@click.option("--dpn-dim-head", default=64)
|
|
||||||
@click.option("--dpn-heads", default=12)
|
|
||||||
@click.option("--dp-condition-on-text-encodings", default=True)
|
|
||||||
@click.option("--dp-timesteps", default=1000)
|
|
||||||
@click.option("--dp-normformer", default=True)
|
|
||||||
@click.option("--dp-cond-drop-prob", default=0.1)
|
|
||||||
@click.option("--dp-loss-type", default="l2")
|
|
||||||
@click.option("--clip", default="ViT-L/14")
|
|
||||||
@click.option("--amp", default=False)
|
|
||||||
@click.option("--save-interval", default=120)
|
|
||||||
@click.option("--save-path", default="./diffusion_prior_checkpoints")
|
|
||||||
@click.option("--pretrained-model-path", default=None)
|
|
||||||
@click.option("--gpu-device", default=0)
|
|
||||||
def train(
|
|
||||||
wandb_entity,
|
|
||||||
wandb_project,
|
|
||||||
wandb_dataset,
|
|
||||||
wandb_arch,
|
|
||||||
image_embed_url,
|
|
||||||
text_embed_url,
|
|
||||||
meta_url,
|
|
||||||
learning_rate,
|
|
||||||
weight_decay,
|
|
||||||
dropout,
|
|
||||||
max_grad_norm,
|
|
||||||
num_data_points,
|
|
||||||
batch_size,
|
|
||||||
num_epochs,
|
|
||||||
image_embed_dim,
|
|
||||||
train_percent,
|
|
||||||
val_percent,
|
|
||||||
test_percent,
|
|
||||||
dpn_depth,
|
|
||||||
dpn_dim_head,
|
|
||||||
dpn_heads,
|
|
||||||
dp_condition_on_text_encodings,
|
|
||||||
dp_timesteps,
|
|
||||||
dp_normformer,
|
|
||||||
dp_cond_drop_prob,
|
|
||||||
dp_loss_type,
|
|
||||||
clip,
|
|
||||||
amp,
|
|
||||||
save_interval,
|
|
||||||
save_path,
|
|
||||||
pretrained_model_path,
|
|
||||||
gpu_device
|
|
||||||
):
|
|
||||||
config = {
|
|
||||||
"learning_rate": learning_rate,
|
|
||||||
"architecture": wandb_arch,
|
|
||||||
"dataset": wandb_dataset,
|
|
||||||
"weight_decay": weight_decay,
|
|
||||||
"max_gradient_clipping_norm": max_grad_norm,
|
|
||||||
"batch_size": batch_size,
|
|
||||||
"epochs": num_epochs,
|
|
||||||
"diffusion_prior_network": {
|
|
||||||
"depth": dpn_depth,
|
|
||||||
"dim_head": dpn_dim_head,
|
|
||||||
"heads": dpn_heads,
|
|
||||||
"normformer": dp_normformer
|
|
||||||
},
|
|
||||||
"diffusion_prior": {
|
|
||||||
"condition_on_text_encodings": dp_condition_on_text_encodings,
|
|
||||||
"timesteps": dp_timesteps,
|
|
||||||
"cond_drop_prob": dp_cond_drop_prob,
|
|
||||||
"loss_type": dp_loss_type,
|
|
||||||
"clip": clip
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
# Check if DPRIOR_PATH exists(saved model path)
|
|
||||||
|
|
||||||
DPRIOR_PATH = pretrained_model_path
|
|
||||||
RESUME = exists(DPRIOR_PATH)
|
|
||||||
|
|
||||||
if not RESUME:
|
|
||||||
tracker.init(
|
|
||||||
entity = wandb_entity,
|
|
||||||
project = wandb_project,
|
|
||||||
config = config
|
|
||||||
)
|
|
||||||
|
|
||||||
# Obtain the utilized device.
|
|
||||||
|
|
||||||
has_cuda = torch.cuda.is_available()
|
|
||||||
if has_cuda:
|
|
||||||
device = torch.device(f"cuda:{gpu_device}")
|
|
||||||
torch.cuda.set_device(device)
|
|
||||||
|
|
||||||
# Training loop
|
|
||||||
# diffusion prior network
|
|
||||||
|
|
||||||
prior_network = DiffusionPriorNetwork(
|
|
||||||
dim = image_embed_dim,
|
|
||||||
depth = dpn_depth,
|
|
||||||
dim_head = dpn_dim_head,
|
|
||||||
heads = dpn_heads,
|
|
||||||
attn_dropout = dropout,
|
|
||||||
ff_dropout = dropout,
|
|
||||||
normformer = dp_normformer
|
|
||||||
)
|
|
||||||
|
|
||||||
# Load clip model if text-conditioning
|
|
||||||
if dp_condition_on_text_encodings:
|
|
||||||
clip_adapter = OpenAIClipAdapter(clip)
|
|
||||||
else:
|
else:
|
||||||
clip_adapter = None
|
accelerator = None
|
||||||
|
|
||||||
# diffusion prior with text embeddings and image embeddings pre-computed
|
# load the configuration file on main process
|
||||||
|
if not exists(accelerator) or accelerator.is_main_process:
|
||||||
|
click.secho(f"Loading configuration from {config_path}", fg="green")
|
||||||
|
|
||||||
diffusion_prior = DiffusionPrior(
|
config = TrainDiffusionPriorConfig.from_json_path(config_path)
|
||||||
net = prior_network,
|
|
||||||
clip = clip_adapter,
|
|
||||||
image_embed_dim = image_embed_dim,
|
|
||||||
timesteps = dp_timesteps,
|
|
||||||
cond_drop_prob = dp_cond_drop_prob,
|
|
||||||
loss_type = dp_loss_type,
|
|
||||||
condition_on_text_encodings = dp_condition_on_text_encodings
|
|
||||||
)
|
|
||||||
|
|
||||||
# Load pre-trained model from DPRIOR_PATH
|
# send config to get processed
|
||||||
|
initialize_training(config, accelerator)
|
||||||
if RESUME:
|
|
||||||
diffusion_prior, loaded_obj = load_diffusion_model(DPRIOR_PATH, device)
|
|
||||||
tracker.init(entity = wandb_entity, project = wandb_project, config = config)
|
|
||||||
|
|
||||||
# diffusion prior trainer
|
|
||||||
|
|
||||||
trainer = DiffusionPriorTrainer(
|
|
||||||
diffusion_prior = diffusion_prior,
|
|
||||||
lr = learning_rate,
|
|
||||||
wd = weight_decay,
|
|
||||||
max_grad_norm = max_grad_norm,
|
|
||||||
amp = amp,
|
|
||||||
).to(device)
|
|
||||||
|
|
||||||
# load optimizer and scaler
|
|
||||||
|
|
||||||
if RESUME:
|
|
||||||
trainer.optimizer.load_state_dict(loaded_obj['optimizer'])
|
|
||||||
trainer.scaler.load_state_dict(loaded_obj['scaler'])
|
|
||||||
|
|
||||||
# Create save_path if it doesn't exist
|
|
||||||
|
|
||||||
Path(save_path).mkdir(exist_ok = True, parents = True)
|
|
||||||
|
|
||||||
# Utilize wrapper to abstract away loader logic
|
|
||||||
print_ribbon("Downloading Embeddings")
|
|
||||||
loader_args = dict(text_conditioned=dp_condition_on_text_encodings, batch_size=batch_size, num_data_points=num_data_points,
|
|
||||||
train_split=train_percent, eval_split=val_percent, device=device, img_url=image_embed_url)
|
|
||||||
|
|
||||||
if dp_condition_on_text_encodings:
|
|
||||||
loader_args = dict(**loader_args, meta_url=meta_url)
|
|
||||||
else:
|
|
||||||
loader_args = dict(**loader_args, txt_url=text_embed_url)
|
|
||||||
|
|
||||||
train_loader, eval_loader, test_loader = make_splits(**loader_args)
|
|
||||||
|
|
||||||
### Training code ###
|
|
||||||
|
|
||||||
step = 1
|
|
||||||
timer = Timer()
|
|
||||||
epochs = num_epochs
|
|
||||||
|
|
||||||
for _ in range(epochs):
|
|
||||||
|
|
||||||
for image, text in tqdm(train_loader):
|
|
||||||
|
|
||||||
diffusion_prior.train()
|
|
||||||
|
|
||||||
input_args = dict(image_embed=image)
|
|
||||||
if dp_condition_on_text_encodings:
|
|
||||||
input_args = dict(**input_args, text = text)
|
|
||||||
else:
|
|
||||||
input_args = dict(**input_args, text_embed=text)
|
|
||||||
|
|
||||||
loss = trainer(**input_args)
|
|
||||||
|
|
||||||
# Samples per second
|
|
||||||
|
|
||||||
samples_per_sec = batch_size * step / timer.elapsed()
|
|
||||||
|
|
||||||
# Save checkpoint every save_interval minutes
|
|
||||||
if(int(timer.elapsed()) >= 60 * save_interval):
|
|
||||||
timer.reset()
|
|
||||||
|
|
||||||
save_diffusion_model(
|
|
||||||
save_path,
|
|
||||||
diffusion_prior,
|
|
||||||
trainer.optimizer,
|
|
||||||
trainer.scaler,
|
|
||||||
config,
|
|
||||||
image_embed_dim)
|
|
||||||
|
|
||||||
# Log to wandb
|
|
||||||
tracker.log({"Training loss": loss,
|
|
||||||
"Steps": step,
|
|
||||||
"Samples per second": samples_per_sec})
|
|
||||||
# Log cosineSim(text_embed,predicted_image_embed) - cosineSim(text_embed,image_embed)
|
|
||||||
# Use NUM_TEST_EMBEDDINGS samples from the test set each time
|
|
||||||
# Get embeddings from the most recently saved model
|
|
||||||
if(step % REPORT_METRICS_EVERY) == 0:
|
|
||||||
report_cosine_sims(diffusion_prior, eval_loader, dp_condition_on_text_encodings)
|
|
||||||
### Evaluate model(validation run) ###
|
|
||||||
eval_model(diffusion_prior, eval_loader, dp_condition_on_text_encodings, dp_loss_type, phase="Validation")
|
|
||||||
|
|
||||||
step += 1
|
|
||||||
trainer.update()
|
|
||||||
|
|
||||||
### Test run ###
|
|
||||||
eval_model(diffusion_prior, test_loader, dp_condition_on_text_encodings, dp_loss_type, phase="Test")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
train()
|
main()
|
||||||
|
|||||||
Reference in New Issue
Block a user