mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 11:34:29 +01:00
Compare commits
108 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2e85e736f3 | ||
|
|
f5760bdb92 | ||
|
|
c453f468b1 | ||
|
|
98f0c17759 | ||
|
|
a5b9fd6ca8 | ||
|
|
4b994601ae | ||
|
|
fddf66e91e | ||
|
|
c8422ffd5d | ||
|
|
2aadc23c7c | ||
|
|
c098f57e09 | ||
|
|
0021535c26 | ||
|
|
56883910fb | ||
|
|
893f270012 | ||
|
|
f545ce18f4 | ||
|
|
fc7abf624d | ||
|
|
67f0740777 | ||
|
|
138079ca83 | ||
|
|
f5a906f5d3 | ||
|
|
0215237fc6 | ||
|
|
461b91c5c1 | ||
|
|
58892135d9 | ||
|
|
e37072a48c | ||
|
|
41ca896413 | ||
|
|
fe19b508ca | ||
|
|
6651eafa93 | ||
|
|
e6bb75e5ab | ||
|
|
b4c3e5b854 | ||
|
|
b7f9607258 | ||
|
|
2219348a6e | ||
|
|
9eea9b9862 | ||
|
|
5d958713c0 | ||
|
|
0f31980362 | ||
|
|
bee5bf3815 | ||
|
|
350a3d6045 | ||
|
|
1a81670718 | ||
|
|
934c9728dc | ||
|
|
ce4b0107c1 | ||
|
|
64c2f9c4eb | ||
|
|
22cc613278 | ||
|
|
83517849e5 | ||
|
|
708809ed6c | ||
|
|
9cc475f6e7 | ||
|
|
ffd342e9d0 | ||
|
|
f8bfd3493a | ||
|
|
9025345e29 | ||
|
|
8cc278447e | ||
|
|
38cd62010c | ||
|
|
1cc288af39 | ||
|
|
a851168633 | ||
|
|
1ffeecd0ca | ||
|
|
3df899f7a4 | ||
|
|
09534119a1 | ||
|
|
6f8b90d4d7 | ||
|
|
b588286288 | ||
|
|
b693e0be03 | ||
|
|
a0bed30a84 | ||
|
|
387c5bf774 | ||
|
|
a13d2d89c5 | ||
|
|
44d4b1bba9 | ||
|
|
f12a7589c5 | ||
|
|
b8af2210df | ||
|
|
f4fe6c570d | ||
|
|
645e207441 | ||
|
|
00743b3a0b | ||
|
|
01589aff6a | ||
|
|
7ecfd76cc0 | ||
|
|
6161b61c55 | ||
|
|
1ed0f9d80b | ||
|
|
f326a95e26 | ||
|
|
d7a0a2ce4b | ||
|
|
f23fab7ef7 | ||
|
|
857b9fbf1e | ||
|
|
8864fd0aa7 | ||
|
|
72bf159331 | ||
|
|
e5e47cfecb | ||
|
|
fa533962bd | ||
|
|
276abf337b | ||
|
|
ae42d03006 | ||
|
|
4d346e98d9 | ||
|
|
2b1fd1ad2e | ||
|
|
82a2ef37d9 | ||
|
|
5c397c9d66 | ||
|
|
0f4edff214 | ||
|
|
501a8c7c46 | ||
|
|
4e49373fc5 | ||
|
|
49de72040c | ||
|
|
271a376eaf | ||
|
|
e527002472 | ||
|
|
c12e067178 | ||
|
|
c6629c431a | ||
|
|
7ac2fc79f2 | ||
|
|
a1ef023193 | ||
|
|
d49eca62fa | ||
|
|
8aab69b91e | ||
|
|
b432df2f7b | ||
|
|
ebaa0d28c2 | ||
|
|
8b0d459b25 | ||
|
|
0064661729 | ||
|
|
b895f52843 | ||
|
|
80497e9839 | ||
|
|
f526f14d7c | ||
|
|
8997f178d6 | ||
|
|
022c94e443 | ||
|
|
430961cb97 | ||
|
|
721f9687c1 | ||
|
|
e0524a6aff | ||
|
|
c85e0d5c35 | ||
|
|
db0642c4cd |
9
.gitignore
vendored
9
.gitignore
vendored
@@ -1,3 +1,12 @@
|
||||
# default experiment tracker data
|
||||
.tracker-data/
|
||||
|
||||
# Configuration Files
|
||||
configs/*
|
||||
!configs/*.example
|
||||
!configs/*_defaults.py
|
||||
!configs/README.md
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
|
||||
100
README.md
100
README.md
@@ -12,7 +12,7 @@ This model is SOTA for text-to-image for now.
|
||||
|
||||
Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord" src="https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white"></a> if you are interested in helping out with the replication with the <a href="https://laion.ai/">LAION</a> community | <a href="https://www.youtube.com/watch?v=AIOE1l1W0Tw">Yannic Interview</a>
|
||||
|
||||
There was enough interest for a <a href="https://github.com/lucidrains/dalle2-jax">Jax version</a>. I will also eventually extend this to <a href="https://github.com/lucidrains/dalle2-video">text to video</a>, once the repository is in a good place.
|
||||
As of 5/23/22, it is no longer SOTA. SOTA will be <a href="https://github.com/lucidrains/imagen-pytorch">here</a>. Jax versions as well as text-to-video project will be shifted towards the Imagen architecture, as it is way simpler.
|
||||
|
||||
## Status
|
||||
|
||||
@@ -24,6 +24,30 @@ There was enough interest for a <a href="https://github.com/lucidrains/dalle2-ja
|
||||
|
||||
*ongoing at 21k steps*
|
||||
|
||||
- <a href="https://twitter.com/Buntworthy/status/1529475416775434240?t=0GEge3Kr9I36cjcUVCQUTg">Justin Pinkney</a> successfully trained the diffusion prior in the repository for his CLIP to Stylegan2 text-to-image application
|
||||
|
||||
## Pre-Trained Models
|
||||
|
||||
- LAION is training prior models. Checkpoints are available on <a href="https://huggingface.co/zenglishuci/conditioned-prior">🤗huggingface</a> and the training statistics are available on <a href="https://wandb.ai/nousr_laion/conditioned-prior/reports/LAION-DALLE2-PyTorch-Prior--VmlldzoyMDI2OTIx">🐝WANDB</a>.
|
||||
- Decoder - <a href="https://wandb.ai/veldrovive/dalle2_train_decoder/runs/jkrtg0so?workspace=user-veldrovive">In-progress test run</a> 🚧
|
||||
- Decoder - <a href="https://wandb.ai/veldrovive/dalle2_train_decoder/runs/3d5rytsa?workspace=">Another test run with sparse attention</a>
|
||||
- DALL-E 2 🚧
|
||||
|
||||
## Appreciation
|
||||
|
||||
This library would not have gotten to this working state without the help of
|
||||
|
||||
- <a href="https://github.com/nousr">Zion</a> for the distributed training code for the diffusion prior
|
||||
- <a href="https://github.com/Veldrovive">Aidan</a> for the distributed training code for the decoder as well as the dataloaders
|
||||
- <a href="https://github.com/krish240574">Kumar</a> for working on the initial diffusion training script
|
||||
- <a href="https://github.com/rom1504">Romain</a> for the pull request reviews and project management
|
||||
- <a href="https://github.com/Ciaohe">He Cao</a> and <a href="https://github.com/xiankgx">xiankgx</a> for the Q&A and for identifying of critical bugs
|
||||
- <a href="https://github.com/crowsonkb">Katherine</a> for her advice
|
||||
- <a href="https://stability.ai/">Stability AI</a> for the generous sponsorship
|
||||
- <a href="https://huggingface.co">🤗 Huggingface</a> and in particular <a href="https://github.com/sgugger">Sylvain</a> for the <a href="https://github.com/huggingface/accelerate">Accelerate</a> library
|
||||
|
||||
... and many others. Thank you! 🙏
|
||||
|
||||
## Install
|
||||
|
||||
```bash
|
||||
@@ -936,7 +960,7 @@ from dalle2_pytorch.dataloaders import ImageEmbeddingDataset, create_image_embed
|
||||
|
||||
# Create a dataloader directly.
|
||||
dataloader = create_image_embedding_dataloader(
|
||||
tar_url="/path/or/url/to/webdataset/{0000..9999}.tar", # Uses braket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar
|
||||
tar_url="/path/or/url/to/webdataset/{0000..9999}.tar", # Uses bracket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar
|
||||
embeddings_url="path/or/url/to/embeddings/folder", # Included if .npy files are not in webdataset. Left out or set to None otherwise
|
||||
num_workers=4,
|
||||
batch_size=32,
|
||||
@@ -993,33 +1017,6 @@ The most significant parameters for the script are as follows:
|
||||
|
||||
- `clip`, default = `None` # Signals the prior to use pre-computed embeddings
|
||||
|
||||
#### Loading and Saving the DiffusionPrior model
|
||||
|
||||
Two methods are provided, load_diffusion_model and save_diffusion_model, the names being self-explanatory.
|
||||
|
||||
```python
|
||||
from dalle2_pytorch.train import load_diffusion_model, save_diffusion_model
|
||||
```
|
||||
|
||||
##### Loading
|
||||
|
||||
load_diffusion_model(dprior_path, device)
|
||||
dprior_path : path to saved model(.pth)
|
||||
device : the cuda device you're running on
|
||||
|
||||
##### Saving
|
||||
|
||||
save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim)
|
||||
save_path : path to save at
|
||||
model : object of Diffusion_Prior
|
||||
optimizer : optimizer object - see train_diffusion_prior.py for how to create one.
|
||||
e.g: optimizer = get_optimizer(diffusion_prior.net.parameters(), wd=weight_decay, lr=learning_rate)
|
||||
scaler : a GradScaler object.
|
||||
e.g: scaler = GradScaler(enabled=amp)
|
||||
config : config object created in train_diffusion_prior.py - see file for example.
|
||||
image_embed_dim - the dimension of the image_embedding
|
||||
e.g: 768
|
||||
|
||||
## CLI (wip)
|
||||
|
||||
```bash
|
||||
@@ -1064,22 +1061,18 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
- [x] bring in cross-scale embedding from iclr paper https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/crossformer.py#L14
|
||||
- [x] cross embed layers for downsampling, as an option
|
||||
- [x] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a>
|
||||
- [x] use pydantic for config drive training
|
||||
- [x] for both diffusion prior and decoder, all exponential moving averaged models needs to be saved and restored as well (as well as the step number)
|
||||
- [x] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
|
||||
- [x] allow for creation of diffusion prior model off pydantic config classes - consider the same for tracker configs
|
||||
- [x] bring in skip-layer excitations (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training (doesnt work well)
|
||||
- [x] test out grid attention in cascading ddpm locally, decide whether to keep or remove https://arxiv.org/abs/2204.01697 (keeping, seems to be fine)
|
||||
- [x] allow for unet to be able to condition non-cross attention style as well
|
||||
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
|
||||
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
|
||||
- [ ] train on a toy task, offer in colab
|
||||
- [ ] think about how best to design a declarative training config that handles preencoding for prior and training of multiple networks in decoder
|
||||
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
|
||||
- [ ] speed up inference, read up on papers (ddim or diffusion-gan, etc)
|
||||
- [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824
|
||||
- [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove
|
||||
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
|
||||
- [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783
|
||||
- [ ] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
|
||||
- [ ] bring in skip-layer excitatons (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training
|
||||
- [ ] decoder needs one day worth of refactor for tech debt
|
||||
- [ ] allow for unet to be able to condition non-cross attention style as well
|
||||
- [ ] for all model classes with hyperparameters that changes the network architecture, make it requirement that they must expose a config property, and write a simple function that asserts that it restores the object correctly
|
||||
- [ ] for both diffusion prior and decoder, all exponential moving averaged models needs to be saved and restored as well (as well as the step number)
|
||||
- [ ] read the paper, figure it out, and build it https://github.com/lucidrains/DALLE2-pytorch/issues/89
|
||||
- [ ] add inpainting ability using resampler from repaint paper https://arxiv.org/abs/2201.09865
|
||||
|
||||
## Citations
|
||||
|
||||
@@ -1122,8 +1115,9 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
```bibtex
|
||||
@inproceedings{Tu2022MaxViTMV,
|
||||
title = {MaxViT: Multi-Axis Vision Transformer},
|
||||
author = {Zhe-Wei Tu and Hossein Talebi and Han Zhang and Feng Yang and Peyman Milanfar and Alan Conrad Bovik and Yinxiao Li},
|
||||
year = {2022}
|
||||
author = {Zhengzhong Tu and Hossein Talebi and Han Zhang and Feng Yang and Peyman Milanfar and Alan Conrad Bovik and Yinxiao Li},
|
||||
year = {2022},
|
||||
url = {https://arxiv.org/abs/2204.01697}
|
||||
}
|
||||
```
|
||||
|
||||
@@ -1177,4 +1171,22 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{Saharia2022,
|
||||
title = {Imagen: unprecedented photorealism × deep level of language understanding},
|
||||
author = {Chitwan Saharia*, William Chan*, Saurabh Saxena†, Lala Li†, Jay Whang†, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S. Sara Mahdavi, Rapha Gontijo Lopes, Tim Salimans, Jonathan Ho†, David Fleet†, Mohammad Norouzi*},
|
||||
year = {2022}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@article{Choi2022PerceptionPT,
|
||||
title = {Perception Prioritized Training of Diffusion Models},
|
||||
author = {Jooyoung Choi and Jungbeom Lee and Chaehun Shin and Sungwon Kim and Hyunwoo J. Kim and Sung-Hoon Yoon},
|
||||
journal = {ArXiv},
|
||||
year = {2022},
|
||||
volume = {abs/2204.00227}
|
||||
}
|
||||
```
|
||||
|
||||
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>
|
||||
|
||||
111
configs/README.md
Normal file
111
configs/README.md
Normal file
@@ -0,0 +1,111 @@
|
||||
## DALLE2 Training Configurations
|
||||
|
||||
For more complex configuration, we provide the option of using a configuration file instead of command line arguments.
|
||||
|
||||
### Decoder Trainer
|
||||
|
||||
The decoder trainer has 7 main configuration options. A full example of their use can be found in the [example decoder configuration](train_decoder_config.example.json).
|
||||
|
||||
**<ins>Unet</ins>:**
|
||||
|
||||
This is a single unet config, which belongs as an array nested under the decoder config as a list of `unets`
|
||||
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `dim` | Yes | N/A | The starting channels of the unet. |
|
||||
| `image_embed_dim` | Yes | N/A | The dimension of the image embeddings. |
|
||||
| `dim_mults` | No | `(1, 2, 4, 8)` | The growth factors of the channels. |
|
||||
|
||||
Any parameter from the `Unet` constructor can also be given here.
|
||||
|
||||
**<ins>Decoder</ins>:**
|
||||
|
||||
Defines the configuration options for the decoder model. The unets defined above will automatically be inserted.
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `unets` | Yes | N/A | A list of unets, using the configuration above |
|
||||
| `image_sizes` | Yes | N/A | The resolution of the image after each upsampling step. The length of this array should be the number of unets defined. |
|
||||
| `image_size` | Yes | N/A | Not used. Can be any number. |
|
||||
| `timesteps` | No | `1000` | The number of diffusion timesteps used for generation. |
|
||||
| `loss_type` | No | `l2` | The loss function. Options are `l1`, `huber`, or `l2`. |
|
||||
| `beta_schedule` | No | `cosine` | The noising schedule. Options are `cosine`, `linear`, `quadratic`, `jsd`, or `sigmoid`. |
|
||||
| `learned_variance` | No | `True` | Whether to learn the variance. |
|
||||
|
||||
Any parameter from the `Decoder` constructor can also be given here.
|
||||
|
||||
**<ins>Data</ins>:**
|
||||
|
||||
Settings for creation of the dataloaders.
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `webdataset_base_url` | Yes | N/A | The url of a shard in the webdataset with the shard replaced with `{}`[^1]. |
|
||||
| `embeddings_url` | No | N/A | The url of the folder containing embeddings shards. Not required if embeddings are in webdataset. |
|
||||
| `num_workers` | No | `4` | The number of workers used in the dataloader. |
|
||||
| `batch_size` | No | `64` | The batch size. |
|
||||
| `start_shard` | No | `0` | Defines the start of the shard range the dataset will recall. |
|
||||
| `end_shard` | No | `9999999` | Defines the end of the shard range the dataset will recall. |
|
||||
| `shard_width` | No | `6` | Defines the width of one webdataset shard number[^2]. |
|
||||
| `index_width` | No | `4` | Defines the width of the index of a file inside a shard[^3]. |
|
||||
| `splits` | No | `{ "train": 0.75, "val": 0.15, "test": 0.1 }` | Defines the proportion of shards that will be allocated to the training, validation, and testing datasets. |
|
||||
| `shuffle_train` | No | `True` | Whether to shuffle the shards of the training dataset. |
|
||||
| `resample_train` | No | `False` | If true, shards will be randomly sampled with replacement from the datasets making the epoch length infinite if a limit is not set. Cannot be enabled if `shuffle_train` is enabled. |
|
||||
| `preprocessing` | No | `{ "ToTensor": True }` | Defines preprocessing applied to images from the datasets. |
|
||||
|
||||
[^1]: If your shard files have the paths `protocol://path/to/shard/00104.tar`, then the base url would be `protocol://path/to/shard/{}.tar`. If you are using a protocol like `s3`, you need to pipe the tars. For example `pipe:s3cmd get s3://bucket/path/{}.tar -`.
|
||||
|
||||
[^2]: This refers to the string length of the shard number for your webdataset shards. For instance, if your webdataset shard has the filename `00104.tar`, your shard length is 5.
|
||||
|
||||
[^3]: Inside the webdataset `tar`, you have files named something like `001045945.jpg`. 5 of these characters refer to the shard, and 4 refer to the index of the file in the webdataset (shard is `001041` and index is `5945`). The `index_width` in this case is 4.
|
||||
|
||||
**<ins>Train</ins>:**
|
||||
|
||||
Settings for controlling the training hyperparameters.
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `epochs` | No | `20` | The number of epochs in the training run. |
|
||||
| `lr` | No | `1e-4` | The learning rate. |
|
||||
| `wd` | No | `0.01` | The weight decay. |
|
||||
| `max_grad_norm`| No | `0.5` | The grad norm clipping. |
|
||||
| `save_every_n_samples` | No | `100000` | Samples will be generated and a checkpoint will be saved every `save_every_n_samples` samples. |
|
||||
| `device` | No | `cuda:0` | The device to train on. |
|
||||
| `epoch_samples` | No | `None` | Limits the number of samples iterated through in each epoch. This must be set if resampling. None means no limit. |
|
||||
| `validation_samples` | No | `None` | The number of samples to use for validation. None mean the entire validation set. |
|
||||
| `use_ema` | No | `True` | Whether to use exponential moving average models for sampling. |
|
||||
| `ema_beta` | No | `0.99` | The ema coefficient. |
|
||||
| `save_all` | No | `False` | If True, preserves a checkpoint for every epoch. |
|
||||
| `save_latest` | No | `True` | If True, overwrites the `latest.pth` every time the model is saved. |
|
||||
| `save_best` | No | `True` | If True, overwrites the `best.pth` every time the model has a lower validation loss than all previous models. |
|
||||
| `unet_training_mask` | No | `None` | A boolean array of the same length as the number of unets. If false, the unet is frozen. A value of `None` trains all unets. |
|
||||
|
||||
**<ins>Evaluate</ins>:**
|
||||
|
||||
Defines which evaluation metrics will be used to test the model.
|
||||
Each metric can be enabled by setting its configuration. The configuration keys for each metric are defined by the torchmetrics constructors which will be linked.
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `n_evaluation_samples` | No | `1000` | The number of samples to generate to test the model. |
|
||||
| `FID` | No | `None` | Setting to an object enables the [Frechet Inception Distance](https://torchmetrics.readthedocs.io/en/stable/image/frechet_inception_distance.html) metric.
|
||||
| `IS` | No | `None` | Setting to an object enables the [Inception Score](https://torchmetrics.readthedocs.io/en/stable/image/inception_score.html) metric.
|
||||
| `KID` | No | `None` | Setting to an object enables the [Kernel Inception Distance](https://torchmetrics.readthedocs.io/en/stable/image/kernel_inception_distance.html) metric. |
|
||||
| `LPIPS` | No | `None` | Setting to an object enables the [Learned Perceptual Image Patch Similarity](https://torchmetrics.readthedocs.io/en/stable/image/learned_perceptual_image_patch_similarity.html) metric. |
|
||||
|
||||
**<ins>Tracker</ins>:**
|
||||
|
||||
Selects which tracker to use and configures it.
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `tracker_type` | No | `console` | Which tracker to use. Currently accepts `console` or `wandb`. |
|
||||
| `data_path` | No | `./models` | Where the tracker will store local data. |
|
||||
| `verbose` | No | `False` | Enables console logging for non-console trackers. |
|
||||
|
||||
Other configuration options are required for the specific trackers. To see which are required, reference the initializer parameters of each [tracker](../dalle2_pytorch/trackers.py).
|
||||
|
||||
**<ins>Load</ins>:**
|
||||
|
||||
Selects where to load a pretrained model from.
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `source` | No | `None` | Supports `file` or `wandb`. |
|
||||
| `resume` | No | `False` | If the tracker support resuming the run, resume it. |
|
||||
|
||||
Other configuration options are required for loading from a specific source. To see which are required, reference the load methods at the top of the [tracker file](../dalle2_pytorch/trackers.py).
|
||||
99
configs/train_decoder_config.example.json
Normal file
99
configs/train_decoder_config.example.json
Normal file
@@ -0,0 +1,99 @@
|
||||
{
|
||||
"decoder": {
|
||||
"unets": [
|
||||
{
|
||||
"dim": 128,
|
||||
"image_embed_dim": 768,
|
||||
"cond_dim": 64,
|
||||
"channels": 3,
|
||||
"dim_mults": [1, 2, 4, 8],
|
||||
"attn_dim_head": 32,
|
||||
"attn_heads": 16
|
||||
}
|
||||
],
|
||||
"image_sizes": [64],
|
||||
"channels": 3,
|
||||
"timesteps": 1000,
|
||||
"loss_type": "l2",
|
||||
"beta_schedule": ["cosine"],
|
||||
"learned_variance": true
|
||||
},
|
||||
"data": {
|
||||
"webdataset_base_url": "pipe:s3cmd get s3://bucket/path/{}.tar -",
|
||||
"embeddings_url": "s3://bucket/embeddings/path/",
|
||||
"num_workers": 4,
|
||||
"batch_size": 64,
|
||||
"start_shard": 0,
|
||||
"end_shard": 9999999,
|
||||
"shard_width": 6,
|
||||
"index_width": 4,
|
||||
"splits": {
|
||||
"train": 0.75,
|
||||
"val": 0.15,
|
||||
"test": 0.1
|
||||
},
|
||||
"shuffle_train": true,
|
||||
"resample_train": false,
|
||||
"preprocessing": {
|
||||
"RandomResizedCrop": {
|
||||
"size": [128, 128],
|
||||
"scale": [0.75, 1.0],
|
||||
"ratio": [1.0, 1.0]
|
||||
},
|
||||
"ToTensor": true
|
||||
}
|
||||
},
|
||||
"train": {
|
||||
"epochs": 20,
|
||||
"lr": 1e-4,
|
||||
"wd": 0.01,
|
||||
"max_grad_norm": 0.5,
|
||||
"save_every_n_samples": 100000,
|
||||
"n_sample_images": 6,
|
||||
"device": "cuda:0",
|
||||
"epoch_samples": null,
|
||||
"validation_samples": null,
|
||||
"use_ema": true,
|
||||
"ema_beta": 0.99,
|
||||
"amp": false,
|
||||
"save_all": false,
|
||||
"save_latest": true,
|
||||
"save_best": true,
|
||||
"unet_training_mask": [true]
|
||||
},
|
||||
"evaluate": {
|
||||
"n_evaluation_samples": 1000,
|
||||
"FID": {
|
||||
"feature": 64
|
||||
},
|
||||
"IS": {
|
||||
"feature": 64,
|
||||
"splits": 10
|
||||
},
|
||||
"KID": {
|
||||
"feature": 64,
|
||||
"subset_size": 10
|
||||
},
|
||||
"LPIPS": {
|
||||
"net_type": "vgg",
|
||||
"reduction": "mean"
|
||||
}
|
||||
},
|
||||
"tracker": {
|
||||
"tracker_type": "console",
|
||||
"data_path": "./models",
|
||||
|
||||
"wandb_entity": "",
|
||||
"wandb_project": "",
|
||||
|
||||
"verbose": false
|
||||
},
|
||||
"load": {
|
||||
"source": null,
|
||||
|
||||
"run_path": "",
|
||||
"file_path": "",
|
||||
|
||||
"resume": false
|
||||
}
|
||||
}
|
||||
70
configs/train_prior_config.example.json
Normal file
70
configs/train_prior_config.example.json
Normal file
@@ -0,0 +1,70 @@
|
||||
{
|
||||
"prior": {
|
||||
"clip": {
|
||||
"make": "x-clip",
|
||||
"model": "ViT-L/14",
|
||||
"base_model_kwargs": {
|
||||
"dim_text": 768,
|
||||
"dim_image": 768,
|
||||
"dim_latent": 768
|
||||
}
|
||||
},
|
||||
"net": {
|
||||
"dim": 768,
|
||||
"depth": 12,
|
||||
"num_timesteps": 1000,
|
||||
"num_time_embeds": 1,
|
||||
"num_image_embeds": 1,
|
||||
"num_text_embeds": 1,
|
||||
"dim_head": 64,
|
||||
"heads": 12,
|
||||
"ff_mult": 4,
|
||||
"norm_out": true,
|
||||
"attn_dropout": 0.0,
|
||||
"ff_dropout": 0.0,
|
||||
"final_proj": true,
|
||||
"normformer": true,
|
||||
"rotary_emb": true
|
||||
},
|
||||
"image_embed_dim": 768,
|
||||
"image_size": 224,
|
||||
"image_channels": 3,
|
||||
"timesteps": 1000,
|
||||
"cond_drop_prob": 0.1,
|
||||
"loss_type": "l2",
|
||||
"predict_x_start": true,
|
||||
"beta_schedule": "cosine",
|
||||
"condition_on_text_encodings": true
|
||||
},
|
||||
"data": {
|
||||
"image_url": "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/",
|
||||
"text_url": "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/",
|
||||
"meta_url": "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/laion2B-en-metadata/",
|
||||
"batch_size": 256,
|
||||
"splits": {
|
||||
"train": 0.9,
|
||||
"val": 1e-7,
|
||||
"test": 0.0999999
|
||||
}
|
||||
},
|
||||
"train": {
|
||||
"epochs": 1,
|
||||
"lr": 1.1e-4,
|
||||
"wd": 6.02e-2,
|
||||
"max_grad_norm": 0.5,
|
||||
"use_ema": true,
|
||||
"amp": false,
|
||||
"save_every": 10000
|
||||
},
|
||||
"load": {
|
||||
"source": null,
|
||||
"resume": false
|
||||
},
|
||||
"tracker": {
|
||||
"tracker_type": "wandb",
|
||||
"data_path": "./prior_checkpoints",
|
||||
"wandb_entity": "laion",
|
||||
"wandb_project": "diffusion-prior",
|
||||
"verbose": true
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
from dalle2_pytorch.version import __version__
|
||||
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
|
||||
from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter
|
||||
from dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import math
|
||||
from tqdm import tqdm
|
||||
from inspect import isfunction
|
||||
import random
|
||||
from tqdm.auto import tqdm
|
||||
from functools import partial, wraps
|
||||
from contextlib import contextmanager
|
||||
from collections import namedtuple
|
||||
@@ -11,7 +11,7 @@ import torch.nn.functional as F
|
||||
from torch import nn, einsum
|
||||
import torchvision.transforms as T
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops import rearrange, repeat, reduce
|
||||
from einops.layers.torch import Rearrange
|
||||
from einops_exts import rearrange_many, repeat_many, check_shape
|
||||
from einops_exts.torch import EinopsToAndFrom
|
||||
@@ -56,9 +56,12 @@ def maybe(fn):
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if isfunction(d) else d
|
||||
return d() if callable(d) else d
|
||||
|
||||
def cast_tuple(val, length = 1):
|
||||
if isinstance(val, list):
|
||||
val = tuple(val)
|
||||
|
||||
return val if isinstance(val, tuple) else ((val,) * length)
|
||||
|
||||
def module_device(module):
|
||||
@@ -310,11 +313,6 @@ def extract(a, t, x_shape):
|
||||
out = a.gather(-1, t)
|
||||
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
||||
|
||||
def noise_like(shape, device, repeat=False):
|
||||
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
|
||||
noise = lambda: torch.randn(shape, device=device)
|
||||
return repeat_noise() if repeat else noise()
|
||||
|
||||
def meanflat(x):
|
||||
return x.mean(dim = tuple(range(1, len(x.shape))))
|
||||
|
||||
@@ -369,7 +367,7 @@ def quadratic_beta_schedule(timesteps):
|
||||
scale = 1000 / timesteps
|
||||
beta_start = scale * 0.0001
|
||||
beta_end = scale * 0.02
|
||||
return torch.linspace(beta_start**2, beta_end**2, timesteps, dtype = torch.float64) ** 2
|
||||
return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype = torch.float64) ** 2
|
||||
|
||||
|
||||
def sigmoid_beta_schedule(timesteps):
|
||||
@@ -380,8 +378,8 @@ def sigmoid_beta_schedule(timesteps):
|
||||
return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
|
||||
|
||||
|
||||
class BaseGaussianDiffusion(nn.Module):
|
||||
def __init__(self, *, beta_schedule, timesteps, loss_type):
|
||||
class NoiseScheduler(nn.Module):
|
||||
def __init__(self, *, beta_schedule, timesteps, loss_type, p2_loss_weight_gamma = 0., p2_loss_weight_k = 1):
|
||||
super().__init__()
|
||||
|
||||
if beta_schedule == "cosine":
|
||||
@@ -446,6 +444,11 @@ class BaseGaussianDiffusion(nn.Module):
|
||||
register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
||||
register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
|
||||
|
||||
# p2 loss reweighting
|
||||
|
||||
self.has_p2_loss_reweighting = p2_loss_weight_gamma > 0.
|
||||
register_buffer('p2_loss_weight', (p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod)) ** -p2_loss_weight_gamma)
|
||||
|
||||
def q_posterior(self, x_start, x_t, t):
|
||||
posterior_mean = (
|
||||
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
|
||||
@@ -469,11 +472,10 @@ class BaseGaussianDiffusion(nn.Module):
|
||||
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
|
||||
)
|
||||
|
||||
def sample(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
def p2_reweigh_loss(self, loss, times):
|
||||
if not self.has_p2_loss_reweighting:
|
||||
return loss
|
||||
return loss * extract(self.p2_loss_weight, times, loss.shape)
|
||||
|
||||
# diffusion prior
|
||||
|
||||
@@ -684,8 +686,7 @@ class Attention(nn.Module):
|
||||
|
||||
# attention
|
||||
|
||||
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
|
||||
attn = sim.softmax(dim = -1)
|
||||
attn = sim.softmax(dim = -1, dtype = torch.float32)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
# aggregate values
|
||||
@@ -859,7 +860,7 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
|
||||
return pred_image_embed
|
||||
|
||||
class DiffusionPrior(BaseGaussianDiffusion):
|
||||
class DiffusionPrior(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
net,
|
||||
@@ -880,13 +881,17 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
image_embed_scale = None, # this is for scaling the l2-normed image embedding, so it is more suitable for gaussian diffusion, as outlined by Katherine (@crowsonkb) https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
|
||||
clip_adapter_overrides = dict()
|
||||
):
|
||||
super().__init__(
|
||||
super().__init__()
|
||||
|
||||
self.noise_scheduler = NoiseScheduler(
|
||||
beta_schedule = beta_schedule,
|
||||
timesteps = timesteps,
|
||||
loss_type = loss_type
|
||||
)
|
||||
|
||||
if exists(clip):
|
||||
assert image_channels == clip.image_channels, f'channels of image ({image_channels}) should be equal to the channels that CLIP accepts ({clip.image_channels})'
|
||||
|
||||
if isinstance(clip, CLIP):
|
||||
clip = XClipAdapter(clip, **clip_adapter_overrides)
|
||||
elif isinstance(clip, CoCa):
|
||||
@@ -918,6 +923,13 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
self.training_clamp_l2norm = training_clamp_l2norm
|
||||
self.init_image_embed_l2norm = init_image_embed_l2norm
|
||||
|
||||
# device tracker
|
||||
self.register_buffer('_dummy', torch.tensor([True]), persistent = False)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self._dummy.device
|
||||
|
||||
def p_mean_variance(self, x, t, text_cond, clip_denoised = False, cond_scale = 1.):
|
||||
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the model was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
|
||||
|
||||
@@ -928,7 +940,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
# not 100% sure of this above line - for any spectators, let me know in the github issues (or through a pull request) if you know how to correctly do this
|
||||
# i'll be rereading https://arxiv.org/abs/2111.14822, where i think a similar approach is taken
|
||||
else:
|
||||
x_recon = self.predict_start_from_noise(x, t = t, noise = pred)
|
||||
x_recon = self.noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
|
||||
|
||||
if clip_denoised and not self.predict_x_start:
|
||||
x_recon.clamp_(-1., 1.)
|
||||
@@ -936,21 +948,21 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
if self.predict_x_start and self.sampling_clamp_l2norm:
|
||||
x_recon = l2norm(x_recon) * self.image_embed_scale
|
||||
|
||||
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
||||
model_mean, posterior_variance, posterior_log_variance = self.noise_scheduler.q_posterior(x_start=x_recon, x_t=x, t=t)
|
||||
return model_mean, posterior_variance, posterior_log_variance
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample(self, x, t, text_cond = None, clip_denoised = True, repeat_noise = False, cond_scale = 1.):
|
||||
def p_sample(self, x, t, text_cond = None, clip_denoised = True, cond_scale = 1.):
|
||||
b, *_, device = *x.shape, x.device
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, text_cond = text_cond, clip_denoised = clip_denoised, cond_scale = cond_scale)
|
||||
noise = noise_like(x.shape, device, repeat_noise)
|
||||
noise = torch.randn_like(x)
|
||||
# no noise when t == 0
|
||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_loop(self, shape, text_cond, cond_scale = 1.):
|
||||
device = self.betas.device
|
||||
device = self.device
|
||||
|
||||
b = shape[0]
|
||||
image_embed = torch.randn(shape, device=device)
|
||||
@@ -958,7 +970,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
if self.init_image_embed_l2norm:
|
||||
image_embed = l2norm(image_embed) * self.image_embed_scale
|
||||
|
||||
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
|
||||
for i in tqdm(reversed(range(0, self.noise_scheduler.num_timesteps)), desc='sampling loop time step', total=self.noise_scheduler.num_timesteps):
|
||||
times = torch.full((b,), i, device = device, dtype = torch.long)
|
||||
image_embed = self.p_sample(image_embed, times, text_cond = text_cond, cond_scale = cond_scale)
|
||||
|
||||
@@ -967,7 +979,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
def p_losses(self, image_embed, times, text_cond, noise = None):
|
||||
noise = default(noise, lambda: torch.randn_like(image_embed))
|
||||
|
||||
image_embed_noisy = self.q_sample(x_start = image_embed, t = times, noise = noise)
|
||||
image_embed_noisy = self.noise_scheduler.q_sample(x_start = image_embed, t = times, noise = noise)
|
||||
|
||||
pred = self.net(
|
||||
image_embed_noisy,
|
||||
@@ -981,7 +993,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
|
||||
target = noise if not self.predict_x_start else image_embed
|
||||
|
||||
loss = self.loss_fn(pred, target)
|
||||
loss = self.noise_scheduler.loss_fn(pred, target)
|
||||
return loss
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -992,7 +1004,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
|
||||
img = torch.randn(shape, device = device)
|
||||
|
||||
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
|
||||
for i in tqdm(reversed(range(0, self.noise_scheduler.num_timesteps)), desc = 'sampling loop time step', total = self.noise_scheduler.num_timesteps):
|
||||
img = self.p_sample(img, torch.full((batch_size,), i, device = device, dtype = torch.long), text_cond = text_cond, cond_scale = cond_scale)
|
||||
return img
|
||||
|
||||
@@ -1064,7 +1076,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
# timestep conditioning from ddpm
|
||||
|
||||
batch, device = image_embed.shape[0], image_embed.device
|
||||
times = torch.randint(0, self.num_timesteps, (batch,), device = device, dtype = torch.long)
|
||||
times = torch.randint(0, self.noise_scheduler.num_timesteps, (batch,), device = device, dtype = torch.long)
|
||||
|
||||
# scale image embed (Katherine)
|
||||
|
||||
@@ -1079,8 +1091,9 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
def Upsample(dim):
|
||||
return nn.ConvTranspose2d(dim, dim, 4, 2, 1)
|
||||
|
||||
def Downsample(dim):
|
||||
return nn.Conv2d(dim, dim, 4, 2, 1)
|
||||
def Downsample(dim, *, dim_out = None):
|
||||
dim_out = default(dim_out, dim)
|
||||
return nn.Conv2d(dim, dim_out, 4, 2, 1)
|
||||
|
||||
class SinusoidalPosEmb(nn.Module):
|
||||
def __init__(self, dim):
|
||||
@@ -1102,13 +1115,20 @@ class Block(nn.Module):
|
||||
groups = 8
|
||||
):
|
||||
super().__init__()
|
||||
self.block = nn.Sequential(
|
||||
nn.Conv2d(dim, dim_out, 3, padding = 1),
|
||||
nn.GroupNorm(groups, dim_out),
|
||||
nn.SiLU()
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.block(x)
|
||||
self.project = nn.Conv2d(dim, dim_out, 3, padding = 1)
|
||||
self.norm = nn.GroupNorm(groups, dim_out)
|
||||
self.act = nn.SiLU()
|
||||
|
||||
def forward(self, x, scale_shift = None):
|
||||
x = self.project(x)
|
||||
x = self.norm(x)
|
||||
|
||||
if exists(scale_shift):
|
||||
scale, shift = scale_shift
|
||||
x = x * (scale + 1) + shift
|
||||
|
||||
x = self.act(x)
|
||||
return x
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
def __init__(
|
||||
@@ -1127,7 +1147,7 @@ class ResnetBlock(nn.Module):
|
||||
if exists(time_cond_dim):
|
||||
self.time_mlp = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(time_cond_dim, dim_out)
|
||||
nn.Linear(time_cond_dim, dim_out * 2)
|
||||
)
|
||||
|
||||
self.cross_attn = None
|
||||
@@ -1147,11 +1167,14 @@ class ResnetBlock(nn.Module):
|
||||
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
|
||||
|
||||
def forward(self, x, cond = None, time_emb = None):
|
||||
h = self.block1(x)
|
||||
|
||||
scale_shift = None
|
||||
if exists(self.time_mlp) and exists(time_emb):
|
||||
time_emb = self.time_mlp(time_emb)
|
||||
h = rearrange(time_emb, 'b c -> b c 1 1') + h
|
||||
time_emb = rearrange(time_emb, 'b c -> b c 1 1')
|
||||
scale_shift = time_emb.chunk(2, dim = 1)
|
||||
|
||||
h = self.block1(x, scale_shift = scale_shift)
|
||||
|
||||
if exists(self.cross_attn):
|
||||
assert exists(cond)
|
||||
@@ -1218,8 +1241,7 @@ class CrossAttention(nn.Module):
|
||||
mask = rearrange(mask, 'b j -> b 1 1 j')
|
||||
sim = sim.masked_fill(~mask, max_neg_value)
|
||||
|
||||
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
|
||||
attn = sim.softmax(dim = -1)
|
||||
attn = sim.softmax(dim = -1, dtype = torch.float32)
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
@@ -1328,12 +1350,15 @@ class Unet(nn.Module):
|
||||
cond_on_text_encodings = False,
|
||||
max_text_len = 256,
|
||||
cond_on_image_embeds = False,
|
||||
add_image_embeds_to_time = True, # alerted by @mhh0318 to a phrase in the paper - "Specifically, we modify the architecture described in Nichol et al. (2021) by projecting and adding CLIP embeddings to the existing timestep embedding"
|
||||
init_dim = None,
|
||||
init_conv_kernel_size = 7,
|
||||
resnet_groups = 8,
|
||||
num_resnet_blocks = 2,
|
||||
init_cross_embed_kernel_sizes = (3, 7, 15),
|
||||
cross_embed_downsample = False,
|
||||
cross_embed_downsample_kernel_sizes = (2, 4),
|
||||
memory_efficient = False,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
@@ -1353,7 +1378,7 @@ class Unet(nn.Module):
|
||||
self.channels_out = default(channels_out, channels)
|
||||
|
||||
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
|
||||
init_dim = default(init_dim, dim // 3 * 2)
|
||||
init_dim = default(init_dim, dim)
|
||||
|
||||
self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1)
|
||||
|
||||
@@ -1380,11 +1405,16 @@ class Unet(nn.Module):
|
||||
nn.Linear(time_cond_dim, time_cond_dim)
|
||||
)
|
||||
|
||||
self.image_to_cond = nn.Sequential(
|
||||
self.image_to_tokens = nn.Sequential(
|
||||
nn.Linear(image_embed_dim, cond_dim * num_image_tokens),
|
||||
Rearrange('b (n d) -> b n d', n = num_image_tokens)
|
||||
) if cond_on_image_embeds and image_embed_dim != cond_dim else nn.Identity()
|
||||
|
||||
self.to_image_hiddens = nn.Sequential(
|
||||
nn.Linear(image_embed_dim, time_cond_dim),
|
||||
nn.GELU()
|
||||
) if cond_on_image_embeds and add_image_embeds_to_time else None
|
||||
|
||||
self.norm_cond = nn.LayerNorm(cond_dim)
|
||||
self.norm_mid_cond = nn.LayerNorm(cond_dim)
|
||||
|
||||
@@ -1405,6 +1435,7 @@ class Unet(nn.Module):
|
||||
# for classifier free guidance
|
||||
|
||||
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
|
||||
self.null_image_hiddens = nn.Parameter(torch.randn(1, time_cond_dim))
|
||||
|
||||
self.max_text_len = max_text_len
|
||||
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim))
|
||||
@@ -1416,6 +1447,7 @@ class Unet(nn.Module):
|
||||
# resnet block klass
|
||||
|
||||
resnet_groups = cast_tuple(resnet_groups, len(in_out))
|
||||
num_resnet_blocks = cast_tuple(num_resnet_blocks, len(in_out))
|
||||
|
||||
assert len(resnet_groups) == len(in_out)
|
||||
|
||||
@@ -1431,16 +1463,17 @@ class Unet(nn.Module):
|
||||
self.ups = nn.ModuleList([])
|
||||
num_resolutions = len(in_out)
|
||||
|
||||
for ind, ((dim_in, dim_out), groups) in enumerate(zip(in_out, resnet_groups)):
|
||||
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks) in enumerate(zip(in_out, resnet_groups, num_resnet_blocks)):
|
||||
is_first = ind == 0
|
||||
is_last = ind >= (num_resolutions - 1)
|
||||
layer_cond_dim = cond_dim if not is_first else None
|
||||
|
||||
self.downs.append(nn.ModuleList([
|
||||
ResnetBlock(dim_in, dim_out, time_cond_dim = time_cond_dim, groups = groups),
|
||||
downsample_klass(dim_in, dim_out = dim_out) if memory_efficient else None,
|
||||
ResnetBlock(dim_out if memory_efficient else dim_in, dim_out, time_cond_dim = time_cond_dim, groups = groups),
|
||||
Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
||||
ResnetBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
|
||||
downsample_klass(dim_out) if not is_last else nn.Identity()
|
||||
nn.ModuleList([ResnetBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
|
||||
downsample_klass(dim_out) if not is_last and not memory_efficient else None
|
||||
]))
|
||||
|
||||
mid_dim = dims[-1]
|
||||
@@ -1449,19 +1482,19 @@ class Unet(nn.Module):
|
||||
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
|
||||
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
|
||||
|
||||
for ind, ((dim_in, dim_out), groups) in enumerate(zip(reversed(in_out[1:]), reversed(resnet_groups))):
|
||||
is_last = ind >= (num_resolutions - 2)
|
||||
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks) in enumerate(zip(reversed(in_out), reversed(resnet_groups), reversed(num_resnet_blocks))):
|
||||
is_last = ind >= (len(in_out) - 1)
|
||||
layer_cond_dim = cond_dim if not is_last else None
|
||||
|
||||
self.ups.append(nn.ModuleList([
|
||||
ResnetBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
|
||||
Residual(LinearAttention(dim_in, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
||||
ResnetBlock(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
|
||||
Upsample(dim_in)
|
||||
nn.ModuleList([ResnetBlock(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
|
||||
Upsample(dim_in) if not is_last or memory_efficient else nn.Identity()
|
||||
]))
|
||||
|
||||
self.final_conv = nn.Sequential(
|
||||
ResnetBlock(dim, dim, groups = resnet_groups[0]),
|
||||
ResnetBlock(dim * 2, dim, groups = resnet_groups[0]),
|
||||
nn.Conv2d(dim, self.channels_out, 1)
|
||||
)
|
||||
|
||||
@@ -1533,6 +1566,7 @@ class Unet(nn.Module):
|
||||
# initial convolution
|
||||
|
||||
x = self.init_conv(x)
|
||||
r = x.clone() # final residual
|
||||
|
||||
# time conditioning
|
||||
|
||||
@@ -1546,7 +1580,23 @@ class Unet(nn.Module):
|
||||
image_keep_mask = prob_mask_like((batch_size,), 1 - image_cond_drop_prob, device = device)
|
||||
text_keep_mask = prob_mask_like((batch_size,), 1 - text_cond_drop_prob, device = device)
|
||||
|
||||
image_keep_mask, text_keep_mask = rearrange_many((image_keep_mask, text_keep_mask), 'b -> b 1 1')
|
||||
text_keep_mask = rearrange(text_keep_mask, 'b -> b 1 1')
|
||||
|
||||
# image embedding to be summed to time embedding
|
||||
# discovered by @mhh0318 in the paper
|
||||
|
||||
if exists(image_embed) and exists(self.to_image_hiddens):
|
||||
image_hiddens = self.to_image_hiddens(image_embed)
|
||||
image_keep_mask_hidden = rearrange(image_keep_mask, 'b -> b 1')
|
||||
null_image_hiddens = self.null_image_hiddens.to(image_hiddens.dtype)
|
||||
|
||||
image_hiddens = torch.where(
|
||||
image_keep_mask_hidden,
|
||||
image_hiddens,
|
||||
null_image_hiddens
|
||||
)
|
||||
|
||||
t = t + image_hiddens
|
||||
|
||||
# mask out image embedding depending on condition dropout
|
||||
# for classifier free guidance
|
||||
@@ -1554,11 +1604,12 @@ class Unet(nn.Module):
|
||||
image_tokens = None
|
||||
|
||||
if self.cond_on_image_embeds:
|
||||
image_tokens = self.image_to_cond(image_embed)
|
||||
image_keep_mask_embed = rearrange(image_keep_mask, 'b -> b 1 1')
|
||||
image_tokens = self.image_to_tokens(image_embed)
|
||||
null_image_embed = self.null_image_embed.to(image_tokens.dtype) # for some reason pytorch AMP not working
|
||||
|
||||
image_tokens = torch.where(
|
||||
image_keep_mask,
|
||||
image_keep_mask_embed,
|
||||
image_tokens,
|
||||
null_image_embed
|
||||
)
|
||||
@@ -1613,12 +1664,20 @@ class Unet(nn.Module):
|
||||
|
||||
hiddens = []
|
||||
|
||||
for block1, sparse_attn, block2, downsample in self.downs:
|
||||
x = block1(x, c, t)
|
||||
for pre_downsample, init_block, sparse_attn, resnet_blocks, post_downsample in self.downs:
|
||||
if exists(pre_downsample):
|
||||
x = pre_downsample(x)
|
||||
|
||||
x = init_block(x, c, t)
|
||||
x = sparse_attn(x)
|
||||
x = block2(x, c, t)
|
||||
|
||||
for resnet_block in resnet_blocks:
|
||||
x = resnet_block(x, c, t)
|
||||
|
||||
hiddens.append(x)
|
||||
x = downsample(x)
|
||||
|
||||
if exists(post_downsample):
|
||||
x = post_downsample(x)
|
||||
|
||||
x = self.mid_block1(x, mid_c, t)
|
||||
|
||||
@@ -1627,20 +1686,24 @@ class Unet(nn.Module):
|
||||
|
||||
x = self.mid_block2(x, mid_c, t)
|
||||
|
||||
for block1, sparse_attn, block2, upsample in self.ups:
|
||||
x = torch.cat((x, hiddens.pop()), dim=1)
|
||||
x = block1(x, c, t)
|
||||
for init_block, sparse_attn, resnet_blocks, upsample in self.ups:
|
||||
x = torch.cat((x, hiddens.pop()), dim = 1)
|
||||
x = init_block(x, c, t)
|
||||
x = sparse_attn(x)
|
||||
x = block2(x, c, t)
|
||||
|
||||
for resnet_block in resnet_blocks:
|
||||
x = resnet_block(x, c, t)
|
||||
|
||||
x = upsample(x)
|
||||
|
||||
x = torch.cat((x, r), dim = 1)
|
||||
return self.final_conv(x)
|
||||
|
||||
class LowresConditioner(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
downsample_first = True,
|
||||
blur_sigma = 0.1,
|
||||
blur_sigma = 0.6,
|
||||
blur_kernel_size = 3,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -1664,13 +1727,25 @@ class LowresConditioner(nn.Module):
|
||||
# when training, blur the low resolution conditional image
|
||||
blur_sigma = default(blur_sigma, self.blur_sigma)
|
||||
blur_kernel_size = default(blur_kernel_size, self.blur_kernel_size)
|
||||
|
||||
# allow for drawing a random sigma between lo and hi float values
|
||||
if isinstance(blur_sigma, tuple):
|
||||
blur_sigma = tuple(map(float, blur_sigma))
|
||||
blur_sigma = random.uniform(*blur_sigma)
|
||||
|
||||
# allow for drawing a random kernel size between lo and hi int values
|
||||
if isinstance(blur_kernel_size, tuple):
|
||||
blur_kernel_size = tuple(map(int, blur_kernel_size))
|
||||
kernel_size_lo, kernel_size_hi = blur_kernel_size
|
||||
blur_kernel_size = random.randrange(kernel_size_lo, kernel_size_hi + 1)
|
||||
|
||||
cond_fmap = gaussian_blur2d(cond_fmap, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2))
|
||||
|
||||
cond_fmap = resize_image_to(cond_fmap, target_image_size)
|
||||
|
||||
return cond_fmap
|
||||
|
||||
class Decoder(BaseGaussianDiffusion):
|
||||
class Decoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
unet,
|
||||
@@ -1683,35 +1758,43 @@ class Decoder(BaseGaussianDiffusion):
|
||||
image_cond_drop_prob = 0.1,
|
||||
text_cond_drop_prob = 0.5,
|
||||
loss_type = 'l2',
|
||||
beta_schedule = 'cosine',
|
||||
beta_schedule = None,
|
||||
predict_x_start = False,
|
||||
predict_x_start_for_latent_diffusion = False,
|
||||
image_sizes = None, # for cascading ddpm, image size at each stage
|
||||
random_crop_sizes = None, # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops)
|
||||
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
|
||||
blur_sigma = 0.1, # cascading ddpm - blur sigma
|
||||
blur_sigma = (0.1, 0.2), # cascading ddpm - blur sigma
|
||||
blur_kernel_size = 3, # cascading ddpm - blur kernel size
|
||||
condition_on_text_encodings = False, # the paper suggested that this didn't do much in the decoder, but i'm allowing the option for experimentation
|
||||
clip_denoised = True,
|
||||
clip_x_start = True,
|
||||
clip_adapter_overrides = dict(),
|
||||
learned_variance = True,
|
||||
learned_variance_constrain_frac = False,
|
||||
vb_loss_weight = 0.001,
|
||||
unconditional = False
|
||||
unconditional = False, # set to True for generating images without conditioning
|
||||
auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader
|
||||
use_dynamic_thres = False, # from the Imagen paper
|
||||
dynamic_thres_percentile = 0.9,
|
||||
p2_loss_weight_gamma = 0., # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended
|
||||
p2_loss_weight_k = 1
|
||||
):
|
||||
super().__init__(
|
||||
beta_schedule = beta_schedule,
|
||||
timesteps = timesteps,
|
||||
loss_type = loss_type
|
||||
)
|
||||
super().__init__()
|
||||
|
||||
self.unconditional = unconditional
|
||||
assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present'
|
||||
|
||||
assert self.unconditional or (exists(clip) ^ exists(image_size)), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)'
|
||||
# text conditioning
|
||||
|
||||
assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present'
|
||||
self.condition_on_text_encodings = condition_on_text_encodings
|
||||
|
||||
# clip
|
||||
|
||||
self.clip = None
|
||||
if exists(clip):
|
||||
assert not unconditional, 'clip must not be given if doing unconditional image training'
|
||||
assert channels == clip.image_channels, f'channels of image ({channels}) should be equal to the channels that CLIP accepts ({clip.image_channels})'
|
||||
|
||||
if isinstance(clip, CLIP):
|
||||
clip = XClipAdapter(clip, **clip_adapter_overrides)
|
||||
elif isinstance(clip, CoCa):
|
||||
@@ -1721,24 +1804,34 @@ class Decoder(BaseGaussianDiffusion):
|
||||
assert isinstance(clip, BaseClipAdapter)
|
||||
|
||||
self.clip = clip
|
||||
self.clip_image_size = clip.image_size
|
||||
self.channels = clip.image_channels
|
||||
else:
|
||||
self.clip_image_size = image_size
|
||||
self.channels = channels
|
||||
|
||||
self.condition_on_text_encodings = condition_on_text_encodings
|
||||
# determine image size, with image_size and image_sizes taking precedence
|
||||
|
||||
if exists(image_size) or exists(image_sizes):
|
||||
assert exists(image_size) ^ exists(image_sizes), 'only one of image_size or image_sizes must be given'
|
||||
image_size = default(image_size, lambda: image_sizes[-1])
|
||||
elif exists(clip):
|
||||
image_size = clip.image_size
|
||||
else:
|
||||
raise Error('either image_size, image_sizes, or clip must be given to decoder')
|
||||
|
||||
# channels
|
||||
|
||||
self.channels = channels
|
||||
|
||||
# automatically take care of ensuring that first unet is unconditional
|
||||
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
|
||||
|
||||
unets = cast_tuple(unet)
|
||||
num_unets = len(unets)
|
||||
|
||||
vaes = pad_tuple_to_length(cast_tuple(vae), len(unets), fillvalue = NullVQGanVAE(channels = self.channels))
|
||||
|
||||
# whether to use learned variance, defaults to True for the first unet in the cascade, as in paper
|
||||
|
||||
learned_variance = pad_tuple_to_length(cast_tuple(learned_variance), len(unets), fillvalue = False)
|
||||
self.learned_variance = learned_variance
|
||||
self.learned_variance_constrain_frac = learned_variance_constrain_frac # whether to constrain the output of the network (the interpolation fraction) from 0 to 1
|
||||
self.vb_loss_weight = vb_loss_weight
|
||||
|
||||
# construct unets and vaes
|
||||
@@ -1758,8 +1851,8 @@ class Decoder(BaseGaussianDiffusion):
|
||||
|
||||
one_unet = one_unet.cast_model_parameters(
|
||||
lowres_cond = not is_first,
|
||||
cond_on_image_embeds = is_first and not unconditional,
|
||||
cond_on_text_encodings = one_unet.cond_on_text_encodings and not unconditional,
|
||||
cond_on_image_embeds = not unconditional and is_first,
|
||||
cond_on_text_encodings = not unconditional and (is_first or one_unet.cond_on_text_encodings),
|
||||
channels = unet_channels,
|
||||
channels_out = unet_channels_out
|
||||
)
|
||||
@@ -1767,9 +1860,30 @@ class Decoder(BaseGaussianDiffusion):
|
||||
self.unets.append(one_unet)
|
||||
self.vaes.append(one_vae.copy_for_eval())
|
||||
|
||||
# create noise schedulers per unet
|
||||
|
||||
if not exists(beta_schedule):
|
||||
beta_schedule = ('cosine', *(('cosine',) * max(num_unets - 2, 0)), *(('linear',) * int(num_unets > 1)))
|
||||
|
||||
beta_schedule = cast_tuple(beta_schedule, num_unets)
|
||||
p2_loss_weight_gamma = cast_tuple(p2_loss_weight_gamma, num_unets)
|
||||
|
||||
self.noise_schedulers = nn.ModuleList([])
|
||||
|
||||
for unet_beta_schedule, unet_p2_loss_weight_gamma in zip(beta_schedule, p2_loss_weight_gamma):
|
||||
noise_scheduler = NoiseScheduler(
|
||||
beta_schedule = unet_beta_schedule,
|
||||
timesteps = timesteps,
|
||||
loss_type = loss_type,
|
||||
p2_loss_weight_gamma = unet_p2_loss_weight_gamma,
|
||||
p2_loss_weight_k = p2_loss_weight_k
|
||||
)
|
||||
|
||||
self.noise_schedulers.append(noise_scheduler)
|
||||
|
||||
# unet image sizes
|
||||
|
||||
image_sizes = default(image_sizes, (self.clip_image_size,))
|
||||
image_sizes = default(image_sizes, (image_size,))
|
||||
image_sizes = tuple(sorted(set(image_sizes)))
|
||||
|
||||
assert len(self.unets) == len(image_sizes), f'you did not supply the correct number of u-nets ({len(self.unets)}) for resolutions {image_sizes}'
|
||||
@@ -1806,6 +1920,24 @@ class Decoder(BaseGaussianDiffusion):
|
||||
self.clip_denoised = clip_denoised
|
||||
self.clip_x_start = clip_x_start
|
||||
|
||||
# dynamic thresholding settings, if clipping denoised during sampling
|
||||
|
||||
self.use_dynamic_thres = use_dynamic_thres
|
||||
self.dynamic_thres_percentile = dynamic_thres_percentile
|
||||
|
||||
# normalize and unnormalize image functions
|
||||
|
||||
self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity
|
||||
self.unnormalize_img = unnormalize_zero_to_one if auto_normalize_img else identity
|
||||
|
||||
# device tracker
|
||||
|
||||
self.register_buffer('_dummy', torch.Tensor([True]), persistent = False)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self._dummy.device
|
||||
|
||||
def get_unet(self, unet_number):
|
||||
assert 0 < unet_number <= len(self.unets)
|
||||
index = unet_number - 1
|
||||
@@ -1829,7 +1961,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
for unet, device in zip(self.unets, devices):
|
||||
unet.to(device)
|
||||
|
||||
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None):
|
||||
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None):
|
||||
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
|
||||
|
||||
pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img))
|
||||
@@ -1840,46 +1972,63 @@ class Decoder(BaseGaussianDiffusion):
|
||||
if predict_x_start:
|
||||
x_recon = pred
|
||||
else:
|
||||
x_recon = self.predict_start_from_noise(x, t = t, noise = pred)
|
||||
x_recon = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
|
||||
|
||||
if clip_denoised:
|
||||
x_recon.clamp_(-1., 1.)
|
||||
# s is the threshold amount
|
||||
# static thresholding would just be s = 1
|
||||
s = 1.
|
||||
if self.use_dynamic_thres:
|
||||
s = torch.quantile(
|
||||
rearrange(x_recon, 'b ... -> b (...)').abs(),
|
||||
self.dynamic_thres_percentile,
|
||||
dim = -1
|
||||
)
|
||||
|
||||
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
||||
s.clamp_(min = 1.)
|
||||
s = s.view(-1, *((1,) * (x_recon.ndim - 1)))
|
||||
|
||||
# clip by threshold, depending on whether static or dynamic
|
||||
x_recon = x_recon.clamp(-s, s) / s
|
||||
|
||||
model_mean, posterior_variance, posterior_log_variance = noise_scheduler.q_posterior(x_start=x_recon, x_t=x, t=t)
|
||||
|
||||
if learned_variance:
|
||||
# if learned variance, posterio variance and posterior log variance are predicted by the network
|
||||
# by an interpolation of the max and min log beta values
|
||||
# eq 15 - https://arxiv.org/abs/2102.09672
|
||||
min_log = extract(self.posterior_log_variance_clipped, t, x.shape)
|
||||
max_log = extract(torch.log(self.betas), t, x.shape)
|
||||
min_log = extract(noise_scheduler.posterior_log_variance_clipped, t, x.shape)
|
||||
max_log = extract(torch.log(noise_scheduler.betas), t, x.shape)
|
||||
var_interp_frac = unnormalize_zero_to_one(var_interp_frac_unnormalized)
|
||||
|
||||
if self.learned_variance_constrain_frac:
|
||||
var_interp_frac = var_interp_frac.sigmoid()
|
||||
|
||||
posterior_log_variance = var_interp_frac * max_log + (1 - var_interp_frac) * min_log
|
||||
posterior_variance = posterior_log_variance.exp()
|
||||
|
||||
return model_mean, posterior_variance, posterior_log_variance
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True, repeat_noise = False):
|
||||
def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True):
|
||||
b, *_, device = *x.shape, x.device
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, learned_variance = learned_variance)
|
||||
noise = noise_like(x.shape, device, repeat_noise)
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, noise_scheduler = noise_scheduler, learned_variance = learned_variance)
|
||||
noise = torch.randn_like(x)
|
||||
# no noise when t == 0
|
||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1, is_latent_diffusion = False):
|
||||
device = self.betas.device
|
||||
def p_sample_loop(self, unet, shape, image_embed, noise_scheduler, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1, is_latent_diffusion = False):
|
||||
device = self.device
|
||||
|
||||
b = shape[0]
|
||||
img = torch.randn(shape, device = device)
|
||||
|
||||
if not is_latent_diffusion:
|
||||
lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img)
|
||||
lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
|
||||
|
||||
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
|
||||
for i in tqdm(reversed(range(0, noise_scheduler.num_timesteps)), desc = 'sampling loop time step', total = noise_scheduler.num_timesteps):
|
||||
img = self.p_sample(
|
||||
unet,
|
||||
img,
|
||||
@@ -1890,25 +2039,26 @@ class Decoder(BaseGaussianDiffusion):
|
||||
cond_scale = cond_scale,
|
||||
lowres_cond_img = lowres_cond_img,
|
||||
predict_x_start = predict_x_start,
|
||||
noise_scheduler = noise_scheduler,
|
||||
learned_variance = learned_variance,
|
||||
clip_denoised = clip_denoised
|
||||
)
|
||||
|
||||
unnormalize_img = unnormalize_zero_to_one(img)
|
||||
unnormalize_img = self.unnormalize_img(img)
|
||||
return unnormalize_img
|
||||
|
||||
def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False):
|
||||
def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False):
|
||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||
|
||||
# normalize to [-1, 1]
|
||||
|
||||
if not is_latent_diffusion:
|
||||
x_start = normalize_neg_one_to_one(x_start)
|
||||
lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img)
|
||||
x_start = self.normalize_img(x_start)
|
||||
lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
|
||||
|
||||
# get x_t
|
||||
|
||||
x_noisy = self.q_sample(x_start = x_start, t = times, noise = noise)
|
||||
x_noisy = noise_scheduler.q_sample(x_start = x_start, t = times, noise = noise)
|
||||
|
||||
model_output = unet(
|
||||
x_noisy,
|
||||
@@ -1928,7 +2078,12 @@ class Decoder(BaseGaussianDiffusion):
|
||||
|
||||
target = noise if not predict_x_start else x_start
|
||||
|
||||
loss = self.loss_fn(pred, target)
|
||||
loss = noise_scheduler.loss_fn(pred, target, reduction = 'none')
|
||||
loss = reduce(loss, 'b ... -> b (...)', 'mean')
|
||||
|
||||
loss = noise_scheduler.p2_reweigh_loss(loss, times)
|
||||
|
||||
loss = loss.mean()
|
||||
|
||||
if not learned_variance:
|
||||
# return simple loss if not using learned variance
|
||||
@@ -1941,8 +2096,8 @@ class Decoder(BaseGaussianDiffusion):
|
||||
|
||||
# if learning the variance, also include the extra weight kl loss
|
||||
|
||||
true_mean, _, true_log_variance_clipped = self.q_posterior(x_start = x_start, x_t = x_noisy, t = times)
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x_noisy, t = times, image_embed = image_embed, clip_denoised = clip_denoised, learned_variance = True, model_output = model_output)
|
||||
true_mean, _, true_log_variance_clipped = noise_scheduler.q_posterior(x_start = x_start, x_t = x_noisy, t = times)
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x_noisy, t = times, image_embed = image_embed, noise_scheduler = noise_scheduler, clip_denoised = clip_denoised, learned_variance = True, model_output = model_output)
|
||||
|
||||
# kl loss with detached model predicted mean, for stability reasons as in paper
|
||||
|
||||
@@ -1974,7 +2129,8 @@ class Decoder(BaseGaussianDiffusion):
|
||||
text_encodings = None,
|
||||
batch_size = 1,
|
||||
cond_scale = 1.,
|
||||
stop_at_unet_number = None
|
||||
stop_at_unet_number = None,
|
||||
distributed = False,
|
||||
):
|
||||
assert self.unconditional or exists(image_embed), 'image embed must be present on sampling from decoder unless if trained unconditionally'
|
||||
|
||||
@@ -1991,9 +2147,9 @@ class Decoder(BaseGaussianDiffusion):
|
||||
img = None
|
||||
is_cuda = next(self.parameters()).is_cuda
|
||||
|
||||
for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance)):
|
||||
for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance, noise_scheduler in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance, self.noise_schedulers)):
|
||||
|
||||
context = self.one_unet_in_gpu(unet = unet) if is_cuda else null_context()
|
||||
context = self.one_unet_in_gpu(unet = unet) if is_cuda and not distributed else null_context()
|
||||
|
||||
with context:
|
||||
lowres_cond_img = None
|
||||
@@ -2019,7 +2175,8 @@ class Decoder(BaseGaussianDiffusion):
|
||||
learned_variance = learned_variance,
|
||||
clip_denoised = not is_latent_diffusion,
|
||||
lowres_cond_img = lowres_cond_img,
|
||||
is_latent_diffusion = is_latent_diffusion
|
||||
is_latent_diffusion = is_latent_diffusion,
|
||||
noise_scheduler = noise_scheduler
|
||||
)
|
||||
|
||||
img = vae.decode(img)
|
||||
@@ -2045,6 +2202,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
unet = self.get_unet(unet_number)
|
||||
|
||||
vae = self.vaes[unet_index]
|
||||
noise_scheduler = self.noise_schedulers[unet_index]
|
||||
target_image_size = self.image_sizes[unet_index]
|
||||
predict_x_start = self.predict_x_start[unet_index]
|
||||
random_crop_size = self.random_crop_sizes[unet_index]
|
||||
@@ -2054,7 +2212,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
check_shape(image, 'b c h w', c = self.channels)
|
||||
assert h >= target_image_size and w >= target_image_size
|
||||
|
||||
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
|
||||
times = torch.randint(0, noise_scheduler.num_timesteps, (b,), device = device, dtype = torch.long)
|
||||
|
||||
if not exists(image_embed) and not self.unconditional:
|
||||
assert exists(self.clip), 'if you want to derive CLIP image embeddings automatically, you must supply `clip` to the decoder on init'
|
||||
@@ -2085,7 +2243,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
image = vae.encode(image)
|
||||
lowres_cond_img = maybe(vae.encode)(lowres_cond_img)
|
||||
|
||||
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion)
|
||||
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler)
|
||||
|
||||
# main class
|
||||
|
||||
|
||||
75
dalle2_pytorch/dataloaders/README.md
Normal file
75
dalle2_pytorch/dataloaders/README.md
Normal file
@@ -0,0 +1,75 @@
|
||||
## Dataloaders
|
||||
In order to make loading data simple and efficient, we include some general dataloaders that can be used to train portions of the network.
|
||||
|
||||
### Decoder: Image Embedding Dataset
|
||||
When training the decoder (and up samplers if training together) in isolation, you will need to load images and corresponding image embeddings. This dataset can read two similar types of datasets. First, it can read a [webdataset](https://github.com/webdataset/webdataset) that contains `.jpg` and `.npy` files in the `.tar`s that contain the images and associated image embeddings respectively. Alternatively, you can also specify a source for the embeddings outside of the webdataset. In this case, the path to the embeddings should contain `.npy` files with the same shard numbers as the webdataset and there should be a correspondence between the filename of the `.jpg` and the index of the embedding in the `.npy`. So, for example, `0001.tar` from the webdataset with image `00010509.jpg` (the first 4 digits are the shard number and the last 4 are the index) in it should be paralleled by a `img_emb_0001.npy` which contains a NumPy array with the embedding at index 509.
|
||||
|
||||
Generating a dataset of this type:
|
||||
1. Use [img2dataset](https://github.com/rom1504/img2dataset) to generate a webdataset.
|
||||
2. Use [clip-retrieval](https://github.com/rom1504/clip-retrieval) to convert the images to embeddings.
|
||||
3. Use [embedding-dataset-reordering](https://github.com/Veldrovive/embedding-dataset-reordering) to reorder the embeddings into the expected format.
|
||||
|
||||
Usage:
|
||||
```python
|
||||
from dalle2_pytorch.dataloaders import ImageEmbeddingDataset, create_image_embedding_dataloader
|
||||
|
||||
# Create a dataloader directly.
|
||||
dataloader = create_image_embedding_dataloader(
|
||||
tar_url="/path/or/url/to/webdataset/{0000..9999}.tar", # Uses bracket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar
|
||||
embeddings_url="path/or/url/to/embeddings/folder", # Included if .npy files are not in webdataset. Left out or set to None otherwise
|
||||
num_workers=4,
|
||||
batch_size=32,
|
||||
shard_width=4, # If a file in the webdataset shard 3 is named 0003039.jpg, we know the shard width is 4 and the last three digits are the index
|
||||
shuffle_num=200, # Does a shuffle of the data with a buffer size of 200
|
||||
shuffle_shards=True, # Shuffle the order the shards are read in
|
||||
resample_shards=False, # Sample shards with replacement. If true, an epoch will be infinite unless stopped manually
|
||||
)
|
||||
for img, emb in dataloader:
|
||||
print(img.shape) # torch.Size([32, 3, 256, 256])
|
||||
print(emb.shape) # torch.Size([32, 512])
|
||||
# Train decoder only as shown above
|
||||
|
||||
# Or create a dataset without a loader so you can configure it manually
|
||||
dataset = ImageEmbeddingDataset(
|
||||
urls="/path/or/url/to/webdataset/{0000..9999}.tar",
|
||||
embedding_folder_url="path/or/url/to/embeddings/folder",
|
||||
shard_width=4,
|
||||
shuffle_shards=True,
|
||||
resample=False
|
||||
)
|
||||
```
|
||||
|
||||
### Diffusion Prior: Prior Embedding Dataset
|
||||
When training the prior it is much more efficient to work with pre-computed embeddings. The `PriorEmbeddingDataset` class enables you to leverage the same script (with minimal modification) for both embedding-only and text-conditioned prior training. This saves you from having to worry about a lot of the boilerplate code.
|
||||
|
||||
To utilize the `PriorEmbeddingDataset`, all you need to do is make a single call to `get_reader()` which will create `EmbeddingReader` object(s) for you. Afterwards, you can utilize `make_splits()` to cleanly create DataLoader objects from for your training run.
|
||||
|
||||
If you are training in a distributed manner, `make_splits()` accepts `rank` and `world_size` arguments to properly distribute to each process. The defaults for these values are `rank=0` and `world_size=1`, so single-process training can safely ignore these parameters.
|
||||
|
||||
Usage:
|
||||
```python
|
||||
from dalle2_pytorch.dataloaders import get_reader, make_splits
|
||||
|
||||
# grab embeddings from some specified location
|
||||
IMG_URL = "data/img_emb/"
|
||||
META_URL = "data/meta/"
|
||||
|
||||
reader = get_reader(text_conditioned=True, img_url=IMG_URL, meta_url=META_URL)
|
||||
|
||||
# some config for training
|
||||
TRAIN_ARGS = {
|
||||
"world_size": 3,
|
||||
"text_conditioned": True,
|
||||
"start": 0,
|
||||
"num_data_points": 10000,
|
||||
"batch_size": 2,
|
||||
"train_split": 0.5,
|
||||
"eval_split": 0.25,
|
||||
"image_reader": reader,
|
||||
}
|
||||
|
||||
# specifying a rank will handle allocation internally
|
||||
rank0_train, rank0_eval, rank0_test = make_splits(rank=0, **TRAIN_ARGS)
|
||||
rank1_train, rank1_eval, rank1_test = make_splits(rank=1, **TRAIN_ARGS)
|
||||
rank2_train, rank2_eval, rank2_test = make_splits(rank=2, **TRAIN_ARGS)
|
||||
```
|
||||
@@ -1,2 +1,2 @@
|
||||
from dalle2_pytorch.dataloaders.decoder_loader import ImageEmbeddingDataset, create_image_embedding_dataloader
|
||||
from dalle2_pytorch.dataloaders.embedding_wrapper import make_splits
|
||||
from dalle2_pytorch.dataloaders.prior_loader import make_splits, get_reader, PriorEmbeddingDataset
|
||||
|
||||
@@ -3,6 +3,7 @@ import webdataset as wds
|
||||
import torch
|
||||
import numpy as np
|
||||
import fsspec
|
||||
import shutil
|
||||
|
||||
def get_shard(filename):
|
||||
"""
|
||||
@@ -20,7 +21,7 @@ def get_example_file(fs, path, file_format):
|
||||
"""
|
||||
return fs.glob(os.path.join(path, f"*.{file_format}"))[0]
|
||||
|
||||
def embedding_inserter(samples, embeddings_url, shard_width, handler=wds.handlers.reraise_exception):
|
||||
def embedding_inserter(samples, embeddings_url, index_width, sample_key='npy', handler=wds.handlers.reraise_exception):
|
||||
"""Given a datum of {"__key__": str, "__url__": str, ...} adds the cooresponding embedding and yields"""
|
||||
previous_tar_url = None
|
||||
current_embeddings = None
|
||||
@@ -50,8 +51,12 @@ def embedding_inserter(samples, embeddings_url, shard_width, handler=wds.handler
|
||||
previous_tar_url = tar_url
|
||||
current_embeddings = load_corresponding_embeds(tar_url)
|
||||
|
||||
embedding_index = int(key[shard_width:])
|
||||
sample["npy"] = current_embeddings[embedding_index]
|
||||
embedding_index = int(key[-index_width:])
|
||||
embedding = current_embeddings[embedding_index]
|
||||
# We need to check if this sample is nonzero. If it is, this embedding is not valid and we should continue to the next loop
|
||||
if torch.count_nonzero(embedding) == 0:
|
||||
raise RuntimeError(f"Webdataset had a sample, but no embedding was found. ImgShard: {key[:-index_width]} - Index: {key[-index_width:]}")
|
||||
sample[sample_key] = embedding
|
||||
yield sample
|
||||
except Exception as exn: # From wds implementation
|
||||
if handler(exn):
|
||||
@@ -60,15 +65,39 @@ def embedding_inserter(samples, embeddings_url, shard_width, handler=wds.handler
|
||||
break
|
||||
insert_embedding = wds.filters.pipelinefilter(embedding_inserter)
|
||||
|
||||
def verify_keys(samples, handler=wds.handlers.reraise_exception):
|
||||
def unassociated_shard_skipper(tarfiles, embeddings_url, handler=wds.handlers.reraise_exception):
|
||||
"""Finds if the is a corresponding embedding for the tarfile at { url: [URL] }"""
|
||||
embeddings_fs, embeddings_path = fsspec.core.url_to_fs(embeddings_url)
|
||||
embedding_files = embeddings_fs.ls(embeddings_path)
|
||||
get_embedding_shard = lambda embedding_file: int(embedding_file.split("_")[-1].split(".")[0])
|
||||
embedding_shards = set([get_embedding_shard(filename) for filename in embedding_files]) # Sets have O(1) check for member
|
||||
|
||||
get_tar_shard = lambda tar_file: int(tar_file.split("/")[-1].split(".")[0])
|
||||
for tarfile in tarfiles:
|
||||
try:
|
||||
webdataset_shard = get_tar_shard(tarfile["url"])
|
||||
# If this shard has an associated embeddings file, we pass it through. Otherwise we iterate until we do have one
|
||||
if webdataset_shard in embedding_shards:
|
||||
yield tarfile
|
||||
except Exception as exn: # From wds implementation
|
||||
if handler(exn):
|
||||
continue
|
||||
else:
|
||||
break
|
||||
skip_unassociated_shards = wds.filters.pipelinefilter(unassociated_shard_skipper)
|
||||
|
||||
def join_embeddings(samples, handler=wds.handlers.reraise_exception):
|
||||
"""
|
||||
Requires that both the image and embedding are present in the sample
|
||||
This is important to do as a user may forget they do not have embeddings in their webdataset and neglect to add them using the embedding_folder_url parameter.
|
||||
Takes the img_emb and text_emb keys and turns them into one key "emb": { "text": text_emb, "img": img_emb }
|
||||
either or both of text_emb and img_emb may not be in the sample so we only add the ones that exist
|
||||
"""
|
||||
for sample in samples:
|
||||
try:
|
||||
assert "jpg" in sample, f"Sample {sample['__key__']} missing image"
|
||||
assert "npy" in sample, f"Sample {sample['__key__']} missing embedding. Did you set embedding_folder_url?"
|
||||
sample['emb'] = {}
|
||||
if 'text_emb' in sample:
|
||||
sample['emb']['text'] = sample['text_emb']
|
||||
if 'img_emb' in sample:
|
||||
sample['emb']['img'] = sample['img_emb']
|
||||
yield sample
|
||||
except Exception as exn: # From wds implementation
|
||||
if handler(exn):
|
||||
@@ -76,6 +105,23 @@ def verify_keys(samples, handler=wds.handlers.reraise_exception):
|
||||
else:
|
||||
break
|
||||
|
||||
def verify_keys(samples, required_keys, handler=wds.handlers.reraise_exception):
|
||||
"""
|
||||
Requires that both the image and embedding are present in the sample
|
||||
This is important to do as a user may forget they do not have embeddings in their webdataset and neglect to add them using the embedding_folder_url parameter.
|
||||
"""
|
||||
for sample in samples:
|
||||
try:
|
||||
for key in required_keys:
|
||||
assert key in sample, f"Sample {sample['__key__']} missing {key}. Has keys {sample.keys()}"
|
||||
yield sample
|
||||
except Exception as exn: # From wds implementation
|
||||
if handler(exn):
|
||||
continue
|
||||
else:
|
||||
break
|
||||
key_verifier = wds.filters.pipelinefilter(verify_keys)
|
||||
|
||||
class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
|
||||
"""
|
||||
A fluid interface wrapper for DataPipline that returns image embedding pairs
|
||||
@@ -85,8 +131,11 @@ class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
|
||||
def __init__(
|
||||
self,
|
||||
urls,
|
||||
embedding_folder_url=None,
|
||||
shard_width=None,
|
||||
img_embedding_folder_url=None,
|
||||
text_embedding_folder_url=None,
|
||||
index_width=None,
|
||||
img_preproc=None,
|
||||
extra_keys=[],
|
||||
handler=wds.handlers.reraise_exception,
|
||||
resample=False,
|
||||
shuffle_shards=True
|
||||
@@ -97,13 +146,36 @@ class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
|
||||
:param urls: A url pointing to the tar files of the webdataset formatted as /path/to/webdataset/{0000..9999}.tar
|
||||
:param embedding_folder_url: Required if webdataset does not contain embeddings. A url pointing to the npy files of the embeddings. Should have the same number of shards as the webdataset.
|
||||
Webdataset image keys should align with the index of the embedding. This means missing image indices must have a corresponding embedding of all zeros.
|
||||
:param shard_width: The number of digits in the shard number. This is used to align the embedding index with the image index.
|
||||
For example, if a file in the webdataset shard 3 is named 0003039.jpg, we know the shard with this 4 and the last three digits are the index.
|
||||
:param index_width: The number of digits in the index. This is used to align the embedding index with the image index.
|
||||
For example, if a file in the webdataset shard 3 is named 0003039.jpg, we know the shard is 4 digits and the last 3 digits are the index_width.
|
||||
:param img_preproc: This function is run on the img before it is batched and returned. Useful for data augmentation or converting to torch tensor.
|
||||
:param handler: A webdataset handler.
|
||||
:param resample: If true, resample webdataset shards with replacement. You need to set your own epoch size if this is true since it will resample infinitely.
|
||||
:param shuffle_shards: If true, shuffle the shards before resampling. This cannot be true if resample is true.
|
||||
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
keys = ["jpg", "emb"] + extra_keys
|
||||
# if img_embedding_folder_url is not None:
|
||||
# keys.append("img_emb")
|
||||
# if text_embedding_folder_url is not None:
|
||||
# keys.append("text_emb")
|
||||
# keys.extend(extra_keys)
|
||||
self.key_map = {key: i for i, key in enumerate(keys)}
|
||||
self.resampling = resample
|
||||
self.img_preproc = img_preproc
|
||||
# If s3, check if s3fs is installed and s3cmd is installed and check if the data is piped instead of straight up
|
||||
if (isinstance(urls, str) and "s3:" in urls) or (isinstance(urls, list) and any(["s3:" in url for url in urls])):
|
||||
# Then this has an s3 link for the webdataset and we need extra packages
|
||||
if shutil.which("s3cmd") is None:
|
||||
raise RuntimeError("s3cmd is required for s3 webdataset")
|
||||
if (img_embedding_folder_url is not None and "s3:" in img_embedding_folder_url) or (text_embedding_folder_url is not None and "s3:" in text_embedding_folder_url):
|
||||
# Then the embeddings are being loaded from s3 and fsspec requires s3fs
|
||||
try:
|
||||
import s3fs
|
||||
except ImportError:
|
||||
raise RuntimeError("s3fs is required to load embeddings from s3")
|
||||
# Add the shardList and randomize or resample if requested
|
||||
if resample:
|
||||
assert not shuffle_shards, "Cannot both resample and shuffle"
|
||||
@@ -112,28 +184,48 @@ class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
|
||||
self.append(wds.SimpleShardList(urls))
|
||||
if shuffle_shards:
|
||||
self.append(wds.filters.shuffle(1000))
|
||||
|
||||
self.append(wds.split_by_node)
|
||||
self.append(wds.split_by_worker)
|
||||
|
||||
if img_embedding_folder_url is not None:
|
||||
# There may be webdataset shards that do not have a embedding shard associated with it. If we do not skip these, they would cause issues.
|
||||
self.append(skip_unassociated_shards(embeddings_url=img_embedding_folder_url, handler=handler))
|
||||
if text_embedding_folder_url is not None:
|
||||
self.append(skip_unassociated_shards(embeddings_url=text_embedding_folder_url, handler=handler))
|
||||
|
||||
self.append(wds.tarfile_to_samples(handler=handler))
|
||||
self.append(wds.decode("torchrgb"))
|
||||
if embedding_folder_url is not None:
|
||||
assert shard_width is not None, "Reading embeddings separately requires shard length to be given"
|
||||
self.append(insert_embedding(embeddings_url=embedding_folder_url, shard_width=shard_width, handler=handler))
|
||||
self.append(verify_keys)
|
||||
self.append(wds.to_tuple("jpg", "npy"))
|
||||
self.append(wds.decode("pilrgb", handler=handler))
|
||||
if img_embedding_folder_url is not None:
|
||||
# Then we are loading image embeddings for a remote source
|
||||
assert index_width is not None, "Reading embeddings separately requires index width length to be given"
|
||||
self.append(insert_embedding(embeddings_url=img_embedding_folder_url, index_width=index_width, sample_key='img_emb', handler=handler))
|
||||
if text_embedding_folder_url is not None:
|
||||
# Then we are loading image embeddings for a remote source
|
||||
assert index_width is not None, "Reading embeddings separately requires index width length to be given"
|
||||
self.append(insert_embedding(embeddings_url=text_embedding_folder_url, index_width=index_width, sample_key='text_emb', handler=handler))
|
||||
self.append(join_embeddings)
|
||||
self.append(key_verifier(required_keys=keys, handler=handler))
|
||||
# Apply preprocessing
|
||||
self.append(wds.map(self.preproc))
|
||||
self.append(wds.to_tuple(*keys))
|
||||
|
||||
def preproc(self, sample):
|
||||
"""Applies the preprocessing for images"""
|
||||
if self.img_preproc is not None:
|
||||
sample["jpg"] = self.img_preproc(sample["jpg"])
|
||||
return sample
|
||||
|
||||
def create_image_embedding_dataloader(
|
||||
tar_url,
|
||||
num_workers,
|
||||
batch_size,
|
||||
embeddings_url=None,
|
||||
shard_width=None,
|
||||
img_embeddings_url=None,
|
||||
text_embeddings_url=None,
|
||||
index_width=None,
|
||||
shuffle_num = None,
|
||||
shuffle_shards = True,
|
||||
resample_shards = False,
|
||||
handler=wds.handlers.warn_and_continue
|
||||
img_preproc=None,
|
||||
extra_keys=[],
|
||||
handler=wds.handlers.reraise_exception#warn_and_continue
|
||||
):
|
||||
"""
|
||||
Convenience function to create an image embedding dataseta and dataloader in one line
|
||||
@@ -143,8 +235,8 @@ def create_image_embedding_dataloader(
|
||||
:param batch_size: The batch size to use for the dataloader
|
||||
:param embeddings_url: Required if webdataset does not contain embeddings. A url pointing to the npy files of the embeddings. Should have the same number of shards as the webdataset.
|
||||
Webdataset image keys should align with the index of the embedding. This means missing image indices must have a corresponding embedding of all zeros.
|
||||
:param shard_width: The number of digits in the shard number. This is used to align the embedding index with the image index.
|
||||
For example, if a file in the webdataset shard 3 is named 0003039.jpg, we know the shard width is 4 and the last three digits are the index.
|
||||
:param index_width: The number of digits in the index. This is used to align the embedding index with the image index.
|
||||
For example, if a file in the webdataset shard 3 is named 0003039.jpg, we know the shard is 4 digits and the last 3 digits are the index_width.
|
||||
:param shuffle_num: If not None, shuffle the dataset with this size buffer after sampling.
|
||||
:param shuffle_shards: If true, shuffle the shards before sampling. This cannot be true if resample is true.
|
||||
:param resample_shards: If true, resample webdataset shards with replacement. You need to set your own epoch size if this is true since it will resample infinitely.
|
||||
@@ -152,10 +244,13 @@ def create_image_embedding_dataloader(
|
||||
"""
|
||||
ds = ImageEmbeddingDataset(
|
||||
tar_url,
|
||||
embeddings_url,
|
||||
shard_width=shard_width,
|
||||
img_embedding_folder_url=img_embeddings_url,
|
||||
text_embedding_folder_url=text_embeddings_url,
|
||||
index_width=index_width,
|
||||
shuffle_shards=shuffle_shards,
|
||||
resample=resample_shards,
|
||||
extra_keys=extra_keys,
|
||||
img_preproc=img_preproc,
|
||||
handler=handler
|
||||
)
|
||||
if shuffle_num is not None and shuffle_num > 0:
|
||||
@@ -167,4 +262,4 @@ def create_image_embedding_dataloader(
|
||||
prefetch_factor=2, # This might be good to have high so the next npy file is prefetched
|
||||
pin_memory=True,
|
||||
shuffle=False
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,180 +0,0 @@
|
||||
from torch.utils.data import IterableDataset
|
||||
from torch import from_numpy
|
||||
from clip import tokenize
|
||||
from embedding_reader import EmbeddingReader
|
||||
|
||||
|
||||
class PriorEmbeddingLoader(IterableDataset):
|
||||
def __init__(
|
||||
self,
|
||||
text_conditioned: bool,
|
||||
batch_size: int,
|
||||
start: int,
|
||||
stop: int,
|
||||
image_reader,
|
||||
text_reader: EmbeddingReader = None,
|
||||
device: str = "cpu",
|
||||
) -> None:
|
||||
super(PriorEmbeddingLoader).__init__()
|
||||
|
||||
self.text_conditioned = text_conditioned
|
||||
|
||||
if not self.text_conditioned:
|
||||
self.text_reader = text_reader
|
||||
|
||||
self.image_reader = image_reader
|
||||
self.batch_size = batch_size
|
||||
self.start = start
|
||||
self.stop = stop
|
||||
self.device = device
|
||||
|
||||
def __iter__(self):
|
||||
self.n = 0
|
||||
loader_args = dict(
|
||||
batch_size=self.batch_size,
|
||||
start=self.start,
|
||||
end=self.stop,
|
||||
show_progress=False,
|
||||
)
|
||||
if self.text_conditioned:
|
||||
self.loader = self.image_reader(**loader_args)
|
||||
else:
|
||||
self.loader = zip(
|
||||
self.image_reader(**loader_args), self.text_reader(**loader_args)
|
||||
)
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
try:
|
||||
return self.get_sample()
|
||||
except StopIteration:
|
||||
raise StopIteration
|
||||
|
||||
def get_sample(self):
|
||||
"""
|
||||
pre-proocess data from either reader into a common format
|
||||
"""
|
||||
self.n += 1
|
||||
|
||||
if self.text_conditioned:
|
||||
image_embedding, caption = next(self.loader)
|
||||
|
||||
image_embedding = from_numpy(image_embedding).to(self.device)
|
||||
tokenized_caption = tokenize(
|
||||
caption["caption"].to_list(), truncate=True
|
||||
).to(self.device)
|
||||
|
||||
return image_embedding, tokenized_caption
|
||||
|
||||
else:
|
||||
(image_embedding, _), (text_embedding, _) = next(self.loader)
|
||||
|
||||
image_embedding = from_numpy(image_embedding).to(self.device)
|
||||
text_embedding = from_numpy(text_embedding).to(self.device)
|
||||
|
||||
return image_embedding, text_embedding
|
||||
|
||||
|
||||
def make_splits(
|
||||
text_conditioned: bool,
|
||||
batch_size: int,
|
||||
num_data_points: int,
|
||||
train_split: float,
|
||||
eval_split: float,
|
||||
device: str,
|
||||
img_url: str,
|
||||
meta_url: str = None,
|
||||
txt_url: str = None,
|
||||
):
|
||||
|
||||
assert img_url is not None, "Must supply some image embeddings"
|
||||
|
||||
if text_conditioned:
|
||||
assert meta_url is not None, "Must supply metadata url if text-conditioning"
|
||||
image_reader = EmbeddingReader(
|
||||
embeddings_folder=img_url,
|
||||
file_format="parquet_npy",
|
||||
meta_columns=["caption"],
|
||||
metadata_folder=meta_url,
|
||||
)
|
||||
|
||||
# compute split points
|
||||
if num_data_points > image_reader.count:
|
||||
print("Specified point count is larger than the number of points available...defaulting to max length of reader.")
|
||||
num_data_points = image_reader.count
|
||||
|
||||
train_set_size = int(train_split * num_data_points)
|
||||
eval_set_size = int(eval_split * num_data_points)
|
||||
eval_stop = int(train_set_size + eval_set_size)
|
||||
|
||||
train_loader = PriorEmbeddingLoader(
|
||||
text_conditioned=text_conditioned,
|
||||
image_reader=image_reader,
|
||||
batch_size=batch_size,
|
||||
start=0,
|
||||
stop=train_set_size,
|
||||
device=device,
|
||||
)
|
||||
eval_loader = PriorEmbeddingLoader(
|
||||
text_conditioned=text_conditioned,
|
||||
image_reader=image_reader,
|
||||
batch_size=batch_size,
|
||||
start=train_set_size,
|
||||
stop=eval_stop,
|
||||
device=device,
|
||||
)
|
||||
test_loader = PriorEmbeddingLoader(
|
||||
text_conditioned=text_conditioned,
|
||||
image_reader=image_reader,
|
||||
batch_size=batch_size,
|
||||
start=eval_stop,
|
||||
stop=int(num_data_points),
|
||||
device=device,
|
||||
)
|
||||
|
||||
else:
|
||||
assert (
|
||||
txt_url is not None
|
||||
), "Must supply text embedding url if not text-conditioning"
|
||||
|
||||
image_reader = EmbeddingReader(img_url, file_format="npy")
|
||||
text_reader = EmbeddingReader(txt_url, file_format="npy")
|
||||
|
||||
# compute split points
|
||||
if num_data_points > image_reader.count:
|
||||
print("Specified point count is larger than the number of points available...defaulting to max length of reader.")
|
||||
num_data_points = image_reader.count
|
||||
|
||||
train_set_size = int(train_split * num_data_points)
|
||||
eval_set_size = int(eval_split * num_data_points)
|
||||
eval_stop = int(train_set_size + eval_set_size)
|
||||
|
||||
train_loader = PriorEmbeddingLoader(
|
||||
text_conditioned=text_conditioned,
|
||||
image_reader=image_reader,
|
||||
text_reader=text_reader,
|
||||
batch_size=batch_size,
|
||||
start=0,
|
||||
stop=train_set_size,
|
||||
device=device,
|
||||
)
|
||||
eval_loader = PriorEmbeddingLoader(
|
||||
text_conditioned=text_conditioned,
|
||||
image_reader=image_reader,
|
||||
text_reader=text_reader,
|
||||
batch_size=batch_size,
|
||||
start=train_set_size,
|
||||
stop=eval_stop,
|
||||
device=device,
|
||||
)
|
||||
test_loader = PriorEmbeddingLoader(
|
||||
text_conditioned=text_conditioned,
|
||||
image_reader=image_reader,
|
||||
text_reader=text_reader,
|
||||
batch_size=batch_size,
|
||||
start=eval_stop,
|
||||
stop=int(num_data_points),
|
||||
device=device,
|
||||
)
|
||||
|
||||
return train_loader, eval_loader, test_loader
|
||||
273
dalle2_pytorch/dataloaders/prior_loader.py
Normal file
273
dalle2_pytorch/dataloaders/prior_loader.py
Normal file
@@ -0,0 +1,273 @@
|
||||
from math import ceil
|
||||
from clip import tokenize
|
||||
from embedding_reader import EmbeddingReader
|
||||
from torch import from_numpy
|
||||
from torch.utils.data import IterableDataset, DataLoader
|
||||
|
||||
|
||||
class PriorEmbeddingDataset(IterableDataset):
|
||||
"""
|
||||
PriorEmbeddingDataset is a wrapper of EmbeddingReader.
|
||||
|
||||
It enables one to simplify the logic necessary to yield samples from
|
||||
the different EmbeddingReader configurations available.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_conditioned: bool,
|
||||
batch_size: int,
|
||||
start: int,
|
||||
stop: int,
|
||||
image_reader,
|
||||
text_reader: EmbeddingReader = None,
|
||||
) -> None:
|
||||
super(PriorEmbeddingDataset).__init__()
|
||||
|
||||
self.text_conditioned = text_conditioned
|
||||
|
||||
if not self.text_conditioned:
|
||||
self.text_reader = text_reader
|
||||
|
||||
self.image_reader = image_reader
|
||||
self.start = start
|
||||
self.stop = stop
|
||||
self.batch_size = batch_size
|
||||
|
||||
def __len__(self):
|
||||
return self.stop - self.start
|
||||
|
||||
def __iter__(self):
|
||||
# D.R.Y loader args
|
||||
loader_args = dict(
|
||||
batch_size=self.batch_size,
|
||||
start=self.start,
|
||||
end=self.stop,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
# if the data requested is text conditioned, only load images
|
||||
if self.text_conditioned:
|
||||
self.loader = self.image_reader(**loader_args)
|
||||
# otherwise, include text embeddings and bypass metadata
|
||||
else:
|
||||
self.loader = zip(
|
||||
self.image_reader(**loader_args), self.text_reader(**loader_args)
|
||||
)
|
||||
|
||||
# return the data loader in its formatted state
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
try:
|
||||
return self.get_sample()
|
||||
except StopIteration:
|
||||
raise StopIteration
|
||||
|
||||
def __str__(self):
|
||||
return f"<PriorEmbeddingDataset: start: {self.start}, stop: {self.stop}, len: {self.__len__()}>"
|
||||
|
||||
def get_sample(self):
|
||||
"""
|
||||
pre-proocess data from either reader into a common format
|
||||
"""
|
||||
if self.text_conditioned:
|
||||
image_embedding, caption = next(self.loader)
|
||||
|
||||
image_embedding = from_numpy(image_embedding)
|
||||
tokenized_caption = tokenize(caption["caption"].to_list(), truncate=True)
|
||||
|
||||
return image_embedding, tokenized_caption
|
||||
|
||||
else:
|
||||
(image_embedding, _), (text_embedding, _) = next(self.loader)
|
||||
|
||||
image_embedding = from_numpy(image_embedding)
|
||||
text_embedding = from_numpy(text_embedding)
|
||||
|
||||
return image_embedding, text_embedding
|
||||
|
||||
|
||||
# helper functions
|
||||
|
||||
|
||||
def distribute_to_rank(start, stop, rank, world_size):
|
||||
"""
|
||||
Distribute data to each rank given the world size.
|
||||
|
||||
Return:
|
||||
- New start and stop points for this rank.
|
||||
"""
|
||||
num_samples = int(stop - start)
|
||||
|
||||
per_rank = int(ceil((num_samples) / float(world_size)))
|
||||
|
||||
assert (
|
||||
per_rank > 0
|
||||
), f"Number of samples per rank must be larger than 0, (found: {per_rank})"
|
||||
|
||||
rank_start = start + rank * per_rank
|
||||
|
||||
rank_stop = min(rank_start + per_rank, stop)
|
||||
|
||||
new_length = rank_stop - rank_start
|
||||
|
||||
assert (
|
||||
new_length > 0
|
||||
), "Calculated start and stop points result in a length of zero for this rank."
|
||||
|
||||
return rank_start, rank_stop
|
||||
|
||||
|
||||
def get_reader(
|
||||
text_conditioned: bool, img_url: str, meta_url: str = None, txt_url: str = None
|
||||
):
|
||||
"""
|
||||
Create an EmbeddingReader object from the specified URLs
|
||||
|
||||
get_reader() will always expect a url to image embeddings.
|
||||
|
||||
If text-conditioned, it will also expect a meta_url for the captions.
|
||||
Otherwise, it will need txt_url for the matching text embeddings.
|
||||
|
||||
Returns an image_reader object if text-conditioned.
|
||||
Otherwise it returns both an image_reader and a text_reader
|
||||
"""
|
||||
|
||||
assert img_url is not None, "Must supply a image url"
|
||||
|
||||
if text_conditioned:
|
||||
assert meta_url is not None, "Must supply meta url if text-conditioned"
|
||||
|
||||
image_reader = EmbeddingReader(
|
||||
embeddings_folder=img_url,
|
||||
file_format="parquet_npy",
|
||||
# will assume the caption column exists and is the only one requested
|
||||
meta_columns=["caption"],
|
||||
metadata_folder=meta_url,
|
||||
)
|
||||
|
||||
return image_reader
|
||||
|
||||
# otherwise we will require text embeddings as well and return two readers
|
||||
assert (
|
||||
txt_url is not None
|
||||
), "Must supply text embedding url if not text-conditioning"
|
||||
|
||||
image_reader = EmbeddingReader(img_url, file_format="npy")
|
||||
text_reader = EmbeddingReader(txt_url, file_format="npy")
|
||||
|
||||
return image_reader, text_reader
|
||||
|
||||
|
||||
def make_splits(
|
||||
text_conditioned: bool,
|
||||
batch_size: int,
|
||||
num_data_points: int,
|
||||
train_split: float,
|
||||
eval_split: float,
|
||||
image_reader: EmbeddingReader,
|
||||
text_reader: EmbeddingReader = None,
|
||||
start=0,
|
||||
rank=0,
|
||||
world_size=1,
|
||||
):
|
||||
"""
|
||||
Split an embedding reader object as needed.
|
||||
|
||||
NOTE: make_splits() will infer the test set size from your train and eval.
|
||||
|
||||
Input:
|
||||
- text_conditioned: whether to prepare text-conditioned training data
|
||||
- batch_size: the batch size for a single gpu
|
||||
- num_data_points: the total number of data points you wish to train on
|
||||
- train_split: the percentage of data you wish to train on
|
||||
- eval_split: the percentage of data you wish to validate on
|
||||
- image_reader: the image_reader you wish to split
|
||||
- text_reader: the text_reader you want to split (if !text_conditioned)
|
||||
- start: the starting point within your dataset
|
||||
- rank: the rank of your worker
|
||||
- world_size: the total world size of your distributed training run
|
||||
|
||||
Returns:
|
||||
- PyTorch Dataloaders that yield tuples of (img, txt) data.
|
||||
"""
|
||||
|
||||
assert start < image_reader.count, "start position cannot exceed reader count."
|
||||
|
||||
# verify that the num_data_points does not exceed the max points
|
||||
if num_data_points > (image_reader.count - start):
|
||||
print(
|
||||
"Specified count is larger than what's available...defaulting to reader's count."
|
||||
)
|
||||
num_data_points = image_reader.count
|
||||
|
||||
# compute split points
|
||||
train_set_size = int(train_split * num_data_points)
|
||||
eval_set_size = int(eval_split * num_data_points)
|
||||
eval_start = train_set_size
|
||||
eval_stop = int(eval_start + eval_set_size)
|
||||
|
||||
assert (
|
||||
train_split + eval_split
|
||||
) < 1.0, "Specified train and eval split is too large to infer a test split."
|
||||
|
||||
# distribute to rank
|
||||
rank_train_start, rank_train_stop = distribute_to_rank(
|
||||
start, train_set_size, rank, world_size
|
||||
)
|
||||
rank_eval_start, rank_eval_stop = distribute_to_rank(
|
||||
train_set_size, eval_stop, rank, world_size
|
||||
)
|
||||
rank_test_start, rank_test_stop = distribute_to_rank(
|
||||
eval_stop, num_data_points, rank, world_size
|
||||
)
|
||||
|
||||
# wrap up splits into a dict
|
||||
train_split_args = dict(
|
||||
start=rank_train_start, stop=rank_train_stop, batch_size=batch_size
|
||||
)
|
||||
eval_split_args = dict(
|
||||
start=rank_eval_start, stop=rank_eval_stop, batch_size=batch_size
|
||||
)
|
||||
test_split_args = dict(
|
||||
start=rank_test_start, stop=rank_test_stop, batch_size=batch_size
|
||||
)
|
||||
|
||||
if text_conditioned:
|
||||
# add the text-conditioned args to a unified dict
|
||||
reader_args = dict(
|
||||
text_conditioned=text_conditioned,
|
||||
image_reader=image_reader,
|
||||
)
|
||||
|
||||
train_split_args = dict(**reader_args, **train_split_args)
|
||||
eval_split_args = dict(**reader_args, **eval_split_args)
|
||||
test_split_args = dict(**reader_args, **test_split_args)
|
||||
|
||||
train = PriorEmbeddingDataset(**train_split_args)
|
||||
val = PriorEmbeddingDataset(**eval_split_args)
|
||||
test = PriorEmbeddingDataset(**test_split_args)
|
||||
|
||||
else:
|
||||
# add the non-conditioned args to a unified dict
|
||||
reader_args = dict(
|
||||
text_conditioned=text_conditioned,
|
||||
image_reader=image_reader,
|
||||
text_reader=text_reader,
|
||||
)
|
||||
|
||||
train_split_args = dict(**reader_args, **train_split_args)
|
||||
eval_split_args = dict(**reader_args, **eval_split_args)
|
||||
test_split_args = dict(**reader_args, **test_split_args)
|
||||
|
||||
train = PriorEmbeddingDataset(**train_split_args)
|
||||
val = PriorEmbeddingDataset(**eval_split_args)
|
||||
test = PriorEmbeddingDataset(**test_split_args)
|
||||
|
||||
# true batch size is specifed in the PriorEmbeddingDataset
|
||||
train_loader = DataLoader(train, batch_size=None)
|
||||
eval_loader = DataLoader(val, batch_size=None)
|
||||
test_loader = DataLoader(test, batch_size=None)
|
||||
|
||||
return train_loader, eval_loader, test_loader
|
||||
@@ -1,17 +1,21 @@
|
||||
from torch.optim import AdamW, Adam
|
||||
|
||||
def separate_weight_decayable_params(params):
|
||||
no_wd_params = set([param for param in params if param.ndim < 2])
|
||||
wd_params = set(params) - no_wd_params
|
||||
wd_params, no_wd_params = [], []
|
||||
for param in params:
|
||||
param_list = no_wd_params if param.ndim < 2 else wd_params
|
||||
param_list.append(param)
|
||||
return wd_params, no_wd_params
|
||||
|
||||
def get_optimizer(
|
||||
params,
|
||||
lr = 1e-4,
|
||||
wd = 1e-2,
|
||||
betas = (0.9, 0.999),
|
||||
betas = (0.9, 0.99),
|
||||
eps = 1e-8,
|
||||
filter_by_requires_grad = False
|
||||
filter_by_requires_grad = False,
|
||||
group_wd_params = True,
|
||||
**kwargs
|
||||
):
|
||||
if filter_by_requires_grad:
|
||||
params = list(filter(lambda t: t.requires_grad, params))
|
||||
@@ -19,12 +23,12 @@ def get_optimizer(
|
||||
if wd == 0:
|
||||
return Adam(params, lr = lr, betas = betas, eps = eps)
|
||||
|
||||
params = set(params)
|
||||
wd_params, no_wd_params = separate_weight_decayable_params(params)
|
||||
if group_wd_params:
|
||||
wd_params, no_wd_params = separate_weight_decayable_params(params)
|
||||
|
||||
param_groups = [
|
||||
{'params': list(wd_params)},
|
||||
{'params': list(no_wd_params), 'weight_decay': 0},
|
||||
]
|
||||
params = [
|
||||
{'params': wd_params},
|
||||
{'params': no_wd_params, 'weight_decay': 0},
|
||||
]
|
||||
|
||||
return AdamW(param_groups, lr = lr, weight_decay = wd, betas = betas, eps = eps)
|
||||
return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps)
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
# to give users a quick easy start to training DALL-E without doing BPE
|
||||
|
||||
import torch
|
||||
import youtokentome as yttm
|
||||
|
||||
import html
|
||||
import os
|
||||
@@ -11,6 +10,8 @@ import regex as re
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
|
||||
from dalle2_pytorch.utils import import_or_print_error
|
||||
|
||||
# OpenAI simple tokenizer
|
||||
|
||||
@lru_cache()
|
||||
@@ -156,7 +157,9 @@ class YttmTokenizer:
|
||||
bpe_path = Path(bpe_path)
|
||||
assert bpe_path.exists(), f'BPE json path {str(bpe_path)} does not exist'
|
||||
|
||||
tokenizer = yttm.BPE(model = str(bpe_path))
|
||||
self.yttm = import_or_print_error('youtokentome', 'you need to install youtokentome by `pip install youtokentome`')
|
||||
|
||||
tokenizer = self.yttm.BPE(model = str(bpe_path))
|
||||
self.tokenizer = tokenizer
|
||||
self.vocab_size = tokenizer.vocab_size()
|
||||
|
||||
@@ -167,7 +170,7 @@ class YttmTokenizer:
|
||||
return self.tokenizer.decode(tokens, ignore_ids = pad_tokens.union({0}))
|
||||
|
||||
def encode(self, texts):
|
||||
encoded = self.tokenizer.encode(texts, output_type = yttm.OutputType.ID)
|
||||
encoded = self.tokenizer.encode(texts, output_type = self.yttm.OutputType.ID)
|
||||
return list(map(torch.tensor, encoded))
|
||||
|
||||
def tokenize(self, texts, context_length = 256, truncate_text = False):
|
||||
|
||||
@@ -1,17 +1,39 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
import importlib
|
||||
from itertools import zip_longest
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from dalle2_pytorch.utils import import_or_print_error
|
||||
|
||||
# constants
|
||||
|
||||
DEFAULT_DATA_PATH = './.tracker-data'
|
||||
|
||||
# helper functions
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
# load file functions
|
||||
|
||||
def load_wandb_file(run_path, file_path, **kwargs):
|
||||
wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb recall function')
|
||||
file_reference = wandb.restore(file_path, run_path=run_path)
|
||||
return file_reference.name
|
||||
|
||||
def load_local_file(file_path, **kwargs):
|
||||
return file_path
|
||||
|
||||
# base class
|
||||
|
||||
class BaseTracker(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self, data_path = DEFAULT_DATA_PATH):
|
||||
super().__init__()
|
||||
self.data_path = Path(data_path)
|
||||
self.data_path.mkdir(parents = True, exist_ok = True)
|
||||
|
||||
def init(self, config, **kwargs):
|
||||
raise NotImplementedError
|
||||
@@ -19,6 +41,58 @@ class BaseTracker(nn.Module):
|
||||
def log(self, log, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def log_images(self, images, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def save_state_dict(self, state_dict, relative_path, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def recall_state_dict(self, recall_source, *args, **kwargs):
|
||||
"""
|
||||
Loads a state dict from any source.
|
||||
Since a user may wish to load a model from a different source than their own tracker (i.e. tracking using wandb but recalling from disk),
|
||||
this should not be linked to any individual tracker.
|
||||
"""
|
||||
# TODO: Pull this into a dict or something similar so that we can add more sources without having a massive switch statement
|
||||
if recall_source == 'wandb':
|
||||
return torch.load(load_wandb_file(*args, **kwargs))
|
||||
elif recall_source == 'local':
|
||||
return torch.load(load_local_file(*args, **kwargs))
|
||||
else:
|
||||
raise ValueError('`recall_source` must be one of `wandb` or `local`')
|
||||
|
||||
def save_file(self, file_path, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def recall_file(self, recall_source, *args, **kwargs):
|
||||
if recall_source == 'wandb':
|
||||
return load_wandb_file(*args, **kwargs)
|
||||
elif recall_source == 'local':
|
||||
return load_local_file(*args, **kwargs)
|
||||
else:
|
||||
raise ValueError('`recall_source` must be one of `wandb` or `local`')
|
||||
|
||||
# Tracker that no-ops all calls except for recall
|
||||
|
||||
class DummyTracker(BaseTracker):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def init(self, config, **kwargs):
|
||||
pass
|
||||
|
||||
def log(self, log, **kwargs):
|
||||
pass
|
||||
|
||||
def log_images(self, images, **kwargs):
|
||||
pass
|
||||
|
||||
def save_state_dict(self, state_dict, relative_path, **kwargs):
|
||||
pass
|
||||
|
||||
def save_file(self, file_path, **kwargs):
|
||||
pass
|
||||
|
||||
# basic stdout class
|
||||
|
||||
class ConsoleTracker(BaseTracker):
|
||||
@@ -28,22 +102,51 @@ class ConsoleTracker(BaseTracker):
|
||||
def log(self, log, **kwargs):
|
||||
print(log)
|
||||
|
||||
def log_images(self, images, **kwargs): # noop for logging images
|
||||
pass
|
||||
|
||||
def save_state_dict(self, state_dict, relative_path, **kwargs):
|
||||
torch.save(state_dict, str(self.data_path / relative_path))
|
||||
|
||||
def save_file(self, file_path, **kwargs):
|
||||
# This is a no-op for local file systems since it is already saved locally
|
||||
pass
|
||||
|
||||
# basic wandb class
|
||||
|
||||
class WandbTracker(BaseTracker):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
try:
|
||||
import wandb
|
||||
except ImportError as e:
|
||||
print('`pip install wandb` to use the wandb experiment tracker')
|
||||
raise e
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb experiment tracker')
|
||||
os.environ["WANDB_SILENT"] = "true"
|
||||
self.wandb = wandb
|
||||
|
||||
def init(self, **config):
|
||||
self.wandb.init(**config)
|
||||
|
||||
def log(self, log, **kwargs):
|
||||
def log(self, log, verbose=False, **kwargs):
|
||||
if verbose:
|
||||
print(log)
|
||||
self.wandb.log(log, **kwargs)
|
||||
|
||||
def log_images(self, images, captions=[], image_section="images", **kwargs):
|
||||
"""
|
||||
Takes a tensor of images and a list of captions and logs them to wandb.
|
||||
"""
|
||||
wandb_images = [self.wandb.Image(image, caption=caption) for image, caption in zip_longest(images, captions)]
|
||||
self.log({ image_section: wandb_images }, **kwargs)
|
||||
|
||||
def save_state_dict(self, state_dict, relative_path, **kwargs):
|
||||
"""
|
||||
Saves a state_dict to disk and uploads it
|
||||
"""
|
||||
full_path = str(self.data_path / relative_path)
|
||||
torch.save(state_dict, full_path)
|
||||
self.wandb.save(full_path, base_path = str(self.data_path)) # Upload and keep relative to data_path
|
||||
|
||||
def save_file(self, file_path, base_path=None, **kwargs):
|
||||
"""
|
||||
Uploads a file from disk to wandb
|
||||
"""
|
||||
if base_path is None:
|
||||
base_path = self.data_path
|
||||
self.wandb.save(str(file_path), base_path = str(base_path))
|
||||
|
||||
303
dalle2_pytorch/train_configs.py
Normal file
303
dalle2_pytorch/train_configs.py
Normal file
@@ -0,0 +1,303 @@
|
||||
import json
|
||||
from torchvision import transforms as T
|
||||
from pydantic import BaseModel, validator, root_validator
|
||||
from typing import List, Iterable, Optional, Union, Tuple, Dict, Any
|
||||
|
||||
from x_clip import CLIP as XCLIP
|
||||
from coca_pytorch import CoCa
|
||||
|
||||
from dalle2_pytorch.dalle2_pytorch import (
|
||||
CoCaAdapter,
|
||||
OpenAIClipAdapter,
|
||||
Unet,
|
||||
Decoder,
|
||||
DiffusionPrior,
|
||||
DiffusionPriorNetwork,
|
||||
XClipAdapter
|
||||
)
|
||||
|
||||
# helper functions
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
def ListOrTuple(inner_type):
|
||||
return Union[List[inner_type], Tuple[inner_type]]
|
||||
|
||||
def SingularOrIterable(inner_type):
|
||||
return Union[inner_type, ListOrTuple(inner_type)]
|
||||
|
||||
# general pydantic classes
|
||||
|
||||
class TrainSplitConfig(BaseModel):
|
||||
train: float = 0.75
|
||||
val: float = 0.15
|
||||
test: float = 0.1
|
||||
|
||||
@root_validator
|
||||
def validate_all(cls, fields):
|
||||
actual_sum = sum([*fields.values()])
|
||||
if actual_sum != 1.:
|
||||
raise ValueError(f'{fields.keys()} must sum to 1.0. Found: {actual_sum}')
|
||||
return fields
|
||||
|
||||
class TrackerConfig(BaseModel):
|
||||
tracker_type: str = 'console' # Decoder currently supports console and wandb
|
||||
data_path: str = './models' # The path where files will be saved locally
|
||||
init_config: Dict[str, Any] = None
|
||||
wandb_entity: str = '' # Only needs to be set if tracker_type is wandb
|
||||
wandb_project: str = ''
|
||||
verbose: bool = False # Whether to print console logging for non-console trackers
|
||||
|
||||
# diffusion prior pydantic classes
|
||||
|
||||
class AdapterConfig(BaseModel):
|
||||
make: str = "openai"
|
||||
model: str = "ViT-L/14"
|
||||
base_model_kwargs: Dict[str, Any] = None
|
||||
|
||||
def create(self):
|
||||
if self.make == "openai":
|
||||
return OpenAIClipAdapter(self.model)
|
||||
elif self.make == "x-clip":
|
||||
return XClipAdapter(XCLIP(**self.base_model_kwargs))
|
||||
elif self.make == "coca":
|
||||
return CoCaAdapter(CoCa(**self.base_model_kwargs))
|
||||
else:
|
||||
raise AttributeError("No adapter with that name is available.")
|
||||
|
||||
class DiffusionPriorNetworkConfig(BaseModel):
|
||||
dim: int
|
||||
depth: int
|
||||
num_timesteps: int = None
|
||||
num_time_embeds: int = 1
|
||||
num_image_embeds: int = 1
|
||||
num_text_embeds: int = 1
|
||||
dim_head: int = 64
|
||||
heads: int = 8
|
||||
ff_mult: int = 4
|
||||
norm_out: bool = True
|
||||
attn_dropout: float = 0.
|
||||
ff_dropout: float = 0.
|
||||
final_proj: bool = True
|
||||
normformer: bool = False
|
||||
rotary_emb: bool = True
|
||||
|
||||
def create(self):
|
||||
kwargs = self.dict()
|
||||
return DiffusionPriorNetwork(**kwargs)
|
||||
|
||||
class DiffusionPriorConfig(BaseModel):
|
||||
clip: AdapterConfig = None
|
||||
net: DiffusionPriorNetworkConfig
|
||||
image_embed_dim: int
|
||||
image_size: int
|
||||
image_channels: int = 3
|
||||
timesteps: int = 1000
|
||||
cond_drop_prob: float = 0.
|
||||
loss_type: str = 'l2'
|
||||
predict_x_start: bool = True
|
||||
beta_schedule: str = 'cosine'
|
||||
condition_on_text_encodings: bool = True
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
def create(self):
|
||||
kwargs = self.dict()
|
||||
|
||||
has_clip = exists(kwargs.pop('clip'))
|
||||
kwargs.pop('net')
|
||||
|
||||
clip = None
|
||||
if has_clip:
|
||||
clip = self.clip.create()
|
||||
|
||||
diffusion_prior_network = self.net.create()
|
||||
return DiffusionPrior(net = diffusion_prior_network, clip = clip, **kwargs)
|
||||
|
||||
class DiffusionPriorTrainConfig(BaseModel):
|
||||
epochs: int = 1
|
||||
lr: float = 1.1e-4
|
||||
wd: float = 6.02e-2
|
||||
max_grad_norm: float = 0.5
|
||||
use_ema: bool = True
|
||||
ema_beta: float = 0.99
|
||||
amp: bool = False
|
||||
save_every: int = 10000 # what steps to save on
|
||||
|
||||
class DiffusionPriorDataConfig(BaseModel):
|
||||
image_url: str # path to embeddings folder
|
||||
meta_url: str # path to metadata (captions) for images
|
||||
splits: TrainSplitConfig
|
||||
batch_size: int = 64
|
||||
|
||||
class DiffusionPriorLoadConfig(BaseModel):
|
||||
source: str = None
|
||||
resume: bool = False
|
||||
|
||||
class TrainDiffusionPriorConfig(BaseModel):
|
||||
prior: DiffusionPriorConfig
|
||||
data: DiffusionPriorDataConfig
|
||||
train: DiffusionPriorTrainConfig
|
||||
load: DiffusionPriorLoadConfig
|
||||
tracker: TrackerConfig
|
||||
|
||||
@classmethod
|
||||
def from_json_path(cls, json_path):
|
||||
with open(json_path) as f:
|
||||
config = json.load(f)
|
||||
return cls(**config)
|
||||
|
||||
# decoder pydantic classes
|
||||
|
||||
class UnetConfig(BaseModel):
|
||||
dim: int
|
||||
dim_mults: ListOrTuple(int)
|
||||
image_embed_dim: int = None
|
||||
cond_dim: int = None
|
||||
channels: int = 3
|
||||
attn_dim_head: int = 32
|
||||
attn_heads: int = 16
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
class DecoderConfig(BaseModel):
|
||||
unets: ListOrTuple(UnetConfig)
|
||||
image_size: int = None
|
||||
image_sizes: ListOrTuple(int) = None
|
||||
condition_on_text_encodings: bool = False
|
||||
clip: Optional[AdapterConfig] # The clip model to use if embeddings are not provided
|
||||
channels: int = 3
|
||||
timesteps: int = 1000
|
||||
loss_type: str = 'l2'
|
||||
beta_schedule: ListOrTuple(str) = 'cosine'
|
||||
learned_variance: bool = True
|
||||
image_cond_drop_prob: float = 0.1
|
||||
text_cond_drop_prob: float = 0.5
|
||||
|
||||
def create(self):
|
||||
decoder_kwargs = self.dict()
|
||||
|
||||
unet_configs = decoder_kwargs.pop('unets')
|
||||
unets = [Unet(**config) for config in unet_configs]
|
||||
|
||||
has_clip = exists(decoder_kwargs.pop('clip'))
|
||||
clip = None
|
||||
if has_clip:
|
||||
clip = self.clip.create()
|
||||
|
||||
return Decoder(unets, clip=clip, **decoder_kwargs)
|
||||
|
||||
@validator('image_sizes')
|
||||
def check_image_sizes(cls, image_sizes, values):
|
||||
if exists(values.get('image_size')) ^ exists(image_sizes):
|
||||
return image_sizes
|
||||
raise ValueError('either image_size or image_sizes is required, but not both')
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
class DecoderDataConfig(BaseModel):
|
||||
webdataset_base_url: str # path to a webdataset with jpg images
|
||||
img_embeddings_url: Optional[str] # path to .npy files with embeddings
|
||||
text_embeddings_url: Optional[str] # path to .npy files with embeddings
|
||||
num_workers: int = 4
|
||||
batch_size: int = 64
|
||||
start_shard: int = 0
|
||||
end_shard: int = 9999999
|
||||
shard_width: int = 6
|
||||
index_width: int = 4
|
||||
splits: TrainSplitConfig
|
||||
shuffle_train: bool = True
|
||||
resample_train: bool = False
|
||||
preprocessing: Dict[str, Any] = {'ToTensor': True}
|
||||
|
||||
@property
|
||||
def img_preproc(self):
|
||||
def _get_transformation(transformation_name, **kwargs):
|
||||
if transformation_name == "RandomResizedCrop":
|
||||
return T.RandomResizedCrop(**kwargs)
|
||||
elif transformation_name == "RandomHorizontalFlip":
|
||||
return T.RandomHorizontalFlip()
|
||||
elif transformation_name == "ToTensor":
|
||||
return T.ToTensor()
|
||||
|
||||
transforms = []
|
||||
for transform_name, transform_kwargs_or_bool in self.preprocessing.items():
|
||||
transform_kwargs = {} if not isinstance(transform_kwargs_or_bool, dict) else transform_kwargs_or_bool
|
||||
transforms.append(_get_transformation(transform_name, **transform_kwargs))
|
||||
return T.Compose(transforms)
|
||||
|
||||
class DecoderTrainConfig(BaseModel):
|
||||
epochs: int = 20
|
||||
lr: SingularOrIterable(float) = 1e-4
|
||||
wd: SingularOrIterable(float) = 0.01
|
||||
max_grad_norm: SingularOrIterable(float) = 0.5
|
||||
save_every_n_samples: int = 100000
|
||||
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
|
||||
device: str = 'cuda:0'
|
||||
epoch_samples: int = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
|
||||
validation_samples: int = None # Same as above but for validation.
|
||||
use_ema: bool = True
|
||||
ema_beta: float = 0.999
|
||||
amp: bool = False
|
||||
save_all: bool = False # Whether to preserve all checkpoints
|
||||
save_latest: bool = True # Whether to always save the latest checkpoint
|
||||
save_best: bool = True # Whether to save the best checkpoint
|
||||
unet_training_mask: ListOrTuple(bool) = None # If None, use all unets
|
||||
|
||||
class DecoderEvaluateConfig(BaseModel):
|
||||
n_evaluation_samples: int = 1000
|
||||
FID: Dict[str, Any] = None
|
||||
IS: Dict[str, Any] = None
|
||||
KID: Dict[str, Any] = None
|
||||
LPIPS: Dict[str, Any] = None
|
||||
|
||||
class DecoderLoadConfig(BaseModel):
|
||||
source: str = None # Supports file and wandb
|
||||
run_path: str = '' # Used only if source is wandb
|
||||
file_path: str = '' # The local filepath if source is file. If source is wandb, the relative path to the model file in wandb.
|
||||
resume: bool = False # If using wandb, whether to resume the run
|
||||
|
||||
class TrainDecoderConfig(BaseModel):
|
||||
decoder: DecoderConfig
|
||||
data: DecoderDataConfig
|
||||
train: DecoderTrainConfig
|
||||
evaluate: DecoderEvaluateConfig
|
||||
tracker: TrackerConfig
|
||||
load: DecoderLoadConfig
|
||||
seed: int = 0
|
||||
|
||||
@classmethod
|
||||
def from_json_path(cls, json_path):
|
||||
with open(json_path) as f:
|
||||
config = json.load(f)
|
||||
return cls(**config)
|
||||
|
||||
@root_validator
|
||||
def check_has_embeddings(cls, values):
|
||||
# Makes sure that enough information is provided to get the embeddings specified for training
|
||||
data_config, decoder_config = values.get('data'), values.get('decoder')
|
||||
if data_config is None or decoder_config is None:
|
||||
# Then something else errored and we should just pass through
|
||||
return values
|
||||
using_text_embeddings = decoder_config.condition_on_text_encodings
|
||||
using_clip = exists(decoder_config.clip)
|
||||
img_emb_url = data_config.img_embeddings_url
|
||||
text_emb_url = data_config.text_embeddings_url
|
||||
if using_text_embeddings:
|
||||
# Then we need some way to get the embeddings
|
||||
assert using_clip or text_emb_url is not None, 'If condition_on_text_encodings is true, either clip or text_embeddings_url must be provided'
|
||||
if using_clip:
|
||||
if using_text_embeddings:
|
||||
assert text_emb_url is None or img_emb_url is None, 'Loaded clip, but also provided text_embeddings_url and img_embeddings_url. This is redundant. Remove the clip model or the embeddings'
|
||||
else:
|
||||
assert img_emb_url is None, 'Loaded clip, but also provided img_embeddings_url. This is redundant. Remove the clip model or the embeddings'
|
||||
if text_emb_url:
|
||||
assert using_text_embeddings, "Text embeddings are being loaded, but text embeddings are not being conditioned on. This will slow down the dataloader for no reason."
|
||||
return values
|
||||
@@ -1,5 +1,6 @@
|
||||
import time
|
||||
import copy
|
||||
from pathlib import Path
|
||||
from math import ceil
|
||||
from functools import partial, wraps
|
||||
from collections.abc import Iterable
|
||||
@@ -10,6 +11,12 @@ from torch.cuda.amp import autocast, GradScaler
|
||||
|
||||
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
|
||||
from dalle2_pytorch.optimizer import get_optimizer
|
||||
from dalle2_pytorch.version import __version__
|
||||
from packaging import version
|
||||
|
||||
from ema_pytorch import EMA
|
||||
|
||||
from accelerate import Accelerator
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -19,7 +26,9 @@ def exists(val):
|
||||
return val is not None
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if callable(d) else d
|
||||
|
||||
def cast_tuple(val, length = 1):
|
||||
return val if isinstance(val, tuple) else ((val,) * length)
|
||||
@@ -128,111 +137,6 @@ def split_args_and_kwargs(*args, split_size = None, **kwargs):
|
||||
chunk_size_frac = chunk_size / batch_size
|
||||
yield chunk_size_frac, (chunked_args, chunked_kwargs)
|
||||
|
||||
# print helpers
|
||||
|
||||
def print_ribbon(s, symbol = '=', repeat = 40):
|
||||
flank = symbol * repeat
|
||||
return f'{flank} {s} {flank}'
|
||||
|
||||
# saving and loading functions
|
||||
|
||||
# for diffusion prior
|
||||
|
||||
def load_diffusion_model(dprior_path, device):
|
||||
dprior_path = Path(dprior_path)
|
||||
assert dprior_path.exists(), 'Dprior model file does not exist'
|
||||
loaded_obj = torch.load(str(dprior_path), map_location='cpu')
|
||||
|
||||
# Get hyperparameters of loaded model
|
||||
dpn_config = loaded_obj['hparams']['diffusion_prior_network']
|
||||
dp_config = loaded_obj['hparams']['diffusion_prior']
|
||||
image_embed_dim = loaded_obj['image_embed_dim']['image_embed_dim']
|
||||
|
||||
# Create DiffusionPriorNetwork and DiffusionPrior with loaded hyperparameters
|
||||
|
||||
# DiffusionPriorNetwork
|
||||
prior_network = DiffusionPriorNetwork( dim = image_embed_dim, **dpn_config).to(device)
|
||||
|
||||
# DiffusionPrior with text embeddings and image embeddings pre-computed
|
||||
diffusion_prior = DiffusionPrior(net = prior_network, **dp_config, image_embed_dim = image_embed_dim).to(device)
|
||||
|
||||
# Load state dict from saved model
|
||||
diffusion_prior.load_state_dict(loaded_obj['model'])
|
||||
|
||||
return diffusion_prior, loaded_obj
|
||||
|
||||
def save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim):
|
||||
# Saving State Dict
|
||||
print_ribbon('Saving checkpoint')
|
||||
|
||||
state_dict = dict(model=model.state_dict(),
|
||||
optimizer=optimizer.state_dict(),
|
||||
scaler=scaler.state_dict(),
|
||||
hparams = config,
|
||||
image_embed_dim = {"image_embed_dim":image_embed_dim})
|
||||
torch.save(state_dict, save_path+'/'+str(time.time())+'_saved_model.pth')
|
||||
|
||||
# exponential moving average wrapper
|
||||
|
||||
class EMA(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
beta = 0.9999,
|
||||
update_after_step = 1000,
|
||||
update_every = 10,
|
||||
):
|
||||
super().__init__()
|
||||
self.beta = beta
|
||||
self.online_model = model
|
||||
self.ema_model = copy.deepcopy(model)
|
||||
|
||||
self.update_every = update_every
|
||||
self.update_after_step = update_after_step // update_every # only start EMA after this step number, starting at 0
|
||||
|
||||
self.register_buffer('initted', torch.Tensor([False]))
|
||||
self.register_buffer('step', torch.tensor([0.]))
|
||||
|
||||
def restore_ema_model_device(self):
|
||||
device = self.initted.device
|
||||
self.ema_model.to(device)
|
||||
|
||||
def copy_params_from_model_to_ema(self):
|
||||
self.ema_model.state_dict(self.online_model.state_dict())
|
||||
|
||||
def update(self):
|
||||
self.step += 1
|
||||
|
||||
if (self.step % self.update_every) != 0:
|
||||
return
|
||||
|
||||
if self.step <= self.update_after_step:
|
||||
self.copy_params_from_model_to_ema()
|
||||
return
|
||||
|
||||
if not self.initted:
|
||||
self.copy_params_from_model_to_ema()
|
||||
self.initted.data.copy_(torch.Tensor([True]))
|
||||
|
||||
self.update_moving_average(self.ema_model, self.online_model)
|
||||
|
||||
def update_moving_average(self, ma_model, current_model):
|
||||
def calculate_ema(beta, old, new):
|
||||
if not exists(old):
|
||||
return new
|
||||
return old * beta + (1 - beta) * new
|
||||
|
||||
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
|
||||
old_weight, up_weight = ma_params.data, current_params.data
|
||||
ma_params.data = calculate_ema(self.beta, old_weight, up_weight)
|
||||
|
||||
for current_buffer, ma_buffer in zip(current_model.buffers(), ma_model.buffers()):
|
||||
new_buffer_value = calculate_ema(self.beta, ma_buffer, current_buffer)
|
||||
ma_buffer.copy_(new_buffer_value)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.ema_model(*args, **kwargs)
|
||||
|
||||
# diffusion prior trainer
|
||||
|
||||
def prior_sample_in_chunks(fn):
|
||||
@@ -255,44 +159,190 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
eps = 1e-6,
|
||||
max_grad_norm = None,
|
||||
amp = False,
|
||||
group_wd_params = True,
|
||||
device = None,
|
||||
accelerator = None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(diffusion_prior, DiffusionPrior)
|
||||
assert not exists(accelerator) or isinstance(accelerator, Accelerator)
|
||||
assert exists(accelerator) or exists(device), "You must supply some method of obtaining a device."
|
||||
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
|
||||
|
||||
# assign some helpful member vars
|
||||
self.accelerator = accelerator
|
||||
self.device = accelerator.device if exists(accelerator) else device
|
||||
self.text_conditioned = diffusion_prior.condition_on_text_encodings
|
||||
|
||||
# save model
|
||||
|
||||
self.diffusion_prior = diffusion_prior
|
||||
|
||||
# exponential moving average
|
||||
|
||||
self.use_ema = use_ema
|
||||
if self.use_ema:
|
||||
self.ema_diffusion_prior = EMA(diffusion_prior, **ema_kwargs)
|
||||
|
||||
# optimizer and mixed precision stuff
|
||||
|
||||
self.amp = amp
|
||||
|
||||
self.scaler = GradScaler(enabled = amp)
|
||||
|
||||
self.optim_kwargs = dict(lr=lr, wd=wd, eps=eps, group_wd_params=group_wd_params)
|
||||
|
||||
self.optimizer = get_optimizer(
|
||||
diffusion_prior.parameters(),
|
||||
lr = lr,
|
||||
wd = wd,
|
||||
eps = eps,
|
||||
self.diffusion_prior.parameters(),
|
||||
**self.optim_kwargs,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
# distribute the model if using HFA
|
||||
if exists(self.accelerator):
|
||||
self.diffusion_prior, self.optimizer = self.accelerator.prepare(self.diffusion_prior, self.optimizer)
|
||||
|
||||
# exponential moving average stuff
|
||||
|
||||
self.use_ema = use_ema
|
||||
|
||||
if self.use_ema:
|
||||
self.ema_diffusion_prior = EMA(self.unwrap_model(self.diffusion_prior), **ema_kwargs)
|
||||
|
||||
# gradient clipping if needed
|
||||
|
||||
self.max_grad_norm = max_grad_norm
|
||||
|
||||
self.register_buffer('step', torch.tensor([0.]))
|
||||
# track steps internally
|
||||
|
||||
self.register_buffer('step', torch.tensor([0]))
|
||||
|
||||
# accelerator wrappers
|
||||
|
||||
def print(self, msg):
|
||||
if exists(self.accelerator):
|
||||
self.accelerator.print(msg)
|
||||
else:
|
||||
print(msg)
|
||||
|
||||
def unwrap_model(self, model):
|
||||
if exists(self.accelerator):
|
||||
return self.accelerator.unwrap_model(model)
|
||||
else:
|
||||
return model
|
||||
|
||||
def wait_for_everyone(self):
|
||||
if exists(self.accelerator):
|
||||
self.accelerator.wait_for_everyone()
|
||||
|
||||
def is_main_process(self):
|
||||
if exists(self.accelerator):
|
||||
return self.accelerator.is_main_process
|
||||
else:
|
||||
return True
|
||||
|
||||
def clip_grad_norm_(self, *args):
|
||||
if exists(self.accelerator):
|
||||
return self.accelerator.clip_grad_norm_(*args)
|
||||
else:
|
||||
return torch.nn.utils.clip_grad_norm_(*args)
|
||||
|
||||
def backprop(self, x):
|
||||
if exists(self.accelerator):
|
||||
self.accelerator.backward(x)
|
||||
else:
|
||||
try:
|
||||
x.backward()
|
||||
except Exception as e:
|
||||
self.print(f"Caught error in backprop call: {e}")
|
||||
|
||||
# utility
|
||||
|
||||
def save(self, path, overwrite = True, **kwargs):
|
||||
# ensure we sync gradients before continuing
|
||||
self.wait_for_everyone()
|
||||
|
||||
# only save on the main process
|
||||
if self.is_main_process():
|
||||
self.print(f"Saving checkpoint at step: {self.step.item()}")
|
||||
path = Path(path)
|
||||
assert not (path.exists() and not overwrite)
|
||||
path.parent.mkdir(parents = True, exist_ok = True)
|
||||
|
||||
save_obj = dict(
|
||||
scaler = self.scaler.state_dict(),
|
||||
optimizer = self.optimizer.state_dict(),
|
||||
model = self.unwrap_model(self.diffusion_prior).state_dict(), # unwrap the model from distribution if applicable
|
||||
version = version.parse(__version__),
|
||||
step = self.step.item(),
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if self.use_ema:
|
||||
save_obj = {
|
||||
**save_obj,
|
||||
'ema': self.ema_diffusion_prior.state_dict(),
|
||||
'ema_model': self.ema_diffusion_prior.ema_model.state_dict() # save the ema model specifically for easy ema-only reload
|
||||
}
|
||||
|
||||
torch.save(save_obj, str(path))
|
||||
|
||||
def load(self, path, overwrite_lr = True, strict = True):
|
||||
"""
|
||||
Load a checkpoint of a diffusion prior trainer.
|
||||
|
||||
Will load the entire trainer, including the optimizer and EMA.
|
||||
|
||||
Params:
|
||||
- path (str): a path to the DiffusionPriorTrainer checkpoint file
|
||||
- overwrite_lr (bool): wether or not to overwrite the stored LR with the LR specified in the new trainer
|
||||
- strict (bool): kwarg for `torch.nn.Module.load_state_dict`, will force an exact checkpoint match
|
||||
|
||||
Returns:
|
||||
loaded_obj (dict): The loaded checkpoint dictionary
|
||||
"""
|
||||
|
||||
# all processes need to load checkpoint. no restriction here
|
||||
path = Path(path)
|
||||
assert path.exists()
|
||||
|
||||
loaded_obj = torch.load(str(path), map_location=self.device)
|
||||
|
||||
if version.parse(__version__) != loaded_obj['version']:
|
||||
print(f'loading saved diffusion prior at version {loaded_obj["version"]} but current package version is at {__version__}')
|
||||
|
||||
# unwrap the model when loading from checkpoint
|
||||
self.unwrap_model(self.diffusion_prior).load_state_dict(loaded_obj['model'], strict = strict)
|
||||
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
|
||||
|
||||
self.scaler.load_state_dict(loaded_obj['scaler'])
|
||||
self.optimizer.load_state_dict(loaded_obj['optimizer'])
|
||||
|
||||
if overwrite_lr:
|
||||
new_lr = self.optim_kwargs["lr"]
|
||||
|
||||
self.print(f"Overriding LR to be {new_lr}")
|
||||
|
||||
for group in self.optimizer.param_groups:
|
||||
group["lr"] = new_lr
|
||||
|
||||
if self.use_ema:
|
||||
assert 'ema' in loaded_obj
|
||||
self.ema_diffusion_prior.load_state_dict(loaded_obj['ema'], strict = strict)
|
||||
# below not be necessary, but I had a suspicion that this wasn't being loaded correctly
|
||||
self.ema_diffusion_prior.ema_model.load_state_dict(loaded_obj["ema_model"])
|
||||
|
||||
# sync and inform
|
||||
self.wait_for_everyone()
|
||||
self.print(f"Loaded model")
|
||||
|
||||
return loaded_obj
|
||||
|
||||
# model functionality
|
||||
|
||||
def update(self):
|
||||
# only continue with updates until all ranks finish
|
||||
self.wait_for_everyone()
|
||||
|
||||
if exists(self.max_grad_norm):
|
||||
self.scaler.unscale_(self.optimizer)
|
||||
nn.utils.clip_grad_norm_(self.diffusion_prior.parameters(), self.max_grad_norm)
|
||||
# utilize HFA clipping where applicable
|
||||
self.clip_grad_norm_(self.diffusion_prior.parameters(), self.max_grad_norm)
|
||||
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
@@ -307,17 +357,26 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
@cast_torch_tensor
|
||||
@prior_sample_in_chunks
|
||||
def p_sample_loop(self, *args, **kwargs):
|
||||
return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs)
|
||||
model = self.ema_diffusion_prior.ema_model if self.use_ema else self.diffusion_prior
|
||||
return model.p_sample_loop(*args, **kwargs)
|
||||
|
||||
@torch.no_grad()
|
||||
@cast_torch_tensor
|
||||
@prior_sample_in_chunks
|
||||
def sample(self, *args, **kwargs):
|
||||
return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs)
|
||||
model = self.ema_diffusion_prior.ema_model if self.use_ema else self.diffusion_prior
|
||||
return model.sample(*args, **kwargs)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_batch_size(self, *args, **kwargs):
|
||||
return self.ema_diffusion_prior.ema_model.sample_batch_size(*args, **kwargs)
|
||||
model = self.ema_diffusion_prior.ema_model if self.use_ema else self.diffusion_prior
|
||||
return model.sample_batch_size(*args, **kwargs)
|
||||
|
||||
@torch.no_grad()
|
||||
@cast_torch_tensor
|
||||
@prior_sample_in_chunks
|
||||
def embed_text(self, *args, **kwargs):
|
||||
return self.unwrap_model(self.diffusion_prior).clip.embed_text(*args, **kwargs)
|
||||
|
||||
@cast_torch_tensor
|
||||
def forward(
|
||||
@@ -335,8 +394,10 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
|
||||
total_loss += loss.item()
|
||||
|
||||
# backprop with accelerate if applicable
|
||||
|
||||
if self.training:
|
||||
self.scaler.scale(loss).backward()
|
||||
self.backprop(self.scaler.scale(loss))
|
||||
|
||||
return total_loss
|
||||
|
||||
@@ -362,20 +423,23 @@ class DecoderTrainer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
decoder,
|
||||
accelerator = None,
|
||||
use_ema = True,
|
||||
lr = 1e-4,
|
||||
wd = 1e-2,
|
||||
eps = 1e-8,
|
||||
max_grad_norm = 0.5,
|
||||
amp = False,
|
||||
group_wd_params = True,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(decoder, Decoder)
|
||||
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
|
||||
|
||||
self.decoder = decoder
|
||||
self.num_unets = len(self.decoder.unets)
|
||||
self.accelerator = default(accelerator, Accelerator)
|
||||
|
||||
self.num_unets = len(decoder.unets)
|
||||
|
||||
self.use_ema = use_ema
|
||||
self.ema_unets = nn.ModuleList([])
|
||||
@@ -387,56 +451,103 @@ class DecoderTrainer(nn.Module):
|
||||
|
||||
lr, wd, eps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps))
|
||||
|
||||
for ind, (unet, unet_lr, unet_wd, unet_eps) in enumerate(zip(self.decoder.unets, lr, wd, eps)):
|
||||
assert all([unet_lr < 1e-3 for unet_lr in lr]), 'your learning rate is too high, recommend sticking with 1e-4, at most 5e-4'
|
||||
|
||||
optimizers = []
|
||||
|
||||
for unet, unet_lr, unet_wd, unet_eps in zip(decoder.unets, lr, wd, eps):
|
||||
optimizer = get_optimizer(
|
||||
unet.parameters(),
|
||||
lr = unet_lr,
|
||||
wd = unet_wd,
|
||||
eps = unet_eps,
|
||||
group_wd_params = group_wd_params,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
setattr(self, f'optim{ind}', optimizer) # cannot use pytorch ModuleList for some reason with optimizers
|
||||
optimizers.append(optimizer)
|
||||
|
||||
if self.use_ema:
|
||||
self.ema_unets.append(EMA(unet, **ema_kwargs))
|
||||
|
||||
scaler = GradScaler(enabled = amp)
|
||||
setattr(self, f'scaler{ind}', scaler)
|
||||
|
||||
# gradient clipping if needed
|
||||
|
||||
self.max_grad_norm = max_grad_norm
|
||||
|
||||
self.register_buffer('step', torch.tensor([0.]))
|
||||
|
||||
decoder, *optimizers = list(self.accelerator.prepare(decoder, *optimizers))
|
||||
|
||||
self.decoder = decoder
|
||||
|
||||
for opt_ind, optimizer in zip(range(len(optimizers)), optimizers):
|
||||
setattr(self, f'optim{opt_ind}', optimizer)
|
||||
|
||||
def save(self, path, overwrite = True, **kwargs):
|
||||
path = Path(path)
|
||||
assert not (path.exists() and not overwrite)
|
||||
path.parent.mkdir(parents = True, exist_ok = True)
|
||||
|
||||
save_obj = dict(
|
||||
model = self.accelerator.unwrap_model(self.decoder).state_dict(),
|
||||
version = __version__,
|
||||
step = self.step.item(),
|
||||
**kwargs
|
||||
)
|
||||
|
||||
for ind in range(0, self.num_unets):
|
||||
optimizer_key = f'optim{ind}'
|
||||
optimizer = getattr(self, optimizer_key)
|
||||
save_obj = {**save_obj, optimizer_key: self.accelerator.unwrap_model(optimizer).state_dict()}
|
||||
|
||||
if self.use_ema:
|
||||
save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}
|
||||
|
||||
self.accelerator.save(save_obj, str(path))
|
||||
|
||||
def load(self, path, only_model = False, strict = True):
|
||||
path = Path(path)
|
||||
assert path.exists()
|
||||
|
||||
loaded_obj = torch.load(str(path), map_location = 'cpu')
|
||||
|
||||
if version.parse(__version__) != version.parse(loaded_obj['version']):
|
||||
self.accelerator.print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {__version__}')
|
||||
|
||||
self.accelerator.unwrap_model(self.decoder).load_state_dict(loaded_obj['model'], strict = strict)
|
||||
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
|
||||
|
||||
if only_model:
|
||||
return loaded_obj
|
||||
|
||||
for ind in range(0, self.num_unets):
|
||||
optimizer_key = f'optim{ind}'
|
||||
optimizer = getattr(self, optimizer_key)
|
||||
|
||||
self.accelerator.unwrap_model(optimizer).load_state_dict(loaded_obj[optimizer_key])
|
||||
|
||||
if self.use_ema:
|
||||
assert 'ema' in loaded_obj
|
||||
self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict)
|
||||
|
||||
return loaded_obj
|
||||
|
||||
@property
|
||||
def unets(self):
|
||||
return nn.ModuleList([ema.ema_model for ema in self.ema_unets])
|
||||
|
||||
def scale(self, loss, *, unet_number):
|
||||
assert 1 <= unet_number <= self.num_unets
|
||||
index = unet_number - 1
|
||||
scaler = getattr(self, f'scaler{index}')
|
||||
return scaler.scale(loss)
|
||||
|
||||
def update(self, unet_number = None):
|
||||
if self.num_unets == 1:
|
||||
unet_number = default(unet_number, 1)
|
||||
|
||||
assert exists(unet_number) and 1 <= unet_number <= self.num_unets
|
||||
index = unet_number - 1
|
||||
unet = self.decoder.unets[index]
|
||||
|
||||
optimizer = getattr(self, f'optim{index}')
|
||||
scaler = getattr(self, f'scaler{index}')
|
||||
|
||||
if exists(self.max_grad_norm):
|
||||
scaler.unscale_(optimizer)
|
||||
nn.utils.clip_grad_norm_(unet.parameters(), self.max_grad_norm)
|
||||
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
self.accelerator.clip_grad_norm_(self.decoder.parameters(), self.max_grad_norm) # Automatically unscales gradients
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if self.use_ema:
|
||||
@@ -449,15 +560,17 @@ class DecoderTrainer(nn.Module):
|
||||
@cast_torch_tensor
|
||||
@decoder_sample_in_chunks
|
||||
def sample(self, *args, **kwargs):
|
||||
distributed = self.accelerator.num_processes > 1
|
||||
base_decoder = self.accelerator.unwrap_model(self.decoder)
|
||||
if kwargs.pop('use_non_ema', False) or not self.use_ema:
|
||||
return self.decoder.sample(*args, **kwargs)
|
||||
return base_decoder.sample(*args, **kwargs, distributed = distributed)
|
||||
|
||||
trainable_unets = self.decoder.unets
|
||||
self.decoder.unets = self.unets # swap in exponential moving averaged unets for sampling
|
||||
trainable_unets = self.accelerator.unwrap_model(self.decoder).unets
|
||||
base_decoder.unets = self.unets # swap in exponential moving averaged unets for sampling
|
||||
|
||||
output = self.decoder.sample(*args, **kwargs)
|
||||
output = base_decoder.sample(*args, **kwargs, distributed = distributed)
|
||||
|
||||
self.decoder.unets = trainable_unets # restore original training unets
|
||||
base_decoder.unets = trainable_unets # restore original training unets
|
||||
|
||||
# cast the ema_model unets back to original device
|
||||
for ema in self.ema_unets:
|
||||
@@ -465,6 +578,18 @@ class DecoderTrainer(nn.Module):
|
||||
|
||||
return output
|
||||
|
||||
@torch.no_grad()
|
||||
@cast_torch_tensor
|
||||
@prior_sample_in_chunks
|
||||
def embed_text(self, *args, **kwargs):
|
||||
return self.accelerator.unwrap_model(self.decoder).clip.embed_text(*args, **kwargs)
|
||||
|
||||
@torch.no_grad()
|
||||
@cast_torch_tensor
|
||||
@prior_sample_in_chunks
|
||||
def embed_image(self, *args, **kwargs):
|
||||
return self.accelerator.unwrap_model(self.decoder).clip.embed_image(*args, **kwargs)
|
||||
|
||||
@cast_torch_tensor
|
||||
def forward(
|
||||
self,
|
||||
@@ -479,13 +604,14 @@ class DecoderTrainer(nn.Module):
|
||||
total_loss = 0.
|
||||
|
||||
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
|
||||
with autocast(enabled = self.amp):
|
||||
# with autocast(enabled = self.amp):
|
||||
with self.accelerator.autocast():
|
||||
loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs)
|
||||
loss = loss * chunk_size_frac
|
||||
|
||||
total_loss += loss.item()
|
||||
|
||||
if self.training:
|
||||
self.scale(loss, unet_number = unet_number).backward()
|
||||
self.accelerator.backward(loss)
|
||||
|
||||
return total_loss
|
||||
|
||||
30
dalle2_pytorch/utils.py
Normal file
30
dalle2_pytorch/utils.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import time
|
||||
import importlib
|
||||
|
||||
# time helpers
|
||||
|
||||
class Timer:
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.last_time = time.time()
|
||||
|
||||
def elapsed(self):
|
||||
return time.time() - self.last_time
|
||||
|
||||
# print helpers
|
||||
|
||||
def print_ribbon(s, symbol = '=', repeat = 40):
|
||||
flank = symbol * repeat
|
||||
return f'{flank} {s} {flank}'
|
||||
|
||||
# import helpers
|
||||
|
||||
def import_or_print_error(pkg_name, err_str = None):
|
||||
try:
|
||||
return importlib.import_module(pkg_name)
|
||||
except ModuleNotFoundError as e:
|
||||
if exists(err_str):
|
||||
print(err_str)
|
||||
exit()
|
||||
1
dalle2_pytorch/version.py
Normal file
1
dalle2_pytorch/version.py
Normal file
@@ -0,0 +1 @@
|
||||
__version__ = '0.12.0'
|
||||
@@ -68,8 +68,8 @@ def group_dict_by_key(cond, d):
|
||||
return_val[ind][key] = d[key]
|
||||
return (*return_val,)
|
||||
|
||||
def string_begins_with(prefix, str):
|
||||
return str.startswith(prefix)
|
||||
def string_begins_with(prefix, string_input):
|
||||
return string_input.startswith(prefix)
|
||||
|
||||
def group_by_key_prefix(prefix, d):
|
||||
return group_dict_by_key(partial(string_begins_with, prefix), d)
|
||||
|
||||
@@ -16,10 +16,11 @@ from torchvision.utils import make_grid, save_image
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
from dalle2_pytorch.train import EMA
|
||||
from dalle2_pytorch.vqgan_vae import VQGanVAE
|
||||
from dalle2_pytorch.optimizer import get_optimizer
|
||||
|
||||
from ema_pytorch import EMA
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
@@ -97,7 +98,7 @@ class VQGanVAETrainer(nn.Module):
|
||||
valid_frac = 0.05,
|
||||
random_split_seed = 42,
|
||||
ema_beta = 0.995,
|
||||
ema_update_after_step = 2000,
|
||||
ema_update_after_step = 500,
|
||||
ema_update_every = 10,
|
||||
apply_grad_penalty_every = 4,
|
||||
amp = False
|
||||
|
||||
11
setup.py
11
setup.py
@@ -1,4 +1,5 @@
|
||||
from setuptools import setup, find_packages
|
||||
exec(open('dalle2_pytorch/version.py').read())
|
||||
|
||||
setup(
|
||||
name = 'dalle2-pytorch',
|
||||
@@ -10,7 +11,7 @@ setup(
|
||||
'dream = dalle2_pytorch.cli:dream'
|
||||
],
|
||||
},
|
||||
version = '0.3.2',
|
||||
version = __version__,
|
||||
license='MIT',
|
||||
description = 'DALL-E 2',
|
||||
author = 'Phil Wang',
|
||||
@@ -23,15 +24,19 @@ setup(
|
||||
'text to image'
|
||||
],
|
||||
install_requires=[
|
||||
'accelerate',
|
||||
'click',
|
||||
'clip-anytorch',
|
||||
'coca-pytorch>=0.0.5',
|
||||
'ema-pytorch>=0.0.7',
|
||||
'einops>=0.4',
|
||||
'einops-exts>=0.0.3',
|
||||
'embedding-reader',
|
||||
'kornia>=0.5.4',
|
||||
'numpy',
|
||||
'packaging',
|
||||
'pillow',
|
||||
'pydantic',
|
||||
'resize-right>=0.0.2',
|
||||
'rotary-embedding-torch',
|
||||
'torch>=1.10',
|
||||
@@ -39,9 +44,9 @@ setup(
|
||||
'tqdm',
|
||||
'vector-quantize-pytorch',
|
||||
'x-clip>=0.4.4',
|
||||
'youtokentome',
|
||||
'webdataset>=0.2.5',
|
||||
'fsspec>=2022.1.0'
|
||||
'fsspec>=2022.1.0',
|
||||
'torchmetrics[image]>=0.8.0'
|
||||
],
|
||||
classifiers=[
|
||||
'Development Status :: 4 - Beta',
|
||||
|
||||
638
train_decoder.py
Normal file
638
train_decoder.py
Normal file
@@ -0,0 +1,638 @@
|
||||
from pathlib import Path
|
||||
|
||||
from dalle2_pytorch.trainer import DecoderTrainer
|
||||
from dalle2_pytorch.dataloaders import create_image_embedding_dataloader
|
||||
from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker, DummyTracker
|
||||
from dalle2_pytorch.train_configs import TrainDecoderConfig
|
||||
from dalle2_pytorch.utils import Timer, print_ribbon
|
||||
from dalle2_pytorch.dalle2_pytorch import resize_image_to
|
||||
from clip import tokenize
|
||||
|
||||
import torchvision
|
||||
import torch
|
||||
from torchmetrics.image.fid import FrechetInceptionDistance
|
||||
from torchmetrics.image.inception import InceptionScore
|
||||
from torchmetrics.image.kid import KernelInceptionDistance
|
||||
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
|
||||
from accelerate import Accelerator, DistributedDataParallelKwargs
|
||||
from accelerate.utils import dataclasses as accelerate_dataclasses
|
||||
import webdataset as wds
|
||||
import click
|
||||
|
||||
# constants
|
||||
|
||||
TRAIN_CALC_LOSS_EVERY_ITERS = 10
|
||||
VALID_CALC_LOSS_EVERY_ITERS = 10
|
||||
|
||||
# helpers functions
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
# main functions
|
||||
|
||||
def create_dataloaders(
|
||||
available_shards,
|
||||
webdataset_base_url,
|
||||
img_embeddings_url=None,
|
||||
text_embeddings_url=None,
|
||||
shard_width=6,
|
||||
num_workers=4,
|
||||
batch_size=32,
|
||||
n_sample_images=6,
|
||||
shuffle_train=True,
|
||||
resample_train=False,
|
||||
img_preproc = None,
|
||||
index_width=4,
|
||||
train_prop = 0.75,
|
||||
val_prop = 0.15,
|
||||
test_prop = 0.10,
|
||||
seed = 0,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Randomly splits the available shards into train, val, and test sets and returns a dataloader for each
|
||||
"""
|
||||
assert train_prop + test_prop + val_prop == 1
|
||||
num_train = round(train_prop*len(available_shards))
|
||||
num_test = round(test_prop*len(available_shards))
|
||||
num_val = len(available_shards) - num_train - num_test
|
||||
assert num_train + num_test + num_val == len(available_shards), f"{num_train} + {num_test} + {num_val} = {num_train + num_test + num_val} != {len(available_shards)}"
|
||||
train_split, test_split, val_split = torch.utils.data.random_split(available_shards, [num_train, num_test, num_val], generator=torch.Generator().manual_seed(seed))
|
||||
|
||||
# The shard number in the webdataset file names has a fixed width. We zero pad the shard numbers so they correspond to a filename.
|
||||
train_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in train_split]
|
||||
test_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in test_split]
|
||||
val_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in val_split]
|
||||
|
||||
create_dataloader = lambda tar_urls, shuffle=False, resample=False, for_sampling=False: create_image_embedding_dataloader(
|
||||
tar_url=tar_urls,
|
||||
num_workers=num_workers,
|
||||
batch_size=batch_size if not for_sampling else n_sample_images,
|
||||
img_embeddings_url=img_embeddings_url,
|
||||
text_embeddings_url=text_embeddings_url,
|
||||
index_width=index_width,
|
||||
shuffle_num = None,
|
||||
extra_keys= ["txt"],
|
||||
shuffle_shards = shuffle,
|
||||
resample_shards = resample,
|
||||
img_preproc=img_preproc,
|
||||
handler=wds.handlers.warn_and_continue
|
||||
)
|
||||
|
||||
train_dataloader = create_dataloader(train_urls, shuffle=shuffle_train, resample=resample_train)
|
||||
train_sampling_dataloader = create_dataloader(train_urls, shuffle=False, for_sampling=True)
|
||||
val_dataloader = create_dataloader(val_urls, shuffle=False)
|
||||
test_dataloader = create_dataloader(test_urls, shuffle=False)
|
||||
test_sampling_dataloader = create_dataloader(test_urls, shuffle=False, for_sampling=True)
|
||||
return {
|
||||
"train": train_dataloader,
|
||||
"train_sampling": train_sampling_dataloader,
|
||||
"val": val_dataloader,
|
||||
"test": test_dataloader,
|
||||
"test_sampling": test_sampling_dataloader
|
||||
}
|
||||
|
||||
def get_dataset_keys(dataloader):
|
||||
"""
|
||||
It is sometimes neccesary to get the keys the dataloader is returning. Since the dataset is burried in the dataloader, we need to do a process to recover it.
|
||||
"""
|
||||
# If the dataloader is actually a WebLoader, we need to extract the real dataloader
|
||||
if isinstance(dataloader, wds.WebLoader):
|
||||
dataloader = dataloader.pipeline[0]
|
||||
return dataloader.dataset.key_map
|
||||
|
||||
def get_example_data(dataloader, device, n=5):
|
||||
"""
|
||||
Samples the dataloader and returns a zipped list of examples
|
||||
"""
|
||||
images = []
|
||||
img_embeddings = []
|
||||
text_embeddings = []
|
||||
captions = []
|
||||
for img, emb, txt in dataloader:
|
||||
img_emb, text_emb = emb.get('img'), emb.get('text')
|
||||
if img_emb is not None:
|
||||
img_emb = img_emb.to(device=device, dtype=torch.float)
|
||||
img_embeddings.extend(list(img_emb))
|
||||
else:
|
||||
# Then we add None img.shape[0] times
|
||||
img_embeddings.extend([None]*img.shape[0])
|
||||
if text_emb is not None:
|
||||
text_emb = text_emb.to(device=device, dtype=torch.float)
|
||||
text_embeddings.extend(list(text_emb))
|
||||
else:
|
||||
# Then we add None img.shape[0] times
|
||||
text_embeddings.extend([None]*img.shape[0])
|
||||
img = img.to(device=device, dtype=torch.float)
|
||||
images.extend(list(img))
|
||||
captions.extend(list(txt))
|
||||
if len(images) >= n:
|
||||
break
|
||||
return list(zip(images[:n], img_embeddings[:n], text_embeddings[:n], captions[:n]))
|
||||
|
||||
def generate_samples(trainer, example_data, condition_on_text_encodings=False, text_prepend=""):
|
||||
"""
|
||||
Takes example data and generates images from the embeddings
|
||||
Returns three lists: real images, generated images, and captions
|
||||
"""
|
||||
real_images, img_embeddings, text_embeddings, txts = zip(*example_data)
|
||||
sample_params = {}
|
||||
if img_embeddings[0] is None:
|
||||
# Generate image embeddings from clip
|
||||
imgs_tensor = torch.stack(real_images)
|
||||
img_embeddings, *_ = trainer.embed_image(imgs_tensor)
|
||||
sample_params["image_embed"] = img_embeddings
|
||||
else:
|
||||
# Then we are using precomputed image embeddings
|
||||
img_embeddings = torch.stack(img_embeddings)
|
||||
sample_params["image_embed"] = img_embeddings
|
||||
if condition_on_text_encodings:
|
||||
if text_embeddings[0] is None:
|
||||
# Generate text embeddings from text
|
||||
tokenized_texts = tokenize(txts, truncate=True)
|
||||
sample_params["text"] = tokenized_texts
|
||||
else:
|
||||
# Then we are using precomputed text embeddings
|
||||
text_embeddings = torch.stack(text_embeddings)
|
||||
sample_params["text_encodings"] = text_embeddings
|
||||
samples = trainer.sample(**sample_params)
|
||||
generated_images = list(samples)
|
||||
captions = [text_prepend + txt for txt in txts]
|
||||
return real_images, generated_images, captions
|
||||
|
||||
def generate_grid_samples(trainer, examples, condition_on_text_encodings=False, text_prepend=""):
|
||||
"""
|
||||
Generates samples and uses torchvision to put them in a side by side grid for easy viewing
|
||||
"""
|
||||
real_images, generated_images, captions = generate_samples(trainer, examples, condition_on_text_encodings, text_prepend)
|
||||
|
||||
real_image_size = real_images[0].shape[-1]
|
||||
generated_image_size = generated_images[0].shape[-1]
|
||||
|
||||
# training images may be larger than the generated one
|
||||
if real_image_size > generated_image_size:
|
||||
real_images = [resize_image_to(image, generated_image_size) for image in real_images]
|
||||
|
||||
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
|
||||
return grid_images, captions
|
||||
|
||||
def evaluate_trainer(trainer, dataloader, device, condition_on_text_encodings=False, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
|
||||
"""
|
||||
Computes evaluation metrics for the decoder
|
||||
"""
|
||||
metrics = {}
|
||||
# Prepare the data
|
||||
examples = get_example_data(dataloader, device, n_evaluation_samples)
|
||||
if len(examples) == 0:
|
||||
print("No data to evaluate. Check that your dataloader has shards.")
|
||||
return metrics
|
||||
real_images, generated_images, captions = generate_samples(trainer, examples, condition_on_text_encodings)
|
||||
real_images = torch.stack(real_images).to(device=device, dtype=torch.float)
|
||||
generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float)
|
||||
# Convert from [0, 1] to [0, 255] and from torch.float to torch.uint8
|
||||
int_real_images = real_images.mul(255).add(0.5).clamp(0, 255).type(torch.uint8)
|
||||
int_generated_images = generated_images.mul(255).add(0.5).clamp(0, 255).type(torch.uint8)
|
||||
|
||||
def null_sync(t, *args, **kwargs):
|
||||
return [t]
|
||||
|
||||
if exists(FID):
|
||||
fid = FrechetInceptionDistance(**FID, dist_sync_fn=null_sync)
|
||||
fid.to(device=device)
|
||||
fid.update(int_real_images, real=True)
|
||||
fid.update(int_generated_images, real=False)
|
||||
metrics["FID"] = fid.compute().item()
|
||||
if exists(IS):
|
||||
inception = InceptionScore(**IS, dist_sync_fn=null_sync)
|
||||
inception.to(device=device)
|
||||
inception.update(int_real_images)
|
||||
is_mean, is_std = inception.compute()
|
||||
metrics["IS_mean"] = is_mean.item()
|
||||
metrics["IS_std"] = is_std.item()
|
||||
if exists(KID):
|
||||
kernel_inception = KernelInceptionDistance(**KID, dist_sync_fn=null_sync)
|
||||
kernel_inception.to(device=device)
|
||||
kernel_inception.update(int_real_images, real=True)
|
||||
kernel_inception.update(int_generated_images, real=False)
|
||||
kid_mean, kid_std = kernel_inception.compute()
|
||||
metrics["KID_mean"] = kid_mean.item()
|
||||
metrics["KID_std"] = kid_std.item()
|
||||
if exists(LPIPS):
|
||||
# Convert from [0, 1] to [-1, 1]
|
||||
renorm_real_images = real_images.mul(2).sub(1)
|
||||
renorm_generated_images = generated_images.mul(2).sub(1)
|
||||
lpips = LearnedPerceptualImagePatchSimilarity(**LPIPS, dist_sync_fn=null_sync)
|
||||
lpips.to(device=device)
|
||||
lpips.update(renorm_real_images, renorm_generated_images)
|
||||
metrics["LPIPS"] = lpips.compute().item()
|
||||
|
||||
if trainer.accelerator.num_processes > 1:
|
||||
# Then we should sync the metrics
|
||||
metrics_order = sorted(metrics.keys())
|
||||
metrics_tensor = torch.zeros(1, len(metrics), device=device, dtype=torch.float)
|
||||
for i, metric_name in enumerate(metrics_order):
|
||||
metrics_tensor[0, i] = metrics[metric_name]
|
||||
metrics_tensor = trainer.accelerator.gather(metrics_tensor)
|
||||
metrics_tensor = metrics_tensor.mean(dim=0)
|
||||
for i, metric_name in enumerate(metrics_order):
|
||||
metrics[metric_name] = metrics_tensor[i].item()
|
||||
return metrics
|
||||
|
||||
def save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, relative_paths):
|
||||
"""
|
||||
Logs the model with an appropriate method depending on the tracker
|
||||
"""
|
||||
if isinstance(relative_paths, str):
|
||||
relative_paths = [relative_paths]
|
||||
for relative_path in relative_paths:
|
||||
local_path = str(tracker.data_path / relative_path)
|
||||
trainer.save(local_path, epoch=epoch, sample=sample, next_task=next_task, validation_losses=validation_losses)
|
||||
tracker.save_file(local_path)
|
||||
|
||||
def recall_trainer(tracker, trainer, recall_source=None, **load_config):
|
||||
"""
|
||||
Loads the model with an appropriate method depending on the tracker
|
||||
"""
|
||||
trainer.accelerator.print(print_ribbon(f"Loading model from {recall_source}"))
|
||||
local_filepath = tracker.recall_file(recall_source, **load_config)
|
||||
state_dict = trainer.load(local_filepath)
|
||||
return state_dict.get("epoch", 0), state_dict.get("validation_losses", []), state_dict.get("next_task", "train"), state_dict.get("sample", 0)
|
||||
|
||||
def train(
|
||||
dataloaders,
|
||||
decoder,
|
||||
accelerator,
|
||||
tracker,
|
||||
inference_device,
|
||||
load_config=None,
|
||||
evaluate_config=None,
|
||||
epoch_samples = None, # If the training dataset is resampling, we have to manually stop an epoch
|
||||
validation_samples = None,
|
||||
epochs = 20,
|
||||
n_sample_images = 5,
|
||||
save_every_n_samples = 100000,
|
||||
save_all=False,
|
||||
save_latest=True,
|
||||
save_best=True,
|
||||
unet_training_mask=None,
|
||||
condition_on_text_encodings=False,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Trains a decoder on a dataset.
|
||||
"""
|
||||
is_master = accelerator.process_index == 0
|
||||
|
||||
trainer = DecoderTrainer(
|
||||
decoder=decoder,
|
||||
accelerator=accelerator,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
# Set up starting model and parameters based on a recalled state dict
|
||||
start_epoch = 0
|
||||
validation_losses = []
|
||||
next_task = 'train'
|
||||
sample = 0
|
||||
samples_seen = 0
|
||||
val_sample = 0
|
||||
step = lambda: int(trainer.step.item())
|
||||
|
||||
if exists(load_config) and exists(load_config.source):
|
||||
start_epoch, validation_losses, next_task, recalled_sample = recall_trainer(tracker, trainer, recall_source=load_config.source, **load_config.dict())
|
||||
if next_task == 'train':
|
||||
sample = recalled_sample
|
||||
if next_task == 'val':
|
||||
val_sample = recalled_sample
|
||||
accelerator.print(f"Loaded model from {load_config.source} on epoch {start_epoch} with minimum validation loss {min(validation_losses) if len(validation_losses) > 0 else 'N/A'}")
|
||||
accelerator.print(f"Starting training from task {next_task} at sample {sample} and validation sample {val_sample}")
|
||||
trainer.to(device=inference_device)
|
||||
|
||||
if not exists(unet_training_mask):
|
||||
# Then the unet mask should be true for all unets in the decoder
|
||||
unet_training_mask = [True] * trainer.num_unets
|
||||
assert len(unet_training_mask) == trainer.num_unets, f"The unet training mask should be the same length as the number of unets in the decoder. Got {len(unet_training_mask)} and {trainer.num_unets}"
|
||||
|
||||
accelerator.print(print_ribbon("Generating Example Data", repeat=40))
|
||||
accelerator.print("This can take a while to load the shard lists...")
|
||||
if is_master:
|
||||
train_example_data = get_example_data(dataloaders["train_sampling"], inference_device, n_sample_images)
|
||||
accelerator.print("Generated training examples")
|
||||
test_example_data = get_example_data(dataloaders["test_sampling"], inference_device, n_sample_images)
|
||||
accelerator.print("Generated testing examples")
|
||||
|
||||
send_to_device = lambda arr: [x.to(device=inference_device, dtype=torch.float) for x in arr]
|
||||
|
||||
sample_length_tensor = torch.zeros(1, dtype=torch.int, device=inference_device)
|
||||
unet_losses_tensor = torch.zeros(TRAIN_CALC_LOSS_EVERY_ITERS, trainer.num_unets, dtype=torch.float, device=inference_device)
|
||||
for epoch in range(start_epoch, epochs):
|
||||
accelerator.print(print_ribbon(f"Starting epoch {epoch}", repeat=40))
|
||||
|
||||
timer = Timer()
|
||||
last_sample = sample
|
||||
last_snapshot = sample
|
||||
|
||||
if next_task == 'train':
|
||||
for i, (img, emb, txt) in enumerate(dataloaders["train"]):
|
||||
# We want to count the total number of samples across all processes
|
||||
sample_length_tensor[0] = len(img)
|
||||
all_samples = accelerator.gather(sample_length_tensor) # TODO: accelerator.reduce is broken when this was written. If it is fixed replace this.
|
||||
total_samples = all_samples.sum().item()
|
||||
sample += total_samples
|
||||
samples_seen += total_samples
|
||||
img_emb = emb.get('img')
|
||||
has_img_embedding = img_emb is not None
|
||||
if has_img_embedding:
|
||||
img_emb, = send_to_device((img_emb,))
|
||||
text_emb = emb.get('text')
|
||||
has_text_embedding = text_emb is not None
|
||||
if has_text_embedding:
|
||||
text_emb, = send_to_device((text_emb,))
|
||||
img, = send_to_device((img,))
|
||||
|
||||
trainer.train()
|
||||
for unet in range(1, trainer.num_unets+1):
|
||||
# Check if this is a unet we are training
|
||||
if not unet_training_mask[unet-1]: # Unet index is the unet number - 1
|
||||
continue
|
||||
|
||||
forward_params = {}
|
||||
if has_img_embedding:
|
||||
forward_params['image_embed'] = img_emb
|
||||
else:
|
||||
# Forward pass automatically generates embedding
|
||||
pass
|
||||
if condition_on_text_encodings:
|
||||
if has_text_embedding:
|
||||
forward_params['text_encodings'] = text_emb
|
||||
else:
|
||||
# Then we need to pass the text instead
|
||||
tokenized_texts = tokenize(txt, truncate=True)
|
||||
forward_params['text'] = tokenized_texts
|
||||
loss = trainer.forward(img, **forward_params, unet_number=unet)
|
||||
trainer.update(unet_number=unet)
|
||||
unet_losses_tensor[i % TRAIN_CALC_LOSS_EVERY_ITERS, unet-1] = loss
|
||||
|
||||
samples_per_sec = (sample - last_sample) / timer.elapsed()
|
||||
timer.reset()
|
||||
last_sample = sample
|
||||
|
||||
if i % TRAIN_CALC_LOSS_EVERY_ITERS == 0:
|
||||
# We want to average losses across all processes
|
||||
unet_all_losses = accelerator.gather(unet_losses_tensor)
|
||||
mask = unet_all_losses != 0
|
||||
unet_average_loss = (unet_all_losses * mask).sum(dim=0) / mask.sum(dim=0)
|
||||
loss_map = { f"Unet {index} Training Loss": loss.item() for index, loss in enumerate(unet_average_loss) if loss != 0 }
|
||||
|
||||
# gather decay rate on each UNet
|
||||
ema_decay_list = {f"Unet {index} EMA Decay": ema_unet.get_current_decay() for index, ema_unet in enumerate(trainer.ema_unets)}
|
||||
|
||||
log_data = {
|
||||
"Epoch": epoch,
|
||||
"Sample": sample,
|
||||
"Step": i,
|
||||
"Samples per second": samples_per_sec,
|
||||
"Samples Seen": samples_seen,
|
||||
**ema_decay_list,
|
||||
**loss_map
|
||||
}
|
||||
|
||||
if is_master:
|
||||
tracker.log(log_data, step=step(), verbose=True)
|
||||
|
||||
if is_master and last_snapshot + save_every_n_samples < sample: # This will miss by some amount every time, but it's not a big deal... I hope
|
||||
# It is difficult to gather this kind of info on the accelerator, so we have to do it on the master
|
||||
print("Saving snapshot")
|
||||
last_snapshot = sample
|
||||
# We need to know where the model should be saved
|
||||
save_paths = []
|
||||
if save_latest:
|
||||
save_paths.append("latest.pth")
|
||||
if save_all:
|
||||
save_paths.append(f"checkpoints/epoch_{epoch}_step_{step()}.pth")
|
||||
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, save_paths)
|
||||
if exists(n_sample_images) and n_sample_images > 0:
|
||||
trainer.eval()
|
||||
train_images, train_captions = generate_grid_samples(trainer, train_example_data, condition_on_text_encodings, "Train: ")
|
||||
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
|
||||
|
||||
if epoch_samples is not None and sample >= epoch_samples:
|
||||
break
|
||||
next_task = 'val'
|
||||
sample = 0
|
||||
|
||||
all_average_val_losses = None
|
||||
if next_task == 'val':
|
||||
trainer.eval()
|
||||
accelerator.print(print_ribbon(f"Starting Validation {epoch}", repeat=40))
|
||||
last_val_sample = val_sample
|
||||
val_sample_length_tensor = torch.zeros(1, dtype=torch.int, device=inference_device)
|
||||
average_val_loss_tensor = torch.zeros(1, trainer.num_unets, dtype=torch.float, device=inference_device)
|
||||
timer = Timer()
|
||||
accelerator.wait_for_everyone()
|
||||
i = 0
|
||||
for i, (img, emb, txt) in enumerate(dataloaders["val"]):
|
||||
val_sample_length_tensor[0] = len(img)
|
||||
all_samples = accelerator.gather(val_sample_length_tensor)
|
||||
total_samples = all_samples.sum().item()
|
||||
val_sample += total_samples
|
||||
img_emb = emb.get('img')
|
||||
has_img_embedding = img_emb is not None
|
||||
if has_img_embedding:
|
||||
img_emb, = send_to_device((img_emb,))
|
||||
text_emb = emb.get('text')
|
||||
has_text_embedding = text_emb is not None
|
||||
if has_text_embedding:
|
||||
text_emb, = send_to_device((text_emb,))
|
||||
img, = send_to_device((img,))
|
||||
|
||||
for unet in range(1, len(decoder.unets)+1):
|
||||
if not unet_training_mask[unet-1]: # Unet index is the unet number - 1
|
||||
# No need to evaluate an unchanging unet
|
||||
continue
|
||||
|
||||
forward_params = {}
|
||||
if has_img_embedding:
|
||||
forward_params['image_embed'] = img_emb.float()
|
||||
else:
|
||||
# Forward pass automatically generates embedding
|
||||
pass
|
||||
if condition_on_text_encodings:
|
||||
if has_text_embedding:
|
||||
forward_params['text_encodings'] = text_emb.float()
|
||||
else:
|
||||
# Then we need to pass the text instead
|
||||
tokenized_texts = tokenize(txt, truncate=True)
|
||||
forward_params['text'] = tokenized_texts
|
||||
loss = trainer.forward(img.float(), **forward_params, unet_number=unet)
|
||||
average_val_loss_tensor[0, unet-1] += loss
|
||||
|
||||
if i % VALID_CALC_LOSS_EVERY_ITERS == 0:
|
||||
samples_per_sec = (val_sample - last_val_sample) / timer.elapsed()
|
||||
timer.reset()
|
||||
last_val_sample = val_sample
|
||||
accelerator.print(f"Epoch {epoch}/{epochs} Val Step {i} - Sample {val_sample} - {samples_per_sec:.2f} samples/sec")
|
||||
accelerator.print(f"Loss: {(average_val_loss_tensor / (i+1))}")
|
||||
accelerator.print("")
|
||||
|
||||
if validation_samples is not None and val_sample >= validation_samples:
|
||||
break
|
||||
print(f"Rank {accelerator.state.process_index} finished validation after {i} steps")
|
||||
accelerator.wait_for_everyone()
|
||||
average_val_loss_tensor /= i+1
|
||||
# Gather all the average loss tensors
|
||||
all_average_val_losses = accelerator.gather(average_val_loss_tensor)
|
||||
if is_master:
|
||||
unet_average_val_loss = all_average_val_losses.mean(dim=0)
|
||||
val_loss_map = { f"Unet {index} Validation Loss": loss.item() for index, loss in enumerate(unet_average_val_loss) if loss != 0 }
|
||||
tracker.log(val_loss_map, step=step(), verbose=True)
|
||||
next_task = 'eval'
|
||||
|
||||
if next_task == 'eval':
|
||||
if exists(evaluate_config):
|
||||
accelerator.print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
|
||||
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config.dict(), condition_on_text_encodings=condition_on_text_encodings)
|
||||
if is_master:
|
||||
tracker.log(evaluation, step=step(), verbose=True)
|
||||
next_task = 'sample'
|
||||
val_sample = 0
|
||||
|
||||
if next_task == 'sample':
|
||||
if is_master:
|
||||
# Generate examples and save the model if we are the master
|
||||
# Generate sample images
|
||||
print(print_ribbon(f"Sampling Set {epoch}", repeat=40))
|
||||
test_images, test_captions = generate_grid_samples(trainer, test_example_data, condition_on_text_encodings, "Test: ")
|
||||
train_images, train_captions = generate_grid_samples(trainer, train_example_data, condition_on_text_encodings, "Train: ")
|
||||
tracker.log_images(test_images, captions=test_captions, image_section="Test Samples", step=step())
|
||||
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
|
||||
|
||||
print(print_ribbon(f"Starting Saving {epoch}", repeat=40))
|
||||
# Get the same paths
|
||||
save_paths = []
|
||||
if save_latest:
|
||||
save_paths.append("latest.pth")
|
||||
if all_average_val_losses is not None:
|
||||
average_loss = all_average_val_losses.mean(dim=0).item()
|
||||
if save_best and (len(validation_losses) == 0 or average_loss < min(validation_losses)):
|
||||
save_paths.append("best.pth")
|
||||
validation_losses.append(average_loss)
|
||||
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, save_paths)
|
||||
next_task = 'train'
|
||||
|
||||
def create_tracker(accelerator, config, config_path, tracker_type=None, data_path=None):
|
||||
"""
|
||||
Creates a tracker of the specified type and initializes special features based on the full config
|
||||
"""
|
||||
tracker_config = config.tracker
|
||||
accelerator_config = {
|
||||
"Distributed": accelerator.distributed_type != accelerate_dataclasses.DistributedType.NO,
|
||||
"DistributedType": accelerator.distributed_type,
|
||||
"NumProcesses": accelerator.num_processes,
|
||||
"MixedPrecision": accelerator.mixed_precision
|
||||
}
|
||||
init_config = { "config": {**config.dict(), **accelerator_config} }
|
||||
data_path = data_path or tracker_config.data_path
|
||||
tracker_type = tracker_type or tracker_config.tracker_type
|
||||
|
||||
if tracker_type == "dummy":
|
||||
tracker = DummyTracker(data_path)
|
||||
tracker.init(**init_config)
|
||||
elif tracker_type == "console":
|
||||
tracker = ConsoleTracker(data_path)
|
||||
tracker.init(**init_config)
|
||||
elif tracker_type == "wandb":
|
||||
# We need to initialize the resume state here
|
||||
load_config = config.load
|
||||
if load_config.source == "wandb" and load_config.resume:
|
||||
# Then we are resuming the run load_config["run_path"]
|
||||
run_id = load_config.run_path.split("/")[-1]
|
||||
init_config["id"] = run_id
|
||||
init_config["resume"] = "must"
|
||||
|
||||
init_config["entity"] = tracker_config.wandb_entity
|
||||
init_config["project"] = tracker_config.wandb_project
|
||||
tracker = WandbTracker(data_path)
|
||||
tracker.init(**init_config)
|
||||
tracker.save_file(str(config_path.absolute()), str(config_path.parent.absolute()))
|
||||
else:
|
||||
raise ValueError(f"Tracker type {tracker_type} not supported by decoder trainer")
|
||||
return tracker
|
||||
|
||||
def initialize_training(config, config_path):
|
||||
# Make sure if we are not loading, distributed models are initialized to the same values
|
||||
torch.manual_seed(config.seed)
|
||||
|
||||
# Set up accelerator for configurable distributed training
|
||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
||||
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
|
||||
|
||||
# Set up data
|
||||
all_shards = list(range(config.data.start_shard, config.data.end_shard + 1))
|
||||
world_size = accelerator.num_processes
|
||||
rank = accelerator.process_index
|
||||
shards_per_process = len(all_shards) // world_size
|
||||
assert shards_per_process > 0, "Not enough shards to split evenly"
|
||||
my_shards = all_shards[rank * shards_per_process: (rank + 1) * shards_per_process]
|
||||
dataloaders = create_dataloaders (
|
||||
available_shards=my_shards,
|
||||
img_preproc = config.data.img_preproc,
|
||||
train_prop = config.data.splits.train,
|
||||
val_prop = config.data.splits.val,
|
||||
test_prop = config.data.splits.test,
|
||||
n_sample_images=config.train.n_sample_images,
|
||||
**config.data.dict(),
|
||||
rank = rank,
|
||||
seed = config.seed,
|
||||
)
|
||||
|
||||
# Create the decoder model and print basic info
|
||||
decoder = config.decoder.create()
|
||||
num_parameters = sum(p.numel() for p in decoder.parameters())
|
||||
|
||||
# Create and initialize the tracker if we are the master
|
||||
tracker = create_tracker(accelerator, config, config_path) if rank == 0 else create_tracker(accelerator, config, config_path, tracker_type="dummy")
|
||||
|
||||
has_img_embeddings = config.data.img_embeddings_url is not None
|
||||
has_text_embeddings = config.data.text_embeddings_url is not None
|
||||
conditioning_on_text = config.decoder.condition_on_text_encodings
|
||||
has_clip_model = config.decoder.clip is not None
|
||||
data_source_string = ""
|
||||
if has_img_embeddings:
|
||||
data_source_string += "precomputed image embeddings"
|
||||
elif has_clip_model:
|
||||
data_source_string += "clip image embeddings generation"
|
||||
else:
|
||||
raise ValueError("No image embeddings source specified")
|
||||
if conditioning_on_text:
|
||||
if has_text_embeddings:
|
||||
data_source_string += " and precomputed text embeddings"
|
||||
elif has_clip_model:
|
||||
data_source_string += " and clip text encoding generation"
|
||||
else:
|
||||
raise ValueError("No text embeddings source specified")
|
||||
|
||||
accelerator.print(print_ribbon("Loaded Config", repeat=40))
|
||||
accelerator.print(f"Running training with {accelerator.num_processes} processes and {accelerator.distributed_type} distributed training")
|
||||
accelerator.print(f"Training using {data_source_string}. {'conditioned on text' if conditioning_on_text else 'not conditioned on text'}")
|
||||
accelerator.print(f"Number of parameters: {num_parameters}")
|
||||
train(dataloaders, decoder, accelerator,
|
||||
tracker=tracker,
|
||||
inference_device=accelerator.device,
|
||||
load_config=config.load,
|
||||
evaluate_config=config.evaluate,
|
||||
condition_on_text_encodings=config.decoder.condition_on_text_encodings,
|
||||
**config.train.dict(),
|
||||
)
|
||||
|
||||
# Create a simple click command line interface to load the config and start the training
|
||||
@click.command()
|
||||
@click.option("--config_file", default="./train_decoder_config.json", help="Path to config file")
|
||||
def main(config_file):
|
||||
config_file_path = Path(config_file)
|
||||
config = TrainDecoderConfig.from_json_path(str(config_file_path))
|
||||
initialize_training(config, config_path=config_file_path)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,85 +1,135 @@
|
||||
from pathlib import Path
|
||||
# TODO: add start, num_data_points, eval_every and group to config
|
||||
# TODO: switch back to repo's wandb
|
||||
|
||||
START = 0
|
||||
NUM_DATA_POINTS = 250e6
|
||||
EVAL_EVERY = 1000
|
||||
GROUP = "distributed"
|
||||
|
||||
import os
|
||||
import click
|
||||
import math
|
||||
import time
|
||||
import numpy as np
|
||||
import wandb
|
||||
|
||||
import torch
|
||||
import clip
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from dalle2_pytorch.dataloaders import make_splits
|
||||
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter
|
||||
from dalle2_pytorch.trainer import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model, print_ribbon
|
||||
import numpy as np
|
||||
|
||||
from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker
|
||||
from accelerate import Accelerator
|
||||
|
||||
from embedding_reader import EmbeddingReader
|
||||
from dalle2_pytorch.dataloaders import get_reader, make_splits
|
||||
from dalle2_pytorch.utils import Timer
|
||||
from dalle2_pytorch.train_configs import (
|
||||
DiffusionPriorTrainConfig,
|
||||
TrainDiffusionPriorConfig,
|
||||
)
|
||||
from dalle2_pytorch.trackers import BaseTracker, WandbTracker
|
||||
from dalle2_pytorch import DiffusionPriorTrainer
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
# constants
|
||||
# helpers
|
||||
|
||||
REPORT_METRICS_EVERY = 250 # for cosine similarity and other metric reporting during training
|
||||
|
||||
tracker = WandbTracker()
|
||||
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
|
||||
|
||||
# helpers functions
|
||||
|
||||
def exists(val):
|
||||
val is not None
|
||||
return val is not None
|
||||
|
||||
class Timer:
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.last_time = time.time()
|
||||
def make_model(
|
||||
prior_config, train_config, device: str = None, accelerator: Accelerator = None
|
||||
):
|
||||
# create model from config
|
||||
diffusion_prior = prior_config.create()
|
||||
|
||||
def elapsed(self):
|
||||
return time.time() - self.last_time
|
||||
# instantiate the trainer
|
||||
trainer = DiffusionPriorTrainer(
|
||||
diffusion_prior=diffusion_prior,
|
||||
lr=train_config.lr,
|
||||
wd=train_config.wd,
|
||||
max_grad_norm=train_config.max_grad_norm,
|
||||
amp=train_config.amp,
|
||||
use_ema=train_config.use_ema,
|
||||
device=device,
|
||||
accelerator=accelerator,
|
||||
)
|
||||
|
||||
# functions
|
||||
return trainer
|
||||
|
||||
def eval_model(model, dataloader, text_conditioned, loss_type, phase="Validation"):
|
||||
model.eval()
|
||||
|
||||
# eval functions
|
||||
|
||||
|
||||
def eval_model(
|
||||
trainer: DiffusionPriorTrainer,
|
||||
dataloader: DataLoader,
|
||||
text_conditioned: bool,
|
||||
loss_type: str,
|
||||
tracker_context: str,
|
||||
tracker: BaseTracker = None,
|
||||
use_ema: bool = True,
|
||||
):
|
||||
trainer.eval()
|
||||
if trainer.is_main_process():
|
||||
click.secho(f"Measuring performance on {tracker_context}", fg="green", blink=True)
|
||||
|
||||
with torch.no_grad():
|
||||
total_loss = 0.
|
||||
total_samples = 0.
|
||||
total_loss = 0.0
|
||||
total_samples = 0.0
|
||||
|
||||
for image_embeddings, text_data in tqdm(dataloader):
|
||||
for image_embeddings, text_data in dataloader:
|
||||
image_embeddings = image_embeddings.to(trainer.device)
|
||||
text_data = text_data.to(trainer.device)
|
||||
|
||||
batches = image_embeddings.shape[0]
|
||||
|
||||
input_args = dict(image_embed=image_embeddings)
|
||||
|
||||
if text_conditioned:
|
||||
input_args = dict(**input_args, text = text_data)
|
||||
input_args = dict(**input_args, text=text_data)
|
||||
else:
|
||||
input_args = dict(**input_args, text_embed=text_data)
|
||||
|
||||
loss = model(**input_args)
|
||||
if use_ema:
|
||||
loss = trainer.ema_diffusion_prior(**input_args)
|
||||
else:
|
||||
loss = trainer(**input_args)
|
||||
|
||||
total_loss += loss * batches
|
||||
total_samples += batches
|
||||
|
||||
avg_loss = (total_loss / total_samples)
|
||||
avg_loss = total_loss / total_samples
|
||||
|
||||
tracker.log({f'{phase} {loss_type}': avg_loss})
|
||||
stats = {f"{tracker_context}-{loss_type}": avg_loss}
|
||||
trainer.print(stats)
|
||||
|
||||
def report_cosine_sims(diffusion_prior, dataloader, text_conditioned):
|
||||
diffusion_prior.eval()
|
||||
if exists(tracker):
|
||||
tracker.log(stats, step=trainer.step.item() + 1)
|
||||
|
||||
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
|
||||
|
||||
for test_image_embeddings, text_data in tqdm(dataloader):
|
||||
def report_cosine_sims(
|
||||
trainer: DiffusionPriorTrainer,
|
||||
dataloader: DataLoader,
|
||||
text_conditioned: bool,
|
||||
tracker: BaseTracker,
|
||||
tracker_context: str = "validation",
|
||||
):
|
||||
trainer.eval()
|
||||
if trainer.is_main_process():
|
||||
click.secho("Measuring Cosine-Similarity", fg="green", blink=True)
|
||||
|
||||
for test_image_embeddings, text_data in dataloader:
|
||||
test_image_embeddings = test_image_embeddings.to(trainer.device)
|
||||
text_data = text_data.to(trainer.device)
|
||||
|
||||
# we are text conditioned, we produce an embedding from the tokenized text
|
||||
if text_conditioned:
|
||||
text_embedding, text_encodings, text_mask = diffusion_prior.clip.embed_text(
|
||||
text_data)
|
||||
text_cond = dict(text_embed=text_embedding,
|
||||
text_encodings=text_encodings, mask=text_mask)
|
||||
text_embedding, text_encodings, text_mask = trainer.embed_text(text_data)
|
||||
text_cond = dict(
|
||||
text_embed=text_embedding, text_encodings=text_encodings, mask=text_mask
|
||||
)
|
||||
else:
|
||||
text_embedding = text_data
|
||||
text_cond = dict(text_embed=text_embedding)
|
||||
@@ -90,8 +140,9 @@ def report_cosine_sims(diffusion_prior, dataloader, text_conditioned):
|
||||
# roll the text to simulate "unrelated" captions
|
||||
rolled_idx = torch.roll(torch.arange(text_embedding.shape[0]), 1)
|
||||
text_embed_shuffled = text_embed_shuffled[rolled_idx]
|
||||
text_embed_shuffled = text_embed_shuffled / \
|
||||
text_embed_shuffled.norm(dim=1, keepdim=True)
|
||||
text_embed_shuffled = text_embed_shuffled / text_embed_shuffled.norm(
|
||||
dim=1, keepdim=True
|
||||
)
|
||||
|
||||
if text_conditioned:
|
||||
text_encodings_shuffled = text_encodings[rolled_idx]
|
||||
@@ -100,276 +151,276 @@ def report_cosine_sims(diffusion_prior, dataloader, text_conditioned):
|
||||
text_encodings_shuffled = None
|
||||
text_mask_shuffled = None
|
||||
|
||||
text_cond_shuffled = dict(text_embed=text_embed_shuffled,
|
||||
text_encodings=text_encodings_shuffled, mask=text_mask_shuffled)
|
||||
text_cond_shuffled = dict(
|
||||
text_embed=text_embed_shuffled,
|
||||
text_encodings=text_encodings_shuffled,
|
||||
mask=text_mask_shuffled,
|
||||
)
|
||||
|
||||
# prepare the text embedding
|
||||
text_embed = text_embedding / text_embedding.norm(dim=1, keepdim=True)
|
||||
|
||||
# prepare image embeddings
|
||||
test_image_embeddings = test_image_embeddings / \
|
||||
test_image_embeddings.norm(dim=1, keepdim=True)
|
||||
test_image_embeddings = test_image_embeddings / test_image_embeddings.norm(
|
||||
dim=1, keepdim=True
|
||||
)
|
||||
|
||||
# predict on the unshuffled text embeddings
|
||||
predicted_image_embeddings = diffusion_prior.p_sample_loop(
|
||||
test_image_embeddings.shape, text_cond)
|
||||
predicted_image_embeddings = predicted_image_embeddings / \
|
||||
predicted_image_embeddings.norm(dim=1, keepdim=True)
|
||||
predicted_image_embeddings = trainer.p_sample_loop(
|
||||
test_image_embeddings.shape, text_cond
|
||||
)
|
||||
|
||||
predicted_image_embeddings = (
|
||||
predicted_image_embeddings
|
||||
/ predicted_image_embeddings.norm(dim=1, keepdim=True)
|
||||
)
|
||||
|
||||
# predict on the shuffled embeddings
|
||||
predicted_unrelated_embeddings = diffusion_prior.p_sample_loop(
|
||||
test_image_embeddings.shape, text_cond_shuffled)
|
||||
predicted_unrelated_embeddings = predicted_unrelated_embeddings / \
|
||||
predicted_unrelated_embeddings.norm(dim=1, keepdim=True)
|
||||
predicted_unrelated_embeddings = trainer.p_sample_loop(
|
||||
test_image_embeddings.shape, text_cond_shuffled
|
||||
)
|
||||
|
||||
predicted_unrelated_embeddings = (
|
||||
predicted_unrelated_embeddings
|
||||
/ predicted_unrelated_embeddings.norm(dim=1, keepdim=True)
|
||||
)
|
||||
|
||||
# calculate similarities
|
||||
original_similarity = cos(
|
||||
text_embed, test_image_embeddings).cpu().numpy()
|
||||
predicted_similarity = cos(
|
||||
text_embed, predicted_image_embeddings).cpu().numpy()
|
||||
unrelated_similarity = cos(
|
||||
text_embed, predicted_unrelated_embeddings).cpu().numpy()
|
||||
predicted_img_similarity = cos(
|
||||
test_image_embeddings, predicted_image_embeddings).cpu().numpy()
|
||||
tracker.log({"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity),
|
||||
"CosineSimilarity(text_embed,predicted_image_embed)":np.mean(predicted_similarity),
|
||||
"CosineSimilarity(orig_image_embed,predicted_image_embed)":np.mean(predicted_img_similarity),
|
||||
"CosineSimilarity(text_embed,predicted_unrelated_embed)": np.mean(unrelated_similarity),
|
||||
"Cosine similarity difference":np.mean(predicted_similarity - original_similarity)})
|
||||
original_similarity = cos(text_embed, test_image_embeddings).cpu().numpy()
|
||||
predicted_similarity = cos(text_embed, predicted_image_embeddings).cpu().numpy()
|
||||
unrelated_similarity = (
|
||||
cos(text_embed, predicted_unrelated_embeddings).cpu().numpy()
|
||||
)
|
||||
predicted_img_similarity = (
|
||||
cos(test_image_embeddings, predicted_image_embeddings).cpu().numpy()
|
||||
)
|
||||
|
||||
stats = {
|
||||
f"{tracker_context}/baseline similarity": np.mean(original_similarity),
|
||||
f"{tracker_context}/similarity with text": np.mean(predicted_similarity),
|
||||
f"{tracker_context}/similarity with original image": np.mean(
|
||||
predicted_img_similarity
|
||||
),
|
||||
f"{tracker_context}/similarity with unrelated caption": np.mean(unrelated_similarity),
|
||||
f"{tracker_context}/difference from baseline similarity": np.mean(
|
||||
predicted_similarity - original_similarity
|
||||
),
|
||||
}
|
||||
|
||||
for k, v in stats.items():
|
||||
trainer.print(f"{tracker_context}/{k}: {v}")
|
||||
|
||||
if exists(tracker):
|
||||
tracker.log(stats, step=trainer.step.item() + 1)
|
||||
|
||||
|
||||
# training script
|
||||
|
||||
|
||||
def train(
|
||||
trainer: DiffusionPriorTrainer,
|
||||
train_loader: DataLoader,
|
||||
eval_loader: DataLoader,
|
||||
test_loader: DataLoader,
|
||||
config: DiffusionPriorTrainConfig,
|
||||
):
|
||||
# distributed tracking with wandb
|
||||
if trainer.accelerator.num_processes > 1:
|
||||
os.environ["WANDB_START_METHOD"] = "thread"
|
||||
|
||||
tracker = wandb.init(
|
||||
name=f"RANK:{trainer.device}",
|
||||
entity=config.tracker.wandb_entity,
|
||||
project=config.tracker.wandb_project,
|
||||
config=config.dict(),
|
||||
group=GROUP,
|
||||
)
|
||||
|
||||
# sync after tracker init
|
||||
trainer.wait_for_everyone()
|
||||
|
||||
# init a timer
|
||||
timer = Timer()
|
||||
|
||||
# do training
|
||||
for img, txt in train_loader:
|
||||
trainer.train()
|
||||
current_step = trainer.step.item() + 1
|
||||
|
||||
# place data on device
|
||||
img = img.to(trainer.device)
|
||||
txt = txt.to(trainer.device)
|
||||
|
||||
# pass to model
|
||||
loss = trainer(text=txt, image_embed=img)
|
||||
|
||||
# display & log loss (will only print from main process)
|
||||
trainer.print(f"Step {current_step}: Loss {loss}")
|
||||
|
||||
# perform backprop & apply EMA updates
|
||||
trainer.update()
|
||||
|
||||
# track samples/sec/rank
|
||||
samples_per_sec = img.shape[0] / timer.elapsed()
|
||||
|
||||
# samples seen
|
||||
samples_seen = (
|
||||
config.data.batch_size * trainer.accelerator.num_processes * current_step
|
||||
)
|
||||
|
||||
# ema decay
|
||||
ema_decay = trainer.ema_diffusion_prior.get_current_decay()
|
||||
|
||||
# Log on all processes for debugging
|
||||
tracker.log(
|
||||
{
|
||||
"tracking/samples-sec": samples_per_sec,
|
||||
"tracking/samples-seen": samples_seen,
|
||||
"tracking/ema-decay": ema_decay,
|
||||
"metrics/training-loss": loss,
|
||||
},
|
||||
step=current_step,
|
||||
)
|
||||
|
||||
# Metric Tracking & Checkpointing (outside of timer's scope)
|
||||
if current_step % EVAL_EVERY == 0:
|
||||
eval_model(
|
||||
trainer=trainer,
|
||||
dataloader=eval_loader,
|
||||
text_conditioned=config.prior.condition_on_text_encodings,
|
||||
loss_type=config.prior.loss_type,
|
||||
tracker_context="metrics/online-model-validation",
|
||||
tracker=tracker,
|
||||
use_ema=False,
|
||||
)
|
||||
|
||||
eval_model(
|
||||
trainer=trainer,
|
||||
dataloader=eval_loader,
|
||||
text_conditioned=config.prior.condition_on_text_encodings,
|
||||
loss_type=config.prior.loss_type,
|
||||
tracker_context="metrics/ema-model-validation",
|
||||
tracker=tracker,
|
||||
use_ema=True,
|
||||
)
|
||||
|
||||
report_cosine_sims(
|
||||
trainer=trainer,
|
||||
dataloader=eval_loader,
|
||||
text_conditioned=config.prior.condition_on_text_encodings,
|
||||
tracker=tracker,
|
||||
tracker_context="metrics",
|
||||
)
|
||||
|
||||
if current_step % config.train.save_every == 0:
|
||||
trainer.save(f"{config.tracker.data_path}/chkpt_step_{current_step}.pth")
|
||||
|
||||
# reset timer for next round
|
||||
timer.reset()
|
||||
|
||||
# evaluate on test data
|
||||
|
||||
eval_model(
|
||||
trainer=trainer,
|
||||
dataloader=test_loader,
|
||||
text_conditioned=config.prior.condition_on_text_encodings,
|
||||
loss_type=config.prior.loss_type,
|
||||
tracker_context="test",
|
||||
tracker=tracker,
|
||||
)
|
||||
|
||||
report_cosine_sims(
|
||||
trainer,
|
||||
test_loader,
|
||||
config.prior.condition_on_text_encodings,
|
||||
tracker,
|
||||
tracker_context="test",
|
||||
)
|
||||
|
||||
|
||||
def initialize_training(config, accelerator=None):
|
||||
"""
|
||||
Parse the configuration file, and prepare everything necessary for training
|
||||
"""
|
||||
|
||||
# get a device
|
||||
|
||||
if accelerator:
|
||||
device = accelerator.device
|
||||
click.secho(f"Accelerating on: {device}", fg="yellow")
|
||||
else:
|
||||
if torch.cuda.is_available():
|
||||
click.secho("GPU detected, defaulting to cuda:0", fg="yellow")
|
||||
device = "cuda:0"
|
||||
else:
|
||||
click.secho("No GPU detected...using cpu", fg="yellow")
|
||||
device = "cpu"
|
||||
|
||||
# make the trainer (will automatically distribute if possible & configured)
|
||||
|
||||
trainer = make_model(config.prior, config.train, device, accelerator).to(device)
|
||||
|
||||
# reload from chcekpoint
|
||||
|
||||
if config.load.resume == True:
|
||||
click.secho(f"Loading checkpoint: {config.load.source}", fg="cyan")
|
||||
trainer.load(config.load.source)
|
||||
|
||||
# fetch and prepare data
|
||||
|
||||
if trainer.is_main_process():
|
||||
click.secho("Grabbing data from source", fg="blue", blink=True)
|
||||
|
||||
img_reader = get_reader(
|
||||
text_conditioned=trainer.text_conditioned,
|
||||
img_url=config.data.image_url,
|
||||
meta_url=config.data.meta_url,
|
||||
)
|
||||
|
||||
train_loader, eval_loader, test_loader = make_splits(
|
||||
text_conditioned=trainer.text_conditioned,
|
||||
batch_size=config.data.batch_size,
|
||||
num_data_points=NUM_DATA_POINTS,
|
||||
train_split=config.data.splits.train,
|
||||
eval_split=config.data.splits.val,
|
||||
image_reader=img_reader,
|
||||
rank=accelerator.state.process_index if exists(accelerator) else 0,
|
||||
world_size=accelerator.state.num_processes if exists(accelerator) else 1,
|
||||
start=START,
|
||||
)
|
||||
|
||||
# wait for everyone to load data before continuing
|
||||
trainer.wait_for_everyone()
|
||||
|
||||
# start training
|
||||
train(
|
||||
trainer=trainer,
|
||||
train_loader=train_loader,
|
||||
eval_loader=eval_loader,
|
||||
test_loader=test_loader,
|
||||
config=config,
|
||||
)
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option("--wandb-entity", default="laion")
|
||||
@click.option("--wandb-project", default="diffusion-prior")
|
||||
@click.option("--wandb-dataset", default="LAION-5B")
|
||||
@click.option("--wandb-arch", default="DiffusionPrior")
|
||||
@click.option("--image-embed-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/")
|
||||
@click.option("--text-embed-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/")
|
||||
@click.option("--meta-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/laion2B-en-metadata/")
|
||||
@click.option("--learning-rate", default=1.1e-4)
|
||||
@click.option("--weight-decay", default=6.02e-2)
|
||||
@click.option("--dropout", default=5e-2)
|
||||
@click.option("--max-grad-norm", default=0.5)
|
||||
@click.option("--num-data-points", default=250e6)
|
||||
@click.option("--batch-size", default=320)
|
||||
@click.option("--num-epochs", default=5)
|
||||
@click.option("--image-embed-dim", default=768)
|
||||
@click.option("--train-percent", default=0.9)
|
||||
@click.option("--val-percent", default=1e-7)
|
||||
@click.option("--test-percent", default=0.0999999)
|
||||
@click.option("--dpn-depth", default=12)
|
||||
@click.option("--dpn-dim-head", default=64)
|
||||
@click.option("--dpn-heads", default=12)
|
||||
@click.option("--dp-condition-on-text-encodings", default=True)
|
||||
@click.option("--dp-timesteps", default=1000)
|
||||
@click.option("--dp-normformer", default=True)
|
||||
@click.option("--dp-cond-drop-prob", default=0.1)
|
||||
@click.option("--dp-loss-type", default="l2")
|
||||
@click.option("--clip", default="ViT-L/14")
|
||||
@click.option("--amp", default=False)
|
||||
@click.option("--save-interval", default=120)
|
||||
@click.option("--save-path", default="./diffusion_prior_checkpoints")
|
||||
@click.option("--pretrained-model-path", default=None)
|
||||
@click.option("--gpu-device", default=0)
|
||||
def train(
|
||||
wandb_entity,
|
||||
wandb_project,
|
||||
wandb_dataset,
|
||||
wandb_arch,
|
||||
image_embed_url,
|
||||
text_embed_url,
|
||||
meta_url,
|
||||
learning_rate,
|
||||
weight_decay,
|
||||
dropout,
|
||||
max_grad_norm,
|
||||
num_data_points,
|
||||
batch_size,
|
||||
num_epochs,
|
||||
image_embed_dim,
|
||||
train_percent,
|
||||
val_percent,
|
||||
test_percent,
|
||||
dpn_depth,
|
||||
dpn_dim_head,
|
||||
dpn_heads,
|
||||
dp_condition_on_text_encodings,
|
||||
dp_timesteps,
|
||||
dp_normformer,
|
||||
dp_cond_drop_prob,
|
||||
dp_loss_type,
|
||||
clip,
|
||||
amp,
|
||||
save_interval,
|
||||
save_path,
|
||||
pretrained_model_path,
|
||||
gpu_device
|
||||
):
|
||||
config = {
|
||||
"learning_rate": learning_rate,
|
||||
"architecture": wandb_arch,
|
||||
"dataset": wandb_dataset,
|
||||
"weight_decay": weight_decay,
|
||||
"max_gradient_clipping_norm": max_grad_norm,
|
||||
"batch_size": batch_size,
|
||||
"epochs": num_epochs,
|
||||
"diffusion_prior_network": {
|
||||
"depth": dpn_depth,
|
||||
"dim_head": dpn_dim_head,
|
||||
"heads": dpn_heads,
|
||||
"normformer": dp_normformer
|
||||
},
|
||||
"diffusion_prior": {
|
||||
"condition_on_text_encodings": dp_condition_on_text_encodings,
|
||||
"timesteps": dp_timesteps,
|
||||
"cond_drop_prob": dp_cond_drop_prob,
|
||||
"loss_type": dp_loss_type,
|
||||
"clip": clip
|
||||
}
|
||||
}
|
||||
|
||||
# Check if DPRIOR_PATH exists(saved model path)
|
||||
|
||||
DPRIOR_PATH = pretrained_model_path
|
||||
RESUME = exists(DPRIOR_PATH)
|
||||
|
||||
if not RESUME:
|
||||
tracker.init(
|
||||
entity = wandb_entity,
|
||||
project = wandb_project,
|
||||
config = config
|
||||
)
|
||||
|
||||
# Obtain the utilized device.
|
||||
|
||||
has_cuda = torch.cuda.is_available()
|
||||
if has_cuda:
|
||||
device = torch.device(f"cuda:{gpu_device}")
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
# Training loop
|
||||
# diffusion prior network
|
||||
|
||||
prior_network = DiffusionPriorNetwork(
|
||||
dim = image_embed_dim,
|
||||
depth = dpn_depth,
|
||||
dim_head = dpn_dim_head,
|
||||
heads = dpn_heads,
|
||||
attn_dropout = dropout,
|
||||
ff_dropout = dropout,
|
||||
normformer = dp_normformer
|
||||
)
|
||||
|
||||
# Load clip model if text-conditioning
|
||||
if dp_condition_on_text_encodings:
|
||||
clip_adapter = OpenAIClipAdapter(clip)
|
||||
@click.option("--hfa", default=True)
|
||||
@click.option("--config_path", default="configs/prior.json")
|
||||
def main(hfa, config_path):
|
||||
# start HFA if requested
|
||||
if hfa:
|
||||
accelerator = Accelerator()
|
||||
else:
|
||||
clip_adapter = None
|
||||
|
||||
# diffusion prior with text embeddings and image embeddings pre-computed
|
||||
accelerator = None
|
||||
|
||||
diffusion_prior = DiffusionPrior(
|
||||
net = prior_network,
|
||||
clip = clip_adapter,
|
||||
image_embed_dim = image_embed_dim,
|
||||
timesteps = dp_timesteps,
|
||||
cond_drop_prob = dp_cond_drop_prob,
|
||||
loss_type = dp_loss_type,
|
||||
condition_on_text_encodings = dp_condition_on_text_encodings
|
||||
)
|
||||
# load the configuration file on main process
|
||||
if not exists(accelerator) or accelerator.is_main_process:
|
||||
click.secho(f"Loading configuration from {config_path}", fg="green")
|
||||
|
||||
# Load pre-trained model from DPRIOR_PATH
|
||||
config = TrainDiffusionPriorConfig.from_json_path(config_path)
|
||||
|
||||
if RESUME:
|
||||
diffusion_prior, loaded_obj = load_diffusion_model(DPRIOR_PATH, device)
|
||||
tracker.init(entity = wandb_entity, project = wandb_project, config = config)
|
||||
|
||||
# diffusion prior trainer
|
||||
|
||||
trainer = DiffusionPriorTrainer(
|
||||
diffusion_prior = diffusion_prior,
|
||||
lr = learning_rate,
|
||||
wd = weight_decay,
|
||||
max_grad_norm = max_grad_norm,
|
||||
amp = amp,
|
||||
).to(device)
|
||||
|
||||
# load optimizer and scaler
|
||||
|
||||
if RESUME:
|
||||
trainer.optimizer.load_state_dict(loaded_obj['optimizer'])
|
||||
trainer.scaler.load_state_dict(loaded_obj['scaler'])
|
||||
|
||||
# Create save_path if it doesn't exist
|
||||
|
||||
Path(save_path).mkdir(exist_ok = True, parents = True)
|
||||
|
||||
# Utilize wrapper to abstract away loader logic
|
||||
print_ribbon("Downloading Embeddings")
|
||||
loader_args = dict(text_conditioned=dp_condition_on_text_encodings, batch_size=batch_size, num_data_points=num_data_points,
|
||||
train_split=train_percent, eval_split=val_percent, device=device, img_url=image_embed_url)
|
||||
|
||||
if dp_condition_on_text_encodings:
|
||||
loader_args = dict(**loader_args, meta_url=meta_url)
|
||||
else:
|
||||
loader_args = dict(**loader_args, txt_url=text_embed_url)
|
||||
|
||||
train_loader, eval_loader, test_loader = make_splits(**loader_args)
|
||||
|
||||
### Training code ###
|
||||
|
||||
step = 1
|
||||
timer = Timer()
|
||||
epochs = num_epochs
|
||||
|
||||
for _ in range(epochs):
|
||||
|
||||
for image, text in tqdm(train_loader):
|
||||
|
||||
diffusion_prior.train()
|
||||
|
||||
input_args = dict(image_embed=image)
|
||||
if dp_condition_on_text_encodings:
|
||||
input_args = dict(**input_args, text = text)
|
||||
else:
|
||||
input_args = dict(**input_args, text_embed=text)
|
||||
|
||||
loss = trainer(**input_args)
|
||||
|
||||
# Samples per second
|
||||
|
||||
samples_per_sec = batch_size * step / timer.elapsed()
|
||||
|
||||
# Save checkpoint every save_interval minutes
|
||||
if(int(timer.elapsed()) >= 60 * save_interval):
|
||||
timer.reset()
|
||||
|
||||
save_diffusion_model(
|
||||
save_path,
|
||||
diffusion_prior,
|
||||
trainer.optimizer,
|
||||
trainer.scaler,
|
||||
config,
|
||||
image_embed_dim)
|
||||
|
||||
# Log to wandb
|
||||
tracker.log({"Training loss": loss,
|
||||
"Steps": step,
|
||||
"Samples per second": samples_per_sec})
|
||||
# Log cosineSim(text_embed,predicted_image_embed) - cosineSim(text_embed,image_embed)
|
||||
# Use NUM_TEST_EMBEDDINGS samples from the test set each time
|
||||
# Get embeddings from the most recently saved model
|
||||
if(step % REPORT_METRICS_EVERY) == 0:
|
||||
report_cosine_sims(diffusion_prior, eval_loader, dp_condition_on_text_encodings)
|
||||
### Evaluate model(validation run) ###
|
||||
eval_model(diffusion_prior, eval_loader, dp_condition_on_text_encodings, dp_loss_type, phase="Validation")
|
||||
|
||||
step += 1
|
||||
trainer.update()
|
||||
|
||||
### Test run ###
|
||||
eval_model(diffusion_prior, test_loader, dp_condition_on_text_encodings, dp_loss_type, phase="Test")
|
||||
# send config to get processed
|
||||
initialize_training(config, accelerator)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user