Compare commits

..

75 Commits

Author SHA1 Message Date
Phil Wang
c8422ffd5d fix EMA updating buffers with non-float tensors 2022-06-22 07:16:39 -07:00
Conight
2aadc23c7c Fix train decoder config example (#160) 2022-06-21 22:17:06 -07:00
Phil Wang
c098f57e09 EMA for vqgan vae comes from ema_pytorch now 2022-06-20 15:29:08 -07:00
Phil Wang
0021535c26 move ema to external repo 2022-06-20 11:48:32 -07:00
Phil Wang
56883910fb cleanup 2022-06-20 11:14:55 -07:00
Phil Wang
893f270012 project management 2022-06-20 10:00:22 -07:00
Phil Wang
f545ce18f4 be able to turn off p2 loss reweighting for upsamplers 2022-06-20 09:43:31 -07:00
Phil Wang
fc7abf624d in paper, blur sigma was 0.6 2022-06-20 09:05:08 -07:00
Phil Wang
67f0740777 small cleanup 2022-06-20 08:59:51 -07:00
Phil Wang
138079ca83 allow for setting beta schedules of unets differently in the decoder, as what was used in the paper was cosine, cosine, linear 2022-06-20 08:56:37 -07:00
zion
f5a906f5d3 prior train script bug fixes (#153) 2022-06-19 15:55:15 -07:00
Phil Wang
0215237fc6 update status 2022-06-19 09:42:24 -07:00
Phil Wang
461b91c5c1 also merge distributed training code for decoder, thanks to @Veldrovive 2022-06-19 09:26:44 -07:00
Aidan Dempster
58892135d9 Distributed Training of the Decoder (#121)
* Converted decoder trainer to use accelerate

* Fixed issue where metric evaluation would hang on distributed mode

* Implemented functional saving
Loading still fails due to some issue with the optimizer

* Fixed issue with loading decoders

* Fixed issue with tracker config

* Fixed issue with amp
Updated logging to be more logical

* Saving checkpoint now saves position in training as well
Fixed an issue with running out of gpu space due to loading weights into the gpu twice

* Fixed ema for distributed training

* Fixed isue where get_pkg_version was reintroduced

* Changed decoder trainer to upload config as a file

Fixed issue where loading best would error
2022-06-19 09:25:54 -07:00
Phil Wang
e37072a48c 0.10.0 2022-06-19 08:50:53 -07:00
Phil Wang
41ca896413 depend on huggingface accelerate, move appreciation thread up for visibility 2022-06-19 08:50:35 -07:00
zion
fe19b508ca Distributed Training of the Prior (#112)
* distributed prior trainer

better EMA support

update load and save methods of prior

* update prior training script

add test evalution & ema validation

add more tracking metrics

small cleanup
2022-06-19 08:46:14 -07:00
Phil Wang
6651eafa93 one more residual, after seeing good results on unconditional generation locally 2022-06-16 11:18:02 -07:00
Phil Wang
e6bb75e5ab fix missing residual for highest resolution of the unet 2022-06-15 20:09:43 -07:00
Giorgos Zachariadis
b4c3e5b854 changed str in order to avoid confusions and collisions with Python (#147) 2022-06-15 13:41:16 -07:00
Phil Wang
b7f9607258 make memory efficient unet design from imagen toggle-able 2022-06-15 13:40:26 -07:00
Phil Wang
2219348a6e adopt similar unet architecture as imagen 2022-06-15 12:18:21 -07:00
Phil Wang
9eea9b9862 add p2 loss reweighting for decoder training as an option 2022-06-14 10:58:57 -07:00
Phil Wang
5d958713c0 fix classifier free guidance for image hiddens summed to time hiddens, thanks to @xvjiarui for finding this bug 2022-06-13 21:01:50 -07:00
Phil Wang
0f31980362 cleanup 2022-06-07 17:31:38 -07:00
Phil Wang
bee5bf3815 fix for https://github.com/lucidrains/DALLE2-pytorch/issues/143 2022-06-07 09:03:48 -07:00
Phil Wang
350a3d6045 0.6.16 2022-06-06 08:45:46 -07:00
Kashif Rasul
1a81670718 fix quadratic_beta_schedule (#141) 2022-06-06 08:45:14 -07:00
Phil Wang
934c9728dc some cleanup 2022-06-04 16:54:15 -07:00
Phil Wang
ce4b0107c1 0.6.13 2022-06-04 13:26:57 -07:00
zion
64c2f9c4eb implement ema warmup from @crowsonkb (#140) 2022-06-04 13:26:34 -07:00
Phil Wang
22cc613278 ema fix from @nousr 2022-06-03 19:44:36 -07:00
zion
83517849e5 ema module fixes (#139) 2022-06-03 19:43:51 -07:00
Phil Wang
708809ed6c lower beta2 for adam down to 0.99, based on https://openreview.net/forum?id=2LdBqxc1Yv 2022-06-03 10:26:28 -07:00
Phil Wang
9cc475f6e7 fix update_every within EMA 2022-06-03 10:21:05 -07:00
Phil Wang
ffd342e9d0 allow for an option to constrain the variance interpolation fraction coming out from the unet for learned variance, if it is turned on 2022-06-03 09:34:57 -07:00
Phil Wang
f8bfd3493a make destructuring datum length agnostic when validating in training decoder script, for @YUHANG-Ma 2022-06-02 13:54:57 -07:00
Phil Wang
9025345e29 take a stab at fixing generate_grid_samples when real images have a greater image size than generated 2022-06-02 11:33:15 -07:00
Phil Wang
8cc278447e just cast to right types for blur sigma and kernel size augs 2022-06-02 11:21:58 -07:00
Phil Wang
38cd62010c allow for random blur sigma and kernel size augmentations on low res conditioning (need to reread paper to see if the augmentation value needs to be fed into the unet for conditioning as well) 2022-06-02 11:11:25 -07:00
Ryan Russell
1cc288af39 Improve Readability (#133)
Signed-off-by: Ryan Russell <git@ryanrussell.org>
2022-06-01 13:28:02 -07:00
Phil Wang
a851168633 make youtokentome optional package, due to reported installation difficulties 2022-06-01 09:25:35 -07:00
Phil Wang
1ffeecd0ca lower default ema beta value 2022-05-31 11:55:21 -07:00
Phil Wang
3df899f7a4 patch 2022-05-31 09:03:43 -07:00
Aidan Dempster
09534119a1 Fixed non deterministic optimizer creation (#130) 2022-05-31 09:03:20 -07:00
Phil Wang
6f8b90d4d7 add packaging package 2022-05-30 11:45:00 -07:00
Phil Wang
b588286288 fix version 2022-05-30 11:06:34 -07:00
Phil Wang
b693e0be03 default number of resnet blocks per layer in unet to 2 (in imagen it was 3 for base 64x64) 2022-05-30 10:06:48 -07:00
Phil Wang
a0bed30a84 additional conditioning on image embedding by summing to time embeddings (for FiLM like conditioning in subsequent layers), from passage found in paper by @mhh0318 2022-05-30 09:26:51 -07:00
zion
387c5bf774 quick patch for new prior loader (#123) 2022-05-29 16:25:53 -07:00
Phil Wang
a13d2d89c5 0.5.7 2022-05-29 07:40:25 -07:00
zion
44d4b1bba9 overhaul prior dataloader (#122)
add readme for loader
2022-05-29 07:39:59 -07:00
Phil Wang
f12a7589c5 commit to trying out grid attention 2022-05-26 12:56:10 -07:00
Phil Wang
b8af2210df make sure diffusion prior can be instantiated from pydantic class without clip 2022-05-26 08:47:30 -07:00
Phil Wang
f4fe6c570d allow for full customization of number of resnet blocks per down or upsampling layers in unet, as in imagen 2022-05-26 08:33:31 -07:00
Phil Wang
645e207441 credit assignment 2022-05-26 08:16:03 -07:00
Phil Wang
00743b3a0b update 2022-05-26 08:12:25 -07:00
Phil Wang
01589aff6a cite maxvit properly 2022-05-26 07:12:25 -07:00
Phil Wang
7ecfd76cc0 fix evaluation config splat in training decoder script 2022-05-26 07:11:31 -07:00
Phil Wang
6161b61c55 0.5.4 2022-05-25 09:32:17 -07:00
zion
1ed0f9d80b use deterministic optimizer params (#116) 2022-05-25 09:31:43 -07:00
Phil Wang
f326a95e26 0.5.3 2022-05-25 09:07:28 -07:00
zion
d7a0a2ce4b add more support for configuring prior (#113) 2022-05-25 09:06:50 -07:00
Phil Wang
f23fab7ef7 switch over to scale shift conditioning, as it seems like Imagen and Glide used it and it may be important 2022-05-24 21:46:12 -07:00
Phil Wang
857b9fbf1e allow for one to stop grouping out weight decayable parameters, to debug optimizer state dict problem 2022-05-24 21:42:32 -07:00
Phil Wang
8864fd0aa7 bring in the dynamic thresholding technique from the Imagen paper, which purportedly improves classifier free guidance for the cascading ddpm 2022-05-24 18:15:14 -07:00
Phil Wang
72bf159331 update 2022-05-24 08:25:40 -07:00
Phil Wang
e5e47cfecb link to aidan's test run 2022-05-23 12:41:46 -07:00
Phil Wang
fa533962bd just use an assert to make sure clip image channels is never different than the channels of the diffusion prior and decoder, if clip is given 2022-05-22 22:43:14 -07:00
Phil Wang
276abf337b fix and cleanup image size determination logic in decoder 2022-05-22 22:28:45 -07:00
Phil Wang
ae42d03006 allow for saving of additional fields on save method in trainers, and return loaded objects from the load method 2022-05-22 22:14:25 -07:00
Phil Wang
4d346e98d9 allow for config driven creation of clip-less diffusion prior 2022-05-22 20:36:20 -07:00
Phil Wang
2b1fd1ad2e product management 2022-05-22 19:23:40 -07:00
zion
82a2ef37d9 Update README.md (#109)
block in a section that links to available pre-trained models for those who are interested
2022-05-22 19:22:30 -07:00
Phil Wang
5c397c9d66 move neural network creations off the configuration file into the pydantic classes 2022-05-22 19:18:18 -07:00
23 changed files with 1780 additions and 1071 deletions

106
README.md
View File

@@ -12,7 +12,7 @@ This model is SOTA for text-to-image for now.
Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord" src="https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white"></a> if you are interested in helping out with the replication with the <a href="https://laion.ai/">LAION</a> community | <a href="https://www.youtube.com/watch?v=AIOE1l1W0Tw">Yannic Interview</a>
There was enough interest for a <a href="https://github.com/lucidrains/dalle2-jax">Jax version</a>. I will also eventually extend this to <a href="https://github.com/lucidrains/dalle2-video">text to video</a>, once the repository is in a good place.
As of 5/23/22, it is no longer SOTA. SOTA will be <a href="https://github.com/lucidrains/imagen-pytorch">here</a>. Jax versions as well as text-to-video project will be shifted towards the Imagen architecture, as it is way simpler.
## Status
@@ -24,6 +24,30 @@ There was enough interest for a <a href="https://github.com/lucidrains/dalle2-ja
*ongoing at 21k steps*
- <a href="https://twitter.com/Buntworthy/status/1529475416775434240?t=0GEge3Kr9I36cjcUVCQUTg">Justin Pinkney</a> successfully trained the diffusion prior in the repository for his CLIP to Stylegan2 text-to-image application
## Pre-Trained Models
- LAION is training prior models. Checkpoints are available on <a href="https://huggingface.co/zenglishuci/conditioned-prior">🤗huggingface</a> and the training statistics are available on <a href="https://wandb.ai/nousr_laion/conditioned-prior/reports/LAION-DALLE2-PyTorch-Prior--VmlldzoyMDI2OTIx">🐝WANDB</a>.
- Decoder - <a href="https://wandb.ai/veldrovive/dalle2_train_decoder/runs/jkrtg0so?workspace=user-veldrovive">In-progress test run</a> 🚧
- Decoder - <a href="https://wandb.ai/veldrovive/dalle2_train_decoder/runs/3d5rytsa?workspace=">Another test run with sparse attention</a>
- DALL-E 2 🚧
## Appreciation
This library would not have gotten to this working state without the help of
- <a href="https://github.com/nousr">Zion</a> for the distributed training code for the diffusion prior
- <a href="https://github.com/Veldrovive">Aidan</a> for the distributed training code for the decoder as well as the dataloaders
- <a href="https://github.com/krish240574">Kumar</a> for working on the initial diffusion training script
- <a href="https://github.com/rom1504">Romain</a> for the pull request reviews and project management
- <a href="https://github.com/Ciaohe">He Cao</a> and <a href="https://github.com/xiankgx">xiankgx</a> for the Q&A and for identifying of critical bugs
- <a href="https://github.com/crowsonkb">Katherine</a> for her advice
- <a href="https://stability.ai/">Stability AI</a> for the generous sponsorship
- <a href="https://huggingface.co">🤗 Huggingface</a> and in particular <a href="https://github.com/sgugger">Sylvain</a> for the <a href="https://github.com/huggingface/accelerate">Accelerate</a> library
... and many others. Thank you! 🙏
## Install
```bash
@@ -936,7 +960,7 @@ from dalle2_pytorch.dataloaders import ImageEmbeddingDataset, create_image_embed
# Create a dataloader directly.
dataloader = create_image_embedding_dataloader(
tar_url="/path/or/url/to/webdataset/{0000..9999}.tar", # Uses braket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar
tar_url="/path/or/url/to/webdataset/{0000..9999}.tar", # Uses bracket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar
embeddings_url="path/or/url/to/embeddings/folder", # Included if .npy files are not in webdataset. Left out or set to None otherwise
num_workers=4,
batch_size=32,
@@ -993,33 +1017,6 @@ The most significant parameters for the script are as follows:
- `clip`, default = `None` # Signals the prior to use pre-computed embeddings
#### Loading and Saving the DiffusionPrior model
Two methods are provided, load_diffusion_model and save_diffusion_model, the names being self-explanatory.
```python
from dalle2_pytorch.train import load_diffusion_model, save_diffusion_model
```
##### Loading
load_diffusion_model(dprior_path, device)
dprior_path : path to saved model(.pth)
device : the cuda device you're running on
##### Saving
save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim)
save_path : path to save at
model : object of Diffusion_Prior
optimizer : optimizer object - see train_diffusion_prior.py for how to create one.
e.g: optimizer = get_optimizer(diffusion_prior.net.parameters(), wd=weight_decay, lr=learning_rate)
scaler : a GradScaler object.
e.g: scaler = GradScaler(enabled=amp)
config : config object created in train_diffusion_prior.py - see file for example.
image_embed_dim - the dimension of the image_embedding
e.g: 768
## CLI (wip)
```bash
@@ -1034,18 +1031,6 @@ Once built, images will be saved to the same directory the command is invoked
<a href="https://github.com/lucidrains/stylegan2-pytorch">template</a>
## Appreciation
This library would not have gotten to this working state without the help of
- <a href="https://github.com/nousr">Zion</a> and <a href="https://github.com/krish240574">Kumar</a> for the diffusion training script
- <a href="https://github.com/Veldrovive">Aidan</a> for the decoder training script and dataloaders
- <a href="https://github.com/rom1504">Romain</a> for the pull request reviews and project management
- <a href="https://github.com/Ciaohe">He Cao</a> and <a href="https://github.com/xiankgx">xiankgx</a> for the Q&A and for identifying of critical bugs
- <a href="https://github.com/crowsonkb">Katherine</a> for her advice
... and many others. Thank you! 🙏
## Todo
- [x] finish off gaussian diffusion class for latent embedding - allow for prediction of epsilon
@@ -1079,19 +1064,15 @@ This library would not have gotten to this working state without the help of
- [x] use pydantic for config drive training
- [x] for both diffusion prior and decoder, all exponential moving averaged models needs to be saved and restored as well (as well as the step number)
- [x] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
- [x] allow for creation of diffusion prior model off pydantic config classes - consider the same for tracker configs
- [x] bring in skip-layer excitations (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training (doesnt work well)
- [x] test out grid attention in cascading ddpm locally, decide whether to keep or remove https://arxiv.org/abs/2204.01697 (keeping, seems to be fine)
- [x] allow for unet to be able to condition non-cross attention style as well
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
- [ ] train on a toy task, offer in colab
- [ ] think about how best to design a declarative training config that handles preencoding for prior and training of multiple networks in decoder
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
- [ ] speed up inference, read up on papers (ddim or diffusion-gan, etc)
- [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824
- [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
- [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783
- [ ] bring in skip-layer excitatons (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training
- [ ] decoder needs one day worth of refactor for tech debt
- [ ] allow for unet to be able to condition non-cross attention style as well
- [ ] read the paper, figure it out, and build it https://github.com/lucidrains/DALLE2-pytorch/issues/89
- [ ] build infilling
## Citations
@@ -1134,8 +1115,9 @@ This library would not have gotten to this working state without the help of
```bibtex
@inproceedings{Tu2022MaxViTMV,
title = {MaxViT: Multi-Axis Vision Transformer},
author = {Zhe-Wei Tu and Hossein Talebi and Han Zhang and Feng Yang and Peyman Milanfar and Alan Conrad Bovik and Yinxiao Li},
year = {2022}
author = {Zhengzhong Tu and Hossein Talebi and Han Zhang and Feng Yang and Peyman Milanfar and Alan Conrad Bovik and Yinxiao Li},
year = {2022},
url = {https://arxiv.org/abs/2204.01697}
}
```
@@ -1189,4 +1171,22 @@ This library would not have gotten to this working state without the help of
}
```
```bibtex
@misc{Saharia2022,
title = {Imagen: unprecedented photorealism × deep level of language understanding},
author = {Chitwan Saharia*, William Chan*, Saurabh Saxena†, Lala Li†, Jay Whang†, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S. Sara Mahdavi, Rapha Gontijo Lopes, Tim Salimans, Jonathan Ho†, David Fleet†, Mohammad Norouzi*},
year = {2022}
}
```
```bibtex
@article{Choi2022PerceptionPT,
title = {Perception Prioritized Training of Diffusion Models},
author = {Jooyoung Choi and Jungbeom Lee and Chaehun Shin and Sungwon Kim and Hyunwoo J. Kim and Sung-Hoon Yoon},
journal = {ArXiv},
year = {2022},
volume = {abs/2204.00227}
}
```
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>

View File

@@ -6,9 +6,10 @@ For more complex configuration, we provide the option of using a configuration f
The decoder trainer has 7 main configuration options. A full example of their use can be found in the [example decoder configuration](train_decoder_config.example.json).
**<ins>Unets</ins>:**
**<ins>Unet</ins>:**
This is a single unet config, which belongs as an array nested under the decoder config as a list of `unets`
Each member of this array defines a single unet that will be added to the decoder.
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `dim` | Yes | N/A | The starting channels of the unet. |
@@ -22,6 +23,7 @@ Any parameter from the `Unet` constructor can also be given here.
Defines the configuration options for the decoder model. The unets defined above will automatically be inserted.
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `unets` | Yes | N/A | A list of unets, using the configuration above |
| `image_sizes` | Yes | N/A | The resolution of the image after each upsampling step. The length of this array should be the number of unets defined. |
| `image_size` | Yes | N/A | Not used. Can be any number. |
| `timesteps` | No | `1000` | The number of diffusion timesteps used for generation. |
@@ -81,7 +83,7 @@ Defines which evaluation metrics will be used to test the model.
Each metric can be enabled by setting its configuration. The configuration keys for each metric are defined by the torchmetrics constructors which will be linked.
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `n_evalation_samples` | No | `1000` | The number of samples to generate to test the model. |
| `n_evaluation_samples` | No | `1000` | The number of samples to generate to test the model. |
| `FID` | No | `None` | Setting to an object enables the [Frechet Inception Distance](https://torchmetrics.readthedocs.io/en/stable/image/frechet_inception_distance.html) metric.
| `IS` | No | `None` | Setting to an object enables the [Inception Score](https://torchmetrics.readthedocs.io/en/stable/image/inception_score.html) metric.
| `KID` | No | `None` | Setting to an object enables the [Kernel Inception Distance](https://torchmetrics.readthedocs.io/en/stable/image/kernel_inception_distance.html) metric. |

View File

@@ -1,21 +1,21 @@
{
"unets": [
{
"dim": 128,
"image_embed_dim": 768,
"cond_dim": 64,
"channels": 3,
"dim_mults": [1, 2, 4, 8],
"attn_dim_head": 32,
"attn_heads": 16
}
],
"decoder": {
"unets": [
{
"dim": 128,
"image_embed_dim": 768,
"cond_dim": 64,
"channels": 3,
"dim_mults": [1, 2, 4, 8],
"attn_dim_head": 32,
"attn_heads": 16
}
],
"image_sizes": [64],
"channels": 3,
"timesteps": 1000,
"loss_type": "l2",
"beta_schedule": "cosine",
"beta_schedule": ["cosine"],
"learned_variance": true
},
"data": {

View File

@@ -0,0 +1,70 @@
{
"prior": {
"clip": {
"make": "x-clip",
"model": "ViT-L/14",
"base_model_kwargs": {
"dim_text": 768,
"dim_image": 768,
"dim_latent": 768
}
},
"net": {
"dim": 768,
"depth": 12,
"num_timesteps": 1000,
"num_time_embeds": 1,
"num_image_embeds": 1,
"num_text_embeds": 1,
"dim_head": 64,
"heads": 12,
"ff_mult": 4,
"norm_out": true,
"attn_dropout": 0.0,
"ff_dropout": 0.0,
"final_proj": true,
"normformer": true,
"rotary_emb": true
},
"image_embed_dim": 768,
"image_size": 224,
"image_channels": 3,
"timesteps": 1000,
"cond_drop_prob": 0.1,
"loss_type": "l2",
"predict_x_start": true,
"beta_schedule": "cosine",
"condition_on_text_encodings": true
},
"data": {
"image_url": "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/",
"text_url": "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/",
"meta_url": "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/laion2B-en-metadata/",
"batch_size": 256,
"splits": {
"train": 0.9,
"val": 1e-7,
"test": 0.0999999
}
},
"train": {
"epochs": 1,
"lr": 1.1e-4,
"wd": 6.02e-2,
"max_grad_norm": 0.5,
"use_ema": true,
"amp": false,
"save_every": 10000
},
"load": {
"source": null,
"resume": false
},
"tracker": {
"tracker_type": "wandb",
"data_path": "./prior_checkpoints",
"wandb_entity": "laion",
"wandb_project": "diffusion-prior",
"verbose": true
}
}

View File

@@ -1,3 +1,4 @@
from dalle2_pytorch.version import __version__
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter
from dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer

View File

@@ -1,6 +1,6 @@
import math
import random
from tqdm import tqdm
from inspect import isfunction
from functools import partial, wraps
from contextlib import contextmanager
from collections import namedtuple
@@ -11,7 +11,7 @@ import torch.nn.functional as F
from torch import nn, einsum
import torchvision.transforms as T
from einops import rearrange, repeat
from einops import rearrange, repeat, reduce
from einops.layers.torch import Rearrange
from einops_exts import rearrange_many, repeat_many, check_shape
from einops_exts.torch import EinopsToAndFrom
@@ -56,7 +56,7 @@ def maybe(fn):
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
return d() if callable(d) else d
def cast_tuple(val, length = 1):
if isinstance(val, list):
@@ -313,11 +313,6 @@ def extract(a, t, x_shape):
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
def noise_like(shape, device, repeat=False):
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
noise = lambda: torch.randn(shape, device=device)
return repeat_noise() if repeat else noise()
def meanflat(x):
return x.mean(dim = tuple(range(1, len(x.shape))))
@@ -372,7 +367,7 @@ def quadratic_beta_schedule(timesteps):
scale = 1000 / timesteps
beta_start = scale * 0.0001
beta_end = scale * 0.02
return torch.linspace(beta_start**2, beta_end**2, timesteps, dtype = torch.float64) ** 2
return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype = torch.float64) ** 2
def sigmoid_beta_schedule(timesteps):
@@ -383,8 +378,8 @@ def sigmoid_beta_schedule(timesteps):
return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
class BaseGaussianDiffusion(nn.Module):
def __init__(self, *, beta_schedule, timesteps, loss_type):
class NoiseScheduler(nn.Module):
def __init__(self, *, beta_schedule, timesteps, loss_type, p2_loss_weight_gamma = 0., p2_loss_weight_k = 1):
super().__init__()
if beta_schedule == "cosine":
@@ -449,6 +444,11 @@ class BaseGaussianDiffusion(nn.Module):
register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
# p2 loss reweighting
self.has_p2_loss_reweighting = p2_loss_weight_gamma > 0.
register_buffer('p2_loss_weight', (p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod)) ** -p2_loss_weight_gamma)
def q_posterior(self, x_start, x_t, t):
posterior_mean = (
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
@@ -472,11 +472,10 @@ class BaseGaussianDiffusion(nn.Module):
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
)
def sample(self, *args, **kwargs):
raise NotImplementedError
def forward(self, *args, **kwargs):
raise NotImplementedError
def p2_reweigh_loss(self, loss, times):
if not self.has_p2_loss_reweighting:
return loss
return loss * extract(self.p2_loss_weight, times, loss.shape)
# diffusion prior
@@ -687,8 +686,7 @@ class Attention(nn.Module):
# attention
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
attn = sim.softmax(dim = -1)
attn = sim.softmax(dim = -1, dtype = torch.float32)
attn = self.dropout(attn)
# aggregate values
@@ -862,7 +860,7 @@ class DiffusionPriorNetwork(nn.Module):
return pred_image_embed
class DiffusionPrior(BaseGaussianDiffusion):
class DiffusionPrior(nn.Module):
def __init__(
self,
net,
@@ -883,13 +881,17 @@ class DiffusionPrior(BaseGaussianDiffusion):
image_embed_scale = None, # this is for scaling the l2-normed image embedding, so it is more suitable for gaussian diffusion, as outlined by Katherine (@crowsonkb) https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
clip_adapter_overrides = dict()
):
super().__init__(
super().__init__()
self.noise_scheduler = NoiseScheduler(
beta_schedule = beta_schedule,
timesteps = timesteps,
loss_type = loss_type
)
if exists(clip):
assert image_channels == clip.image_channels, f'channels of image ({image_channels}) should be equal to the channels that CLIP accepts ({clip.image_channels})'
if isinstance(clip, CLIP):
clip = XClipAdapter(clip, **clip_adapter_overrides)
elif isinstance(clip, CoCa):
@@ -921,6 +923,13 @@ class DiffusionPrior(BaseGaussianDiffusion):
self.training_clamp_l2norm = training_clamp_l2norm
self.init_image_embed_l2norm = init_image_embed_l2norm
# device tracker
self.register_buffer('_dummy', torch.tensor([True]), persistent = False)
@property
def device(self):
return self._dummy.device
def p_mean_variance(self, x, t, text_cond, clip_denoised = False, cond_scale = 1.):
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the model was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
@@ -931,7 +940,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
# not 100% sure of this above line - for any spectators, let me know in the github issues (or through a pull request) if you know how to correctly do this
# i'll be rereading https://arxiv.org/abs/2111.14822, where i think a similar approach is taken
else:
x_recon = self.predict_start_from_noise(x, t = t, noise = pred)
x_recon = self.noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
if clip_denoised and not self.predict_x_start:
x_recon.clamp_(-1., 1.)
@@ -939,21 +948,21 @@ class DiffusionPrior(BaseGaussianDiffusion):
if self.predict_x_start and self.sampling_clamp_l2norm:
x_recon = l2norm(x_recon) * self.image_embed_scale
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
model_mean, posterior_variance, posterior_log_variance = self.noise_scheduler.q_posterior(x_start=x_recon, x_t=x, t=t)
return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad()
def p_sample(self, x, t, text_cond = None, clip_denoised = True, repeat_noise = False, cond_scale = 1.):
def p_sample(self, x, t, text_cond = None, clip_denoised = True, cond_scale = 1.):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, text_cond = text_cond, clip_denoised = clip_denoised, cond_scale = cond_scale)
noise = noise_like(x.shape, device, repeat_noise)
noise = torch.randn_like(x)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad()
def p_sample_loop(self, shape, text_cond, cond_scale = 1.):
device = self.betas.device
device = self.device
b = shape[0]
image_embed = torch.randn(shape, device=device)
@@ -961,7 +970,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
if self.init_image_embed_l2norm:
image_embed = l2norm(image_embed) * self.image_embed_scale
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
for i in tqdm(reversed(range(0, self.noise_scheduler.num_timesteps)), desc='sampling loop time step', total=self.noise_scheduler.num_timesteps):
times = torch.full((b,), i, device = device, dtype = torch.long)
image_embed = self.p_sample(image_embed, times, text_cond = text_cond, cond_scale = cond_scale)
@@ -970,7 +979,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
def p_losses(self, image_embed, times, text_cond, noise = None):
noise = default(noise, lambda: torch.randn_like(image_embed))
image_embed_noisy = self.q_sample(x_start = image_embed, t = times, noise = noise)
image_embed_noisy = self.noise_scheduler.q_sample(x_start = image_embed, t = times, noise = noise)
pred = self.net(
image_embed_noisy,
@@ -984,7 +993,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
target = noise if not self.predict_x_start else image_embed
loss = self.loss_fn(pred, target)
loss = self.noise_scheduler.loss_fn(pred, target)
return loss
@torch.no_grad()
@@ -995,7 +1004,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
img = torch.randn(shape, device = device)
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
for i in tqdm(reversed(range(0, self.noise_scheduler.num_timesteps)), desc = 'sampling loop time step', total = self.noise_scheduler.num_timesteps):
img = self.p_sample(img, torch.full((batch_size,), i, device = device, dtype = torch.long), text_cond = text_cond, cond_scale = cond_scale)
return img
@@ -1067,7 +1076,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
# timestep conditioning from ddpm
batch, device = image_embed.shape[0], image_embed.device
times = torch.randint(0, self.num_timesteps, (batch,), device = device, dtype = torch.long)
times = torch.randint(0, self.noise_scheduler.num_timesteps, (batch,), device = device, dtype = torch.long)
# scale image embed (Katherine)
@@ -1082,8 +1091,9 @@ class DiffusionPrior(BaseGaussianDiffusion):
def Upsample(dim):
return nn.ConvTranspose2d(dim, dim, 4, 2, 1)
def Downsample(dim):
return nn.Conv2d(dim, dim, 4, 2, 1)
def Downsample(dim, *, dim_out = None):
dim_out = default(dim_out, dim)
return nn.Conv2d(dim, dim_out, 4, 2, 1)
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
@@ -1105,13 +1115,20 @@ class Block(nn.Module):
groups = 8
):
super().__init__()
self.block = nn.Sequential(
nn.Conv2d(dim, dim_out, 3, padding = 1),
nn.GroupNorm(groups, dim_out),
nn.SiLU()
)
def forward(self, x):
return self.block(x)
self.project = nn.Conv2d(dim, dim_out, 3, padding = 1)
self.norm = nn.GroupNorm(groups, dim_out)
self.act = nn.SiLU()
def forward(self, x, scale_shift = None):
x = self.project(x)
x = self.norm(x)
if exists(scale_shift):
scale, shift = scale_shift
x = x * (scale + 1) + shift
x = self.act(x)
return x
class ResnetBlock(nn.Module):
def __init__(
@@ -1130,7 +1147,7 @@ class ResnetBlock(nn.Module):
if exists(time_cond_dim):
self.time_mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(time_cond_dim, dim_out)
nn.Linear(time_cond_dim, dim_out * 2)
)
self.cross_attn = None
@@ -1150,11 +1167,14 @@ class ResnetBlock(nn.Module):
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
def forward(self, x, cond = None, time_emb = None):
h = self.block1(x)
scale_shift = None
if exists(self.time_mlp) and exists(time_emb):
time_emb = self.time_mlp(time_emb)
h = rearrange(time_emb, 'b c -> b c 1 1') + h
time_emb = rearrange(time_emb, 'b c -> b c 1 1')
scale_shift = time_emb.chunk(2, dim = 1)
h = self.block1(x, scale_shift = scale_shift)
if exists(self.cross_attn):
assert exists(cond)
@@ -1221,8 +1241,7 @@ class CrossAttention(nn.Module):
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, max_neg_value)
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
attn = sim.softmax(dim = -1)
attn = sim.softmax(dim = -1, dtype = torch.float32)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
@@ -1331,12 +1350,15 @@ class Unet(nn.Module):
cond_on_text_encodings = False,
max_text_len = 256,
cond_on_image_embeds = False,
add_image_embeds_to_time = True, # alerted by @mhh0318 to a phrase in the paper - "Specifically, we modify the architecture described in Nichol et al. (2021) by projecting and adding CLIP embeddings to the existing timestep embedding"
init_dim = None,
init_conv_kernel_size = 7,
resnet_groups = 8,
num_resnet_blocks = 2,
init_cross_embed_kernel_sizes = (3, 7, 15),
cross_embed_downsample = False,
cross_embed_downsample_kernel_sizes = (2, 4),
memory_efficient = False,
**kwargs
):
super().__init__()
@@ -1356,7 +1378,7 @@ class Unet(nn.Module):
self.channels_out = default(channels_out, channels)
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
init_dim = default(init_dim, dim // 3 * 2)
init_dim = default(init_dim, dim)
self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1)
@@ -1383,11 +1405,16 @@ class Unet(nn.Module):
nn.Linear(time_cond_dim, time_cond_dim)
)
self.image_to_cond = nn.Sequential(
self.image_to_tokens = nn.Sequential(
nn.Linear(image_embed_dim, cond_dim * num_image_tokens),
Rearrange('b (n d) -> b n d', n = num_image_tokens)
) if cond_on_image_embeds and image_embed_dim != cond_dim else nn.Identity()
self.to_image_hiddens = nn.Sequential(
nn.Linear(image_embed_dim, time_cond_dim),
nn.GELU()
) if cond_on_image_embeds and add_image_embeds_to_time else None
self.norm_cond = nn.LayerNorm(cond_dim)
self.norm_mid_cond = nn.LayerNorm(cond_dim)
@@ -1408,6 +1435,7 @@ class Unet(nn.Module):
# for classifier free guidance
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
self.null_image_hiddens = nn.Parameter(torch.randn(1, time_cond_dim))
self.max_text_len = max_text_len
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim))
@@ -1419,6 +1447,7 @@ class Unet(nn.Module):
# resnet block klass
resnet_groups = cast_tuple(resnet_groups, len(in_out))
num_resnet_blocks = cast_tuple(num_resnet_blocks, len(in_out))
assert len(resnet_groups) == len(in_out)
@@ -1434,16 +1463,17 @@ class Unet(nn.Module):
self.ups = nn.ModuleList([])
num_resolutions = len(in_out)
for ind, ((dim_in, dim_out), groups) in enumerate(zip(in_out, resnet_groups)):
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks) in enumerate(zip(in_out, resnet_groups, num_resnet_blocks)):
is_first = ind == 0
is_last = ind >= (num_resolutions - 1)
layer_cond_dim = cond_dim if not is_first else None
self.downs.append(nn.ModuleList([
ResnetBlock(dim_in, dim_out, time_cond_dim = time_cond_dim, groups = groups),
downsample_klass(dim_in, dim_out = dim_out) if memory_efficient else None,
ResnetBlock(dim_out if memory_efficient else dim_in, dim_out, time_cond_dim = time_cond_dim, groups = groups),
Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(),
ResnetBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
downsample_klass(dim_out) if not is_last else nn.Identity()
nn.ModuleList([ResnetBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
downsample_klass(dim_out) if not is_last and not memory_efficient else None
]))
mid_dim = dims[-1]
@@ -1452,19 +1482,19 @@ class Unet(nn.Module):
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
for ind, ((dim_in, dim_out), groups) in enumerate(zip(reversed(in_out[1:]), reversed(resnet_groups))):
is_last = ind >= (num_resolutions - 2)
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks) in enumerate(zip(reversed(in_out), reversed(resnet_groups), reversed(num_resnet_blocks))):
is_last = ind >= (len(in_out) - 1)
layer_cond_dim = cond_dim if not is_last else None
self.ups.append(nn.ModuleList([
ResnetBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
Residual(LinearAttention(dim_in, **attn_kwargs)) if sparse_attn else nn.Identity(),
ResnetBlock(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
Upsample(dim_in)
nn.ModuleList([ResnetBlock(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
Upsample(dim_in) if not is_last or memory_efficient else nn.Identity()
]))
self.final_conv = nn.Sequential(
ResnetBlock(dim, dim, groups = resnet_groups[0]),
ResnetBlock(dim * 2, dim, groups = resnet_groups[0]),
nn.Conv2d(dim, self.channels_out, 1)
)
@@ -1536,6 +1566,7 @@ class Unet(nn.Module):
# initial convolution
x = self.init_conv(x)
r = x.clone() # final residual
# time conditioning
@@ -1549,7 +1580,23 @@ class Unet(nn.Module):
image_keep_mask = prob_mask_like((batch_size,), 1 - image_cond_drop_prob, device = device)
text_keep_mask = prob_mask_like((batch_size,), 1 - text_cond_drop_prob, device = device)
image_keep_mask, text_keep_mask = rearrange_many((image_keep_mask, text_keep_mask), 'b -> b 1 1')
text_keep_mask = rearrange(text_keep_mask, 'b -> b 1 1')
# image embedding to be summed to time embedding
# discovered by @mhh0318 in the paper
if exists(image_embed) and exists(self.to_image_hiddens):
image_hiddens = self.to_image_hiddens(image_embed)
image_keep_mask_hidden = rearrange(image_keep_mask, 'b -> b 1')
null_image_hiddens = self.null_image_hiddens.to(image_hiddens.dtype)
image_hiddens = torch.where(
image_keep_mask_hidden,
image_hiddens,
null_image_hiddens
)
t = t + image_hiddens
# mask out image embedding depending on condition dropout
# for classifier free guidance
@@ -1557,11 +1604,12 @@ class Unet(nn.Module):
image_tokens = None
if self.cond_on_image_embeds:
image_tokens = self.image_to_cond(image_embed)
image_keep_mask_embed = rearrange(image_keep_mask, 'b -> b 1 1')
image_tokens = self.image_to_tokens(image_embed)
null_image_embed = self.null_image_embed.to(image_tokens.dtype) # for some reason pytorch AMP not working
image_tokens = torch.where(
image_keep_mask,
image_keep_mask_embed,
image_tokens,
null_image_embed
)
@@ -1616,12 +1664,20 @@ class Unet(nn.Module):
hiddens = []
for block1, sparse_attn, block2, downsample in self.downs:
x = block1(x, c, t)
for pre_downsample, init_block, sparse_attn, resnet_blocks, post_downsample in self.downs:
if exists(pre_downsample):
x = pre_downsample(x)
x = init_block(x, c, t)
x = sparse_attn(x)
x = block2(x, c, t)
for resnet_block in resnet_blocks:
x = resnet_block(x, c, t)
hiddens.append(x)
x = downsample(x)
if exists(post_downsample):
x = post_downsample(x)
x = self.mid_block1(x, mid_c, t)
@@ -1630,20 +1686,24 @@ class Unet(nn.Module):
x = self.mid_block2(x, mid_c, t)
for block1, sparse_attn, block2, upsample in self.ups:
x = torch.cat((x, hiddens.pop()), dim=1)
x = block1(x, c, t)
for init_block, sparse_attn, resnet_blocks, upsample in self.ups:
x = torch.cat((x, hiddens.pop()), dim = 1)
x = init_block(x, c, t)
x = sparse_attn(x)
x = block2(x, c, t)
for resnet_block in resnet_blocks:
x = resnet_block(x, c, t)
x = upsample(x)
x = torch.cat((x, r), dim = 1)
return self.final_conv(x)
class LowresConditioner(nn.Module):
def __init__(
self,
downsample_first = True,
blur_sigma = 0.1,
blur_sigma = 0.6,
blur_kernel_size = 3,
):
super().__init__()
@@ -1667,13 +1727,25 @@ class LowresConditioner(nn.Module):
# when training, blur the low resolution conditional image
blur_sigma = default(blur_sigma, self.blur_sigma)
blur_kernel_size = default(blur_kernel_size, self.blur_kernel_size)
# allow for drawing a random sigma between lo and hi float values
if isinstance(blur_sigma, tuple):
blur_sigma = tuple(map(float, blur_sigma))
blur_sigma = random.uniform(*blur_sigma)
# allow for drawing a random kernel size between lo and hi int values
if isinstance(blur_kernel_size, tuple):
blur_kernel_size = tuple(map(int, blur_kernel_size))
kernel_size_lo, kernel_size_hi = blur_kernel_size
blur_kernel_size = random.randrange(kernel_size_lo, kernel_size_hi + 1)
cond_fmap = gaussian_blur2d(cond_fmap, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2))
cond_fmap = resize_image_to(cond_fmap, target_image_size)
return cond_fmap
class Decoder(BaseGaussianDiffusion):
class Decoder(nn.Module):
def __init__(
self,
unet,
@@ -1686,36 +1758,44 @@ class Decoder(BaseGaussianDiffusion):
image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5,
loss_type = 'l2',
beta_schedule = 'cosine',
beta_schedule = None,
predict_x_start = False,
predict_x_start_for_latent_diffusion = False,
image_sizes = None, # for cascading ddpm, image size at each stage
random_crop_sizes = None, # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops)
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
blur_sigma = 0.1, # cascading ddpm - blur sigma
blur_sigma = (0.1, 0.2), # cascading ddpm - blur sigma
blur_kernel_size = 3, # cascading ddpm - blur kernel size
condition_on_text_encodings = False, # the paper suggested that this didn't do much in the decoder, but i'm allowing the option for experimentation
clip_denoised = True,
clip_x_start = True,
clip_adapter_overrides = dict(),
learned_variance = True,
learned_variance_constrain_frac = False,
vb_loss_weight = 0.001,
unconditional = False,
auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader
use_dynamic_thres = False, # from the Imagen paper
dynamic_thres_percentile = 0.9,
p2_loss_weight_gamma = 0., # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended
p2_loss_weight_k = 1
):
super().__init__(
beta_schedule = beta_schedule,
timesteps = timesteps,
loss_type = loss_type
)
super().__init__()
self.unconditional = unconditional
assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present'
assert self.unconditional or (exists(clip) ^ exists(image_size)), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)'
# text conditioning
assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present'
self.condition_on_text_encodings = condition_on_text_encodings
# clip
self.clip = None
if exists(clip):
assert not unconditional, 'clip must not be given if doing unconditional image training'
assert channels == clip.image_channels, f'channels of image ({channels}) should be equal to the channels that CLIP accepts ({clip.image_channels})'
if isinstance(clip, CLIP):
clip = XClipAdapter(clip, **clip_adapter_overrides)
elif isinstance(clip, CoCa):
@@ -1725,24 +1805,34 @@ class Decoder(BaseGaussianDiffusion):
assert isinstance(clip, BaseClipAdapter)
self.clip = clip
self.clip_image_size = clip.image_size
self.channels = clip.image_channels
else:
self.clip_image_size = image_size
self.channels = channels
self.condition_on_text_encodings = condition_on_text_encodings
# determine image size, with image_size and image_sizes taking precedence
if exists(image_size) or exists(image_sizes):
assert exists(image_size) ^ exists(image_sizes), 'only one of image_size or image_sizes must be given'
image_size = default(image_size, lambda: image_sizes[-1])
elif exists(clip):
image_size = clip.image_size
else:
raise Error('either image_size, image_sizes, or clip must be given to decoder')
# channels
self.channels = channels
# automatically take care of ensuring that first unet is unconditional
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
unets = cast_tuple(unet)
num_unets = len(unets)
vaes = pad_tuple_to_length(cast_tuple(vae), len(unets), fillvalue = NullVQGanVAE(channels = self.channels))
# whether to use learned variance, defaults to True for the first unet in the cascade, as in paper
learned_variance = pad_tuple_to_length(cast_tuple(learned_variance), len(unets), fillvalue = False)
self.learned_variance = learned_variance
self.learned_variance_constrain_frac = learned_variance_constrain_frac # whether to constrain the output of the network (the interpolation fraction) from 0 to 1
self.vb_loss_weight = vb_loss_weight
# construct unets and vaes
@@ -1771,9 +1861,30 @@ class Decoder(BaseGaussianDiffusion):
self.unets.append(one_unet)
self.vaes.append(one_vae.copy_for_eval())
# create noise schedulers per unet
if not exists(beta_schedule):
beta_schedule = ('cosine', *(('cosine',) * max(num_unets - 2, 0)), *(('linear',) * int(num_unets > 1)))
beta_schedule = cast_tuple(beta_schedule, num_unets)
p2_loss_weight_gamma = cast_tuple(p2_loss_weight_gamma, num_unets)
self.noise_schedulers = nn.ModuleList([])
for unet_beta_schedule, unet_p2_loss_weight_gamma in zip(beta_schedule, p2_loss_weight_gamma):
noise_scheduler = NoiseScheduler(
beta_schedule = unet_beta_schedule,
timesteps = timesteps,
loss_type = loss_type,
p2_loss_weight_gamma = unet_p2_loss_weight_gamma,
p2_loss_weight_k = p2_loss_weight_k
)
self.noise_schedulers.append(noise_scheduler)
# unet image sizes
image_sizes = default(image_sizes, (self.clip_image_size,))
image_sizes = default(image_sizes, (image_size,))
image_sizes = tuple(sorted(set(image_sizes)))
assert len(self.unets) == len(image_sizes), f'you did not supply the correct number of u-nets ({len(self.unets)}) for resolutions {image_sizes}'
@@ -1810,10 +1921,24 @@ class Decoder(BaseGaussianDiffusion):
self.clip_denoised = clip_denoised
self.clip_x_start = clip_x_start
# dynamic thresholding settings, if clipping denoised during sampling
self.use_dynamic_thres = use_dynamic_thres
self.dynamic_thres_percentile = dynamic_thres_percentile
# normalize and unnormalize image functions
self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity
self.unnormalize_img = unnormalize_zero_to_one if auto_normalize_img else identity
# device tracker
self.register_buffer('_dummy', torch.Tensor([True]), persistent = False)
@property
def device(self):
return self._dummy.device
def get_unet(self, unet_number):
assert 0 < unet_number <= len(self.unets)
index = unet_number - 1
@@ -1837,7 +1962,7 @@ class Decoder(BaseGaussianDiffusion):
for unet, device in zip(self.unets, devices):
unet.to(device)
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None):
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None):
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img))
@@ -1848,38 +1973,55 @@ class Decoder(BaseGaussianDiffusion):
if predict_x_start:
x_recon = pred
else:
x_recon = self.predict_start_from_noise(x, t = t, noise = pred)
x_recon = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
if clip_denoised:
x_recon.clamp_(-1., 1.)
# s is the threshold amount
# static thresholding would just be s = 1
s = 1.
if self.use_dynamic_thres:
s = torch.quantile(
rearrange(x_recon, 'b ... -> b (...)').abs(),
self.dynamic_thres_percentile,
dim = -1
)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
s.clamp_(min = 1.)
s = s.view(-1, *((1,) * (x_recon.ndim - 1)))
# clip by threshold, depending on whether static or dynamic
x_recon = x_recon.clamp(-s, s) / s
model_mean, posterior_variance, posterior_log_variance = noise_scheduler.q_posterior(x_start=x_recon, x_t=x, t=t)
if learned_variance:
# if learned variance, posterio variance and posterior log variance are predicted by the network
# by an interpolation of the max and min log beta values
# eq 15 - https://arxiv.org/abs/2102.09672
min_log = extract(self.posterior_log_variance_clipped, t, x.shape)
max_log = extract(torch.log(self.betas), t, x.shape)
min_log = extract(noise_scheduler.posterior_log_variance_clipped, t, x.shape)
max_log = extract(torch.log(noise_scheduler.betas), t, x.shape)
var_interp_frac = unnormalize_zero_to_one(var_interp_frac_unnormalized)
if self.learned_variance_constrain_frac:
var_interp_frac = var_interp_frac.sigmoid()
posterior_log_variance = var_interp_frac * max_log + (1 - var_interp_frac) * min_log
posterior_variance = posterior_log_variance.exp()
return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad()
def p_sample(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True, repeat_noise = False):
def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, learned_variance = learned_variance)
noise = noise_like(x.shape, device, repeat_noise)
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, noise_scheduler = noise_scheduler, learned_variance = learned_variance)
noise = torch.randn_like(x)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad()
def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1, is_latent_diffusion = False):
device = self.betas.device
def p_sample_loop(self, unet, shape, image_embed, noise_scheduler, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1, is_latent_diffusion = False):
device = self.device
b = shape[0]
img = torch.randn(shape, device = device)
@@ -1887,7 +2029,7 @@ class Decoder(BaseGaussianDiffusion):
if not is_latent_diffusion:
lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
for i in tqdm(reversed(range(0, noise_scheduler.num_timesteps)), desc = 'sampling loop time step', total = noise_scheduler.num_timesteps):
img = self.p_sample(
unet,
img,
@@ -1898,6 +2040,7 @@ class Decoder(BaseGaussianDiffusion):
cond_scale = cond_scale,
lowres_cond_img = lowres_cond_img,
predict_x_start = predict_x_start,
noise_scheduler = noise_scheduler,
learned_variance = learned_variance,
clip_denoised = clip_denoised
)
@@ -1905,7 +2048,7 @@ class Decoder(BaseGaussianDiffusion):
unnormalize_img = self.unnormalize_img(img)
return unnormalize_img
def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False):
def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False):
noise = default(noise, lambda: torch.randn_like(x_start))
# normalize to [-1, 1]
@@ -1916,7 +2059,7 @@ class Decoder(BaseGaussianDiffusion):
# get x_t
x_noisy = self.q_sample(x_start = x_start, t = times, noise = noise)
x_noisy = noise_scheduler.q_sample(x_start = x_start, t = times, noise = noise)
model_output = unet(
x_noisy,
@@ -1936,7 +2079,12 @@ class Decoder(BaseGaussianDiffusion):
target = noise if not predict_x_start else x_start
loss = self.loss_fn(pred, target)
loss = noise_scheduler.loss_fn(pred, target, reduction = 'none')
loss = reduce(loss, 'b ... -> b (...)', 'mean')
loss = noise_scheduler.p2_reweigh_loss(loss, times)
loss = loss.mean()
if not learned_variance:
# return simple loss if not using learned variance
@@ -1949,8 +2097,8 @@ class Decoder(BaseGaussianDiffusion):
# if learning the variance, also include the extra weight kl loss
true_mean, _, true_log_variance_clipped = self.q_posterior(x_start = x_start, x_t = x_noisy, t = times)
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x_noisy, t = times, image_embed = image_embed, clip_denoised = clip_denoised, learned_variance = True, model_output = model_output)
true_mean, _, true_log_variance_clipped = noise_scheduler.q_posterior(x_start = x_start, x_t = x_noisy, t = times)
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x_noisy, t = times, image_embed = image_embed, noise_scheduler = noise_scheduler, clip_denoised = clip_denoised, learned_variance = True, model_output = model_output)
# kl loss with detached model predicted mean, for stability reasons as in paper
@@ -1982,7 +2130,8 @@ class Decoder(BaseGaussianDiffusion):
text_encodings = None,
batch_size = 1,
cond_scale = 1.,
stop_at_unet_number = None
stop_at_unet_number = None,
distributed = False,
):
assert self.unconditional or exists(image_embed), 'image embed must be present on sampling from decoder unless if trained unconditionally'
@@ -1999,9 +2148,9 @@ class Decoder(BaseGaussianDiffusion):
img = None
is_cuda = next(self.parameters()).is_cuda
for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance)):
for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance, noise_scheduler in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance, self.noise_schedulers)):
context = self.one_unet_in_gpu(unet = unet) if is_cuda else null_context()
context = self.one_unet_in_gpu(unet = unet) if is_cuda and not distributed else null_context()
with context:
lowres_cond_img = None
@@ -2027,7 +2176,8 @@ class Decoder(BaseGaussianDiffusion):
learned_variance = learned_variance,
clip_denoised = not is_latent_diffusion,
lowres_cond_img = lowres_cond_img,
is_latent_diffusion = is_latent_diffusion
is_latent_diffusion = is_latent_diffusion,
noise_scheduler = noise_scheduler
)
img = vae.decode(img)
@@ -2053,6 +2203,7 @@ class Decoder(BaseGaussianDiffusion):
unet = self.get_unet(unet_number)
vae = self.vaes[unet_index]
noise_scheduler = self.noise_schedulers[unet_index]
target_image_size = self.image_sizes[unet_index]
predict_x_start = self.predict_x_start[unet_index]
random_crop_size = self.random_crop_sizes[unet_index]
@@ -2062,7 +2213,7 @@ class Decoder(BaseGaussianDiffusion):
check_shape(image, 'b c h w', c = self.channels)
assert h >= target_image_size and w >= target_image_size
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
times = torch.randint(0, noise_scheduler.num_timesteps, (b,), device = device, dtype = torch.long)
if not exists(image_embed) and not self.unconditional:
assert exists(self.clip), 'if you want to derive CLIP image embeddings automatically, you must supply `clip` to the decoder on init'
@@ -2093,7 +2244,7 @@ class Decoder(BaseGaussianDiffusion):
image = vae.encode(image)
lowres_cond_img = maybe(vae.encode)(lowres_cond_img)
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion)
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler)
# main class

View File

@@ -4,7 +4,7 @@ In order to make loading data simple and efficient, we include some general data
### Decoder: Image Embedding Dataset
When training the decoder (and up samplers if training together) in isolation, you will need to load images and corresponding image embeddings. This dataset can read two similar types of datasets. First, it can read a [webdataset](https://github.com/webdataset/webdataset) that contains `.jpg` and `.npy` files in the `.tar`s that contain the images and associated image embeddings respectively. Alternatively, you can also specify a source for the embeddings outside of the webdataset. In this case, the path to the embeddings should contain `.npy` files with the same shard numbers as the webdataset and there should be a correspondence between the filename of the `.jpg` and the index of the embedding in the `.npy`. So, for example, `0001.tar` from the webdataset with image `00010509.jpg` (the first 4 digits are the shard number and the last 4 are the index) in it should be paralleled by a `img_emb_0001.npy` which contains a NumPy array with the embedding at index 509.
Generating a dataset of this type:
Generating a dataset of this type:
1. Use [img2dataset](https://github.com/rom1504/img2dataset) to generate a webdataset.
2. Use [clip-retrieval](https://github.com/rom1504/clip-retrieval) to convert the images to embeddings.
3. Use [embedding-dataset-reordering](https://github.com/Veldrovive/embedding-dataset-reordering) to reorder the embeddings into the expected format.
@@ -15,7 +15,7 @@ from dalle2_pytorch.dataloaders import ImageEmbeddingDataset, create_image_embed
# Create a dataloader directly.
dataloader = create_image_embedding_dataloader(
tar_url="/path/or/url/to/webdataset/{0000..9999}.tar", # Uses braket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar
tar_url="/path/or/url/to/webdataset/{0000..9999}.tar", # Uses bracket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar
embeddings_url="path/or/url/to/embeddings/folder", # Included if .npy files are not in webdataset. Left out or set to None otherwise
num_workers=4,
batch_size=32,
@@ -39,3 +39,37 @@ dataset = ImageEmbeddingDataset(
)
```
### Diffusion Prior: Prior Embedding Dataset
When training the prior it is much more efficient to work with pre-computed embeddings. The `PriorEmbeddingDataset` class enables you to leverage the same script (with minimal modification) for both embedding-only and text-conditioned prior training. This saves you from having to worry about a lot of the boilerplate code.
To utilize the `PriorEmbeddingDataset`, all you need to do is make a single call to `get_reader()` which will create `EmbeddingReader` object(s) for you. Afterwards, you can utilize `make_splits()` to cleanly create DataLoader objects from for your training run.
If you are training in a distributed manner, `make_splits()` accepts `rank` and `world_size` arguments to properly distribute to each process. The defaults for these values are `rank=0` and `world_size=1`, so single-process training can safely ignore these parameters.
Usage:
```python
from dalle2_pytorch.dataloaders import get_reader, make_splits
# grab embeddings from some specified location
IMG_URL = "data/img_emb/"
META_URL = "data/meta/"
reader = get_reader(text_conditioned=True, img_url=IMG_URL, meta_url=META_URL)
# some config for training
TRAIN_ARGS = {
"world_size": 3,
"text_conditioned": True,
"start": 0,
"num_data_points": 10000,
"batch_size": 2,
"train_split": 0.5,
"eval_split": 0.25,
"image_reader": reader,
}
# specifying a rank will handle allocation internally
rank0_train, rank0_eval, rank0_test = make_splits(rank=0, **TRAIN_ARGS)
rank1_train, rank1_eval, rank1_test = make_splits(rank=1, **TRAIN_ARGS)
rank2_train, rank2_eval, rank2_test = make_splits(rank=2, **TRAIN_ARGS)
```

View File

@@ -1,2 +1,2 @@
from dalle2_pytorch.dataloaders.decoder_loader import ImageEmbeddingDataset, create_image_embedding_dataloader
from dalle2_pytorch.dataloaders.embedding_wrapper import make_splits
from dalle2_pytorch.dataloaders.prior_loader import make_splits, get_reader, PriorEmbeddingDataset

View File

@@ -164,9 +164,6 @@ class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
# There may be webdataset shards that do not have a embedding shard associated with it. If we do not skip these, they would cause issues.
self.append(skip_unassociated_shards(embeddings_url=embedding_folder_url, handler=handler))
self.append(wds.split_by_node)
self.append(wds.split_by_worker)
self.append(wds.tarfile_to_samples(handler=handler))
self.append(wds.decode("pilrgb", handler=handler))
if embedding_folder_url is not None:

View File

@@ -1,180 +0,0 @@
from torch.utils.data import IterableDataset
from torch import from_numpy
from clip import tokenize
from embedding_reader import EmbeddingReader
class PriorEmbeddingLoader(IterableDataset):
def __init__(
self,
text_conditioned: bool,
batch_size: int,
start: int,
stop: int,
image_reader,
text_reader: EmbeddingReader = None,
device: str = "cpu",
) -> None:
super(PriorEmbeddingLoader).__init__()
self.text_conditioned = text_conditioned
if not self.text_conditioned:
self.text_reader = text_reader
self.image_reader = image_reader
self.batch_size = batch_size
self.start = start
self.stop = stop
self.device = device
def __iter__(self):
self.n = 0
loader_args = dict(
batch_size=self.batch_size,
start=self.start,
end=self.stop,
show_progress=False,
)
if self.text_conditioned:
self.loader = self.image_reader(**loader_args)
else:
self.loader = zip(
self.image_reader(**loader_args), self.text_reader(**loader_args)
)
return self
def __next__(self):
try:
return self.get_sample()
except StopIteration:
raise StopIteration
def get_sample(self):
"""
pre-proocess data from either reader into a common format
"""
self.n += 1
if self.text_conditioned:
image_embedding, caption = next(self.loader)
image_embedding = from_numpy(image_embedding).to(self.device)
tokenized_caption = tokenize(
caption["caption"].to_list(), truncate=True
).to(self.device)
return image_embedding, tokenized_caption
else:
(image_embedding, _), (text_embedding, _) = next(self.loader)
image_embedding = from_numpy(image_embedding).to(self.device)
text_embedding = from_numpy(text_embedding).to(self.device)
return image_embedding, text_embedding
def make_splits(
text_conditioned: bool,
batch_size: int,
num_data_points: int,
train_split: float,
eval_split: float,
device: str,
img_url: str,
meta_url: str = None,
txt_url: str = None,
):
assert img_url is not None, "Must supply some image embeddings"
if text_conditioned:
assert meta_url is not None, "Must supply metadata url if text-conditioning"
image_reader = EmbeddingReader(
embeddings_folder=img_url,
file_format="parquet_npy",
meta_columns=["caption"],
metadata_folder=meta_url,
)
# compute split points
if num_data_points > image_reader.count:
print("Specified point count is larger than the number of points available...defaulting to max length of reader.")
num_data_points = image_reader.count
train_set_size = int(train_split * num_data_points)
eval_set_size = int(eval_split * num_data_points)
eval_stop = int(train_set_size + eval_set_size)
train_loader = PriorEmbeddingLoader(
text_conditioned=text_conditioned,
image_reader=image_reader,
batch_size=batch_size,
start=0,
stop=train_set_size,
device=device,
)
eval_loader = PriorEmbeddingLoader(
text_conditioned=text_conditioned,
image_reader=image_reader,
batch_size=batch_size,
start=train_set_size,
stop=eval_stop,
device=device,
)
test_loader = PriorEmbeddingLoader(
text_conditioned=text_conditioned,
image_reader=image_reader,
batch_size=batch_size,
start=eval_stop,
stop=int(num_data_points),
device=device,
)
else:
assert (
txt_url is not None
), "Must supply text embedding url if not text-conditioning"
image_reader = EmbeddingReader(img_url, file_format="npy")
text_reader = EmbeddingReader(txt_url, file_format="npy")
# compute split points
if num_data_points > image_reader.count:
print("Specified point count is larger than the number of points available...defaulting to max length of reader.")
num_data_points = image_reader.count
train_set_size = int(train_split * num_data_points)
eval_set_size = int(eval_split * num_data_points)
eval_stop = int(train_set_size + eval_set_size)
train_loader = PriorEmbeddingLoader(
text_conditioned=text_conditioned,
image_reader=image_reader,
text_reader=text_reader,
batch_size=batch_size,
start=0,
stop=train_set_size,
device=device,
)
eval_loader = PriorEmbeddingLoader(
text_conditioned=text_conditioned,
image_reader=image_reader,
text_reader=text_reader,
batch_size=batch_size,
start=train_set_size,
stop=eval_stop,
device=device,
)
test_loader = PriorEmbeddingLoader(
text_conditioned=text_conditioned,
image_reader=image_reader,
text_reader=text_reader,
batch_size=batch_size,
start=eval_stop,
stop=int(num_data_points),
device=device,
)
return train_loader, eval_loader, test_loader

View File

@@ -0,0 +1,273 @@
from math import ceil
from clip import tokenize
from embedding_reader import EmbeddingReader
from torch import from_numpy
from torch.utils.data import IterableDataset, DataLoader
class PriorEmbeddingDataset(IterableDataset):
"""
PriorEmbeddingDataset is a wrapper of EmbeddingReader.
It enables one to simplify the logic necessary to yield samples from
the different EmbeddingReader configurations available.
"""
def __init__(
self,
text_conditioned: bool,
batch_size: int,
start: int,
stop: int,
image_reader,
text_reader: EmbeddingReader = None,
) -> None:
super(PriorEmbeddingDataset).__init__()
self.text_conditioned = text_conditioned
if not self.text_conditioned:
self.text_reader = text_reader
self.image_reader = image_reader
self.start = start
self.stop = stop
self.batch_size = batch_size
def __len__(self):
return self.stop - self.start
def __iter__(self):
# D.R.Y loader args
loader_args = dict(
batch_size=self.batch_size,
start=self.start,
end=self.stop,
show_progress=False,
)
# if the data requested is text conditioned, only load images
if self.text_conditioned:
self.loader = self.image_reader(**loader_args)
# otherwise, include text embeddings and bypass metadata
else:
self.loader = zip(
self.image_reader(**loader_args), self.text_reader(**loader_args)
)
# return the data loader in its formatted state
return self
def __next__(self):
try:
return self.get_sample()
except StopIteration:
raise StopIteration
def __str__(self):
return f"<PriorEmbeddingDataset: start: {self.start}, stop: {self.stop}, len: {self.__len__()}>"
def get_sample(self):
"""
pre-proocess data from either reader into a common format
"""
if self.text_conditioned:
image_embedding, caption = next(self.loader)
image_embedding = from_numpy(image_embedding)
tokenized_caption = tokenize(caption["caption"].to_list(), truncate=True)
return image_embedding, tokenized_caption
else:
(image_embedding, _), (text_embedding, _) = next(self.loader)
image_embedding = from_numpy(image_embedding)
text_embedding = from_numpy(text_embedding)
return image_embedding, text_embedding
# helper functions
def distribute_to_rank(start, stop, rank, world_size):
"""
Distribute data to each rank given the world size.
Return:
- New start and stop points for this rank.
"""
num_samples = int(stop - start)
per_rank = int(ceil((num_samples) / float(world_size)))
assert (
per_rank > 0
), f"Number of samples per rank must be larger than 0, (found: {per_rank})"
rank_start = start + rank * per_rank
rank_stop = min(rank_start + per_rank, stop)
new_length = rank_stop - rank_start
assert (
new_length > 0
), "Calculated start and stop points result in a length of zero for this rank."
return rank_start, rank_stop
def get_reader(
text_conditioned: bool, img_url: str, meta_url: str = None, txt_url: str = None
):
"""
Create an EmbeddingReader object from the specified URLs
get_reader() will always expect a url to image embeddings.
If text-conditioned, it will also expect a meta_url for the captions.
Otherwise, it will need txt_url for the matching text embeddings.
Returns an image_reader object if text-conditioned.
Otherwise it returns both an image_reader and a text_reader
"""
assert img_url is not None, "Must supply a image url"
if text_conditioned:
assert meta_url is not None, "Must supply meta url if text-conditioned"
image_reader = EmbeddingReader(
embeddings_folder=img_url,
file_format="parquet_npy",
# will assume the caption column exists and is the only one requested
meta_columns=["caption"],
metadata_folder=meta_url,
)
return image_reader
# otherwise we will require text embeddings as well and return two readers
assert (
txt_url is not None
), "Must supply text embedding url if not text-conditioning"
image_reader = EmbeddingReader(img_url, file_format="npy")
text_reader = EmbeddingReader(txt_url, file_format="npy")
return image_reader, text_reader
def make_splits(
text_conditioned: bool,
batch_size: int,
num_data_points: int,
train_split: float,
eval_split: float,
image_reader: EmbeddingReader,
text_reader: EmbeddingReader = None,
start=0,
rank=0,
world_size=1,
):
"""
Split an embedding reader object as needed.
NOTE: make_splits() will infer the test set size from your train and eval.
Input:
- text_conditioned: whether to prepare text-conditioned training data
- batch_size: the batch size for a single gpu
- num_data_points: the total number of data points you wish to train on
- train_split: the percentage of data you wish to train on
- eval_split: the percentage of data you wish to validate on
- image_reader: the image_reader you wish to split
- text_reader: the text_reader you want to split (if !text_conditioned)
- start: the starting point within your dataset
- rank: the rank of your worker
- world_size: the total world size of your distributed training run
Returns:
- PyTorch Dataloaders that yield tuples of (img, txt) data.
"""
assert start < image_reader.count, "start position cannot exceed reader count."
# verify that the num_data_points does not exceed the max points
if num_data_points > (image_reader.count - start):
print(
"Specified count is larger than what's available...defaulting to reader's count."
)
num_data_points = image_reader.count
# compute split points
train_set_size = int(train_split * num_data_points)
eval_set_size = int(eval_split * num_data_points)
eval_start = train_set_size
eval_stop = int(eval_start + eval_set_size)
assert (
train_split + eval_split
) < 1.0, "Specified train and eval split is too large to infer a test split."
# distribute to rank
rank_train_start, rank_train_stop = distribute_to_rank(
start, train_set_size, rank, world_size
)
rank_eval_start, rank_eval_stop = distribute_to_rank(
train_set_size, eval_stop, rank, world_size
)
rank_test_start, rank_test_stop = distribute_to_rank(
eval_stop, num_data_points, rank, world_size
)
# wrap up splits into a dict
train_split_args = dict(
start=rank_train_start, stop=rank_train_stop, batch_size=batch_size
)
eval_split_args = dict(
start=rank_eval_start, stop=rank_eval_stop, batch_size=batch_size
)
test_split_args = dict(
start=rank_test_start, stop=rank_test_stop, batch_size=batch_size
)
if text_conditioned:
# add the text-conditioned args to a unified dict
reader_args = dict(
text_conditioned=text_conditioned,
image_reader=image_reader,
)
train_split_args = dict(**reader_args, **train_split_args)
eval_split_args = dict(**reader_args, **eval_split_args)
test_split_args = dict(**reader_args, **test_split_args)
train = PriorEmbeddingDataset(**train_split_args)
val = PriorEmbeddingDataset(**eval_split_args)
test = PriorEmbeddingDataset(**test_split_args)
else:
# add the non-conditioned args to a unified dict
reader_args = dict(
text_conditioned=text_conditioned,
image_reader=image_reader,
text_reader=text_reader,
)
train_split_args = dict(**reader_args, **train_split_args)
eval_split_args = dict(**reader_args, **eval_split_args)
test_split_args = dict(**reader_args, **test_split_args)
train = PriorEmbeddingDataset(**train_split_args)
val = PriorEmbeddingDataset(**eval_split_args)
test = PriorEmbeddingDataset(**test_split_args)
# true batch size is specifed in the PriorEmbeddingDataset
train_loader = DataLoader(train, batch_size=None)
eval_loader = DataLoader(val, batch_size=None)
test_loader = DataLoader(test, batch_size=None)
return train_loader, eval_loader, test_loader

View File

@@ -1,17 +1,20 @@
from torch.optim import AdamW, Adam
def separate_weight_decayable_params(params):
no_wd_params = set([param for param in params if param.ndim < 2])
wd_params = set(params) - no_wd_params
wd_params, no_wd_params = [], []
for param in params:
param_list = no_wd_params if param.ndim < 2 else wd_params
param_list.append(param)
return wd_params, no_wd_params
def get_optimizer(
params,
lr = 1e-4,
wd = 1e-2,
betas = (0.9, 0.999),
betas = (0.9, 0.99),
eps = 1e-8,
filter_by_requires_grad = False,
group_wd_params = True,
**kwargs
):
if filter_by_requires_grad:
@@ -20,12 +23,12 @@ def get_optimizer(
if wd == 0:
return Adam(params, lr = lr, betas = betas, eps = eps)
params = set(params)
wd_params, no_wd_params = separate_weight_decayable_params(params)
if group_wd_params:
wd_params, no_wd_params = separate_weight_decayable_params(params)
param_groups = [
{'params': list(wd_params)},
{'params': list(no_wd_params), 'weight_decay': 0},
]
params = [
{'params': wd_params},
{'params': no_wd_params, 'weight_decay': 0},
]
return AdamW(param_groups, lr = lr, weight_decay = wd, betas = betas, eps = eps)
return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps)

View File

@@ -2,7 +2,6 @@
# to give users a quick easy start to training DALL-E without doing BPE
import torch
import youtokentome as yttm
import html
import os
@@ -11,6 +10,8 @@ import regex as re
from functools import lru_cache
from pathlib import Path
from dalle2_pytorch.utils import import_or_print_error
# OpenAI simple tokenizer
@lru_cache()
@@ -156,7 +157,9 @@ class YttmTokenizer:
bpe_path = Path(bpe_path)
assert bpe_path.exists(), f'BPE json path {str(bpe_path)} does not exist'
tokenizer = yttm.BPE(model = str(bpe_path))
self.yttm = import_or_print_error('youtokentome', 'you need to install youtokentome by `pip install youtokentome`')
tokenizer = self.yttm.BPE(model = str(bpe_path))
self.tokenizer = tokenizer
self.vocab_size = tokenizer.vocab_size()
@@ -167,7 +170,7 @@ class YttmTokenizer:
return self.tokenizer.decode(tokens, ignore_ids = pad_tokens.union({0}))
def encode(self, texts):
encoded = self.tokenizer.encode(texts, output_type = yttm.OutputType.ID)
encoded = self.tokenizer.encode(texts, output_type = self.yttm.OutputType.ID)
return list(map(torch.tensor, encoded))
def tokenize(self, texts, context_length = 256, truncate_text = False):

View File

@@ -6,6 +6,8 @@ from itertools import zip_longest
import torch
from torch import nn
from dalle2_pytorch.utils import import_or_print_error
# constants
DEFAULT_DATA_PATH = './.tracker-data'
@@ -15,23 +17,15 @@ DEFAULT_DATA_PATH = './.tracker-data'
def exists(val):
return val is not None
def import_or_print_error(pkg_name, err_str = None):
try:
return importlib.import_module(pkg_name)
except ModuleNotFoundError as e:
if exists(err_str):
print(err_str)
exit()
# load file functions
# load state dict functions
def load_wandb_state_dict(run_path, file_path, **kwargs):
def load_wandb_file(run_path, file_path, **kwargs):
wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb recall function')
file_reference = wandb.restore(file_path, run_path=run_path)
return torch.load(file_reference.name)
return file_reference.name
def load_local_state_dict(file_path, **kwargs):
return torch.load(file_path)
def load_local_file(file_path, **kwargs):
return file_path
# base class
@@ -61,12 +55,43 @@ class BaseTracker(nn.Module):
"""
# TODO: Pull this into a dict or something similar so that we can add more sources without having a massive switch statement
if recall_source == 'wandb':
return load_wandb_state_dict(*args, **kwargs)
return torch.load(load_wandb_file(*args, **kwargs))
elif recall_source == 'local':
return load_local_state_dict(*args, **kwargs)
return torch.load(load_local_file(*args, **kwargs))
else:
raise ValueError('`recall_source` must be one of `wandb` or `local`')
def save_file(self, file_path, **kwargs):
raise NotImplementedError
def recall_file(self, recall_source, *args, **kwargs):
if recall_source == 'wandb':
return load_wandb_file(*args, **kwargs)
elif recall_source == 'local':
return load_local_file(*args, **kwargs)
else:
raise ValueError('`recall_source` must be one of `wandb` or `local`')
# Tracker that no-ops all calls except for recall
class DummyTracker(BaseTracker):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def init(self, config, **kwargs):
pass
def log(self, log, **kwargs):
pass
def log_images(self, images, **kwargs):
pass
def save_state_dict(self, state_dict, relative_path, **kwargs):
pass
def save_file(self, file_path, **kwargs):
pass
# basic stdout class
@@ -82,6 +107,10 @@ class ConsoleTracker(BaseTracker):
def save_state_dict(self, state_dict, relative_path, **kwargs):
torch.save(state_dict, str(self.data_path / relative_path))
def save_file(self, file_path, **kwargs):
# This is a no-op for local file systems since it is already saved locally
pass
# basic wandb class
@@ -113,3 +142,11 @@ class WandbTracker(BaseTracker):
full_path = str(self.data_path / relative_path)
torch.save(state_dict, full_path)
self.wandb.save(full_path, base_path = str(self.data_path)) # Upload and keep relative to data_path
def save_file(self, file_path, base_path=None, **kwargs):
"""
Uploads a file from disk to wandb
"""
if base_path is None:
base_path = self.data_path
self.wandb.save(str(file_path), base_path = str(base_path))

View File

@@ -3,15 +3,160 @@ from torchvision import transforms as T
from pydantic import BaseModel, validator, root_validator
from typing import List, Iterable, Optional, Union, Tuple, Dict, Any
from x_clip import CLIP as XCLIP
from coca_pytorch import CoCa
from dalle2_pytorch.dalle2_pytorch import (
CoCaAdapter,
OpenAIClipAdapter,
Unet,
Decoder,
DiffusionPrior,
DiffusionPriorNetwork,
XClipAdapter,
)
# helper functions
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def ListOrTuple(inner_type):
return Union[List[inner_type], Tuple[inner_type]]
def SingularOrIterable(inner_type):
return Union[inner_type, ListOrTuple(inner_type)]
# general pydantic classes
class TrainSplitConfig(BaseModel):
train: float = 0.75
val: float = 0.15
test: float = 0.1
@root_validator
def validate_all(cls, fields):
actual_sum = sum([*fields.values()])
if actual_sum != 1.:
raise ValueError(f'{fields.keys()} must sum to 1.0. Found: {actual_sum}')
return fields
class TrackerConfig(BaseModel):
tracker_type: str = 'console' # Decoder currently supports console and wandb
data_path: str = './models' # The path where files will be saved locally
init_config: Dict[str, Any] = None
wandb_entity: str = '' # Only needs to be set if tracker_type is wandb
wandb_project: str = ''
verbose: bool = False # Whether to print console logging for non-console trackers
# diffusion prior pydantic classes
class AdapterConfig(BaseModel):
make: str = "openai"
model: str = "ViT-L/14"
base_model_kwargs: Dict[str, Any] = None
def create(self):
if self.make == "openai":
return OpenAIClipAdapter(self.model)
elif self.make == "x-clip":
return XClipAdapter(XCLIP(**self.base_model_kwargs))
elif self.make == "coca":
return CoCaAdapter(CoCa(**self.base_model_kwargs))
else:
raise AttributeError("No adapter with that name is available.")
class DiffusionPriorNetworkConfig(BaseModel):
dim: int
depth: int
num_timesteps: int = None
num_time_embeds: int = 1
num_image_embeds: int = 1
num_text_embeds: int = 1
dim_head: int = 64
heads: int = 8
ff_mult: int = 4
norm_out: bool = True
attn_dropout: float = 0.
ff_dropout: float = 0.
final_proj: bool = True
normformer: bool = False
rotary_emb: bool = True
def create(self):
kwargs = self.dict()
return DiffusionPriorNetwork(**kwargs)
class DiffusionPriorConfig(BaseModel):
clip: AdapterConfig = None
net: DiffusionPriorNetworkConfig
image_embed_dim: int
image_size: int
image_channels: int = 3
timesteps: int = 1000
cond_drop_prob: float = 0.
loss_type: str = 'l2'
predict_x_start: bool = True
beta_schedule: str = 'cosine'
condition_on_text_encodings: bool = True
class Config:
extra = "allow"
def create(self):
kwargs = self.dict()
has_clip = exists(kwargs.pop('clip'))
kwargs.pop('net')
clip = None
if has_clip:
clip = self.clip.create()
diffusion_prior_network = self.net.create()
return DiffusionPrior(net = diffusion_prior_network, clip = clip, **kwargs)
class DiffusionPriorTrainConfig(BaseModel):
epochs: int = 1
lr: float = 1.1e-4
wd: float = 6.02e-2
max_grad_norm: float = 0.5
use_ema: bool = True
ema_beta: float = 0.99
amp: bool = False
save_every: int = 10000 # what steps to save on
class DiffusionPriorDataConfig(BaseModel):
image_url: str # path to embeddings folder
meta_url: str # path to metadata (captions) for images
splits: TrainSplitConfig
batch_size: int = 64
class DiffusionPriorLoadConfig(BaseModel):
source: str = None
resume: bool = False
class TrainDiffusionPriorConfig(BaseModel):
prior: DiffusionPriorConfig
data: DiffusionPriorDataConfig
train: DiffusionPriorTrainConfig
load: DiffusionPriorLoadConfig
tracker: TrackerConfig
@classmethod
def from_json_path(cls, json_path):
with open(json_path) as f:
config = json.load(f)
return cls(**config)
# decoder pydantic classes
class UnetConfig(BaseModel):
dim: int
dim_mults: List[int]
dim_mults: ListOrTuple(int)
image_embed_dim: int = None
cond_dim: int = None
channels: int = 3
@@ -22,13 +167,22 @@ class UnetConfig(BaseModel):
extra = "allow"
class DecoderConfig(BaseModel):
unets: ListOrTuple(UnetConfig)
image_size: int = None
image_sizes: Union[List[int], Tuple[int]] = None
image_sizes: ListOrTuple(int) = None
channels: int = 3
timesteps: int = 1000
loss_type: str = 'l2'
beta_schedule: str = 'cosine'
beta_schedule: ListOrTuple(str) = 'cosine'
learned_variance: bool = True
image_cond_drop_prob: float = 0.1
text_cond_drop_prob: float = 0.5
def create(self):
decoder_kwargs = self.dict()
unet_configs = decoder_kwargs.pop('unets')
unets = [Unet(**config) for config in unet_configs]
return Decoder(unets, **decoder_kwargs)
@validator('image_sizes')
def check_image_sizes(cls, image_sizes, values):
@@ -39,17 +193,6 @@ class DecoderConfig(BaseModel):
class Config:
extra = "allow"
class TrainSplitConfig(BaseModel):
train: float = 0.75
val: float = 0.15
test: float = 0.1
@root_validator
def validate_all(cls, fields):
if sum([*fields.values()]) != 1.:
raise ValueError(f'{fields.keys()} must sum to 1.0')
return fields
class DecoderDataConfig(BaseModel):
webdataset_base_url: str # path to a webdataset with jpg images
embeddings_url: str # path to .npy files with embeddings
@@ -82,21 +225,21 @@ class DecoderDataConfig(BaseModel):
class DecoderTrainConfig(BaseModel):
epochs: int = 20
lr: float = 1e-4
wd: float = 0.01
max_grad_norm: float = 0.5
lr: SingularOrIterable(float) = 1e-4
wd: SingularOrIterable(float) = 0.01
max_grad_norm: SingularOrIterable(float) = 0.5
save_every_n_samples: int = 100000
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
device: str = 'cuda:0'
epoch_samples: int = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
validation_samples: int = None # Same as above but for validation.
epoch_samples: int = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
validation_samples: int = None # Same as above but for validation.
use_ema: bool = True
ema_beta: float = 0.99
ema_beta: float = 0.999
amp: bool = False
save_all: bool = False # Whether to preserve all checkpoints
save_latest: bool = True # Whether to always save the latest checkpoint
save_best: bool = True # Whether to save the best checkpoint
unet_training_mask: List[bool] = None # If None, use all unets
save_all: bool = False # Whether to preserve all checkpoints
save_latest: bool = True # Whether to always save the latest checkpoint
save_best: bool = True # Whether to save the best checkpoint
unet_training_mask: ListOrTuple(bool) = None # If None, use all unets
class DecoderEvaluateConfig(BaseModel):
n_evaluation_samples: int = 1000
@@ -105,14 +248,6 @@ class DecoderEvaluateConfig(BaseModel):
KID: Dict[str, Any] = None
LPIPS: Dict[str, Any] = None
class TrackerConfig(BaseModel):
tracker_type: str = 'console' # Decoder currently supports console and wandb
data_path: str = './models' # The path where files will be saved locally
init_config: Dict[str, Any] = None
wandb_entity: str = '' # Only needs to be set if tracker_type is wandb
wandb_project: str = ''
verbose: bool = False # Whether to print console logging for non-console trackers
class DecoderLoadConfig(BaseModel):
source: str = None # Supports file and wandb
run_path: str = '' # Used only if source is wandb
@@ -120,13 +255,13 @@ class DecoderLoadConfig(BaseModel):
resume: bool = False # If using wandb, whether to resume the run
class TrainDecoderConfig(BaseModel):
unets: List[UnetConfig]
decoder: DecoderConfig
data: DecoderDataConfig
train: DecoderTrainConfig
evaluate: DecoderEvaluateConfig
tracker: TrackerConfig
load: DecoderLoadConfig
seed: int = 0
@classmethod
def from_json_path(cls, json_path):

View File

@@ -11,6 +11,12 @@ from torch.cuda.amp import autocast, GradScaler
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
from dalle2_pytorch.optimizer import get_optimizer
from dalle2_pytorch.version import __version__
from packaging import version
from ema_pytorch import EMA
from accelerate import Accelerator
import numpy as np
@@ -20,7 +26,9 @@ def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
if exists(val):
return val
return d() if callable(d) else d
def cast_tuple(val, length = 1):
return val if isinstance(val, tuple) else ((val,) * length)
@@ -56,10 +64,6 @@ def num_to_groups(num, divisor):
arr.append(remainder)
return arr
def get_pkg_version():
from pkg_resources import get_distribution
return get_distribution('dalle2_pytorch').version
# decorators
def cast_torch_tensor(fn):
@@ -133,105 +137,6 @@ def split_args_and_kwargs(*args, split_size = None, **kwargs):
chunk_size_frac = chunk_size / batch_size
yield chunk_size_frac, (chunked_args, chunked_kwargs)
# saving and loading functions
# for diffusion prior
def load_diffusion_model(dprior_path, device):
dprior_path = Path(dprior_path)
assert dprior_path.exists(), 'Dprior model file does not exist'
loaded_obj = torch.load(str(dprior_path), map_location='cpu')
# Get hyperparameters of loaded model
dpn_config = loaded_obj['hparams']['diffusion_prior_network']
dp_config = loaded_obj['hparams']['diffusion_prior']
image_embed_dim = loaded_obj['image_embed_dim']['image_embed_dim']
# Create DiffusionPriorNetwork and DiffusionPrior with loaded hyperparameters
# DiffusionPriorNetwork
prior_network = DiffusionPriorNetwork( dim = image_embed_dim, **dpn_config).to(device)
# DiffusionPrior with text embeddings and image embeddings pre-computed
diffusion_prior = DiffusionPrior(net = prior_network, **dp_config, image_embed_dim = image_embed_dim).to(device)
# Load state dict from saved model
diffusion_prior.load_state_dict(loaded_obj['model'])
return diffusion_prior, loaded_obj
def save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim):
# Saving State Dict
print_ribbon('Saving checkpoint')
state_dict = dict(model=model.state_dict(),
optimizer=optimizer.state_dict(),
scaler=scaler.state_dict(),
hparams = config,
image_embed_dim = {"image_embed_dim":image_embed_dim})
torch.save(state_dict, save_path+'/'+str(time.time())+'_saved_model.pth')
# exponential moving average wrapper
class EMA(nn.Module):
def __init__(
self,
model,
beta = 0.9999,
update_after_step = 1000,
update_every = 10,
):
super().__init__()
self.beta = beta
self.online_model = model
self.ema_model = copy.deepcopy(model)
self.update_every = update_every
self.update_after_step = update_after_step // update_every # only start EMA after this step number, starting at 0
self.register_buffer('initted', torch.Tensor([False]))
self.register_buffer('step', torch.tensor([0]))
def restore_ema_model_device(self):
device = self.initted.device
self.ema_model.to(device)
def copy_params_from_model_to_ema(self):
self.ema_model.state_dict(self.online_model.state_dict())
def update(self):
self.step += 1
if (self.step % self.update_every) != 0:
return
if self.step <= self.update_after_step:
self.copy_params_from_model_to_ema()
return
if not self.initted:
self.copy_params_from_model_to_ema()
self.initted.data.copy_(torch.Tensor([True]))
self.update_moving_average(self.ema_model, self.online_model)
def update_moving_average(self, ma_model, current_model):
def calculate_ema(beta, old, new):
if not exists(old):
return new
return old * beta + (1 - beta) * new
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
old_weight, up_weight = ma_params.data, current_params.data
ma_params.data = calculate_ema(self.beta, old_weight, up_weight)
for current_buffer, ma_buffer in zip(current_model.buffers(), ma_model.buffers()):
new_buffer_value = calculate_ema(self.beta, ma_buffer, current_buffer)
ma_buffer.copy_(new_buffer_value)
def __call__(self, *args, **kwargs):
return self.ema_model(*args, **kwargs)
# diffusion prior trainer
def prior_sample_in_chunks(fn):
@@ -254,84 +159,190 @@ class DiffusionPriorTrainer(nn.Module):
eps = 1e-6,
max_grad_norm = None,
amp = False,
group_wd_params = True,
device = None,
accelerator = None,
**kwargs
):
super().__init__()
assert isinstance(diffusion_prior, DiffusionPrior)
assert not exists(accelerator) or isinstance(accelerator, Accelerator)
assert exists(accelerator) or exists(device), "You must supply some method of obtaining a device."
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
# assign some helpful member vars
self.accelerator = accelerator
self.device = accelerator.device if exists(accelerator) else device
self.text_conditioned = diffusion_prior.condition_on_text_encodings
# save model
self.diffusion_prior = diffusion_prior
# exponential moving average
self.use_ema = use_ema
if self.use_ema:
self.ema_diffusion_prior = EMA(diffusion_prior, **ema_kwargs)
# optimizer and mixed precision stuff
self.amp = amp
self.scaler = GradScaler(enabled = amp)
self.optim_kwargs = dict(lr=lr, wd=wd, eps=eps, group_wd_params=group_wd_params)
self.optimizer = get_optimizer(
diffusion_prior.parameters(),
lr = lr,
wd = wd,
eps = eps,
self.diffusion_prior.parameters(),
**self.optim_kwargs,
**kwargs
)
# distribute the model if using HFA
if exists(self.accelerator):
self.diffusion_prior, self.optimizer = self.accelerator.prepare(self.diffusion_prior, self.optimizer)
# exponential moving average stuff
self.use_ema = use_ema
if self.use_ema:
self.ema_diffusion_prior = EMA(self.unwrap_model(self.diffusion_prior), **ema_kwargs)
# gradient clipping if needed
self.max_grad_norm = max_grad_norm
# track steps internally
self.register_buffer('step', torch.tensor([0]))
def save(self, path, overwrite = True):
path = Path(path)
assert not (path.exists() and not overwrite)
path.parent.mkdir(parents = True, exist_ok = True)
# accelerator wrappers
save_obj = dict(
scaler = self.scaler.state_dict(),
optimizer = self.optimizer.state_dict(),
model = self.diffusion_prior.state_dict(),
version = get_pkg_version(),
step = self.step.item()
)
def print(self, msg):
if exists(self.accelerator):
self.accelerator.print(msg)
else:
print(msg)
if self.use_ema:
save_obj = {**save_obj, 'ema': self.ema_diffusion_prior.state_dict()}
def unwrap_model(self, model):
if exists(self.accelerator):
return self.accelerator.unwrap_model(model)
else:
return model
torch.save(save_obj, str(path))
def wait_for_everyone(self):
if exists(self.accelerator):
self.accelerator.wait_for_everyone()
def load(self, path, only_model = False, strict = True):
def is_main_process(self):
if exists(self.accelerator):
return self.accelerator.is_main_process
else:
return True
def clip_grad_norm_(self, *args):
if exists(self.accelerator):
return self.accelerator.clip_grad_norm_(*args)
else:
return torch.nn.utils.clip_grad_norm_(*args)
def backprop(self, x):
if exists(self.accelerator):
self.accelerator.backward(x)
else:
try:
x.backward()
except Exception as e:
self.print(f"Caught error in backprop call: {e}")
# utility
def save(self, path, overwrite = True, **kwargs):
# ensure we sync gradients before continuing
self.wait_for_everyone()
# only save on the main process
if self.is_main_process():
self.print(f"Saving checkpoint at step: {self.step.item()}")
path = Path(path)
assert not (path.exists() and not overwrite)
path.parent.mkdir(parents = True, exist_ok = True)
save_obj = dict(
scaler = self.scaler.state_dict(),
optimizer = self.optimizer.state_dict(),
model = self.unwrap_model(self.diffusion_prior).state_dict(), # unwrap the model from distribution if applicable
version = version.parse(__version__),
step = self.step.item(),
**kwargs
)
if self.use_ema:
save_obj = {
**save_obj,
'ema': self.ema_diffusion_prior.state_dict(),
'ema_model': self.ema_diffusion_prior.ema_model.state_dict() # save the ema model specifically for easy ema-only reload
}
torch.save(save_obj, str(path))
def load(self, path, overwrite_lr = True, strict = True):
"""
Load a checkpoint of a diffusion prior trainer.
Will load the entire trainer, including the optimizer and EMA.
Params:
- path (str): a path to the DiffusionPriorTrainer checkpoint file
- overwrite_lr (bool): wether or not to overwrite the stored LR with the LR specified in the new trainer
- strict (bool): kwarg for `torch.nn.Module.load_state_dict`, will force an exact checkpoint match
Returns:
loaded_obj (dict): The loaded checkpoint dictionary
"""
# all processes need to load checkpoint. no restriction here
path = Path(path)
assert path.exists()
loaded_obj = torch.load(str(path))
loaded_obj = torch.load(str(path), map_location=self.device)
if get_pkg_version() != loaded_obj['version']:
print(f'loading saved diffusion prior at version {loaded_obj["version"]} but current package version is at {get_pkg_version()}')
if version.parse(__version__) != loaded_obj['version']:
print(f'loading saved diffusion prior at version {loaded_obj["version"]} but current package version is at {__version__}')
self.diffusion_prior.load_state_dict(loaded_obj['model'], strict = strict)
# unwrap the model when loading from checkpoint
self.unwrap_model(self.diffusion_prior).load_state_dict(loaded_obj['model'], strict = strict)
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
if only_model:
return
self.scaler.load_state_dict(loaded_obj['scaler'])
self.optimizer.load_state_dict(loaded_obj['optimizer'])
if overwrite_lr:
new_lr = self.optim_kwargs["lr"]
self.print(f"Overriding LR to be {new_lr}")
for group in self.optimizer.param_groups:
group["lr"] = new_lr
if self.use_ema:
assert 'ema' in loaded_obj
self.ema_diffusion_prior.load_state_dict(loaded_obj['ema'], strict = strict)
# below not be necessary, but I had a suspicion that this wasn't being loaded correctly
self.ema_diffusion_prior.ema_model.load_state_dict(loaded_obj["ema_model"])
# sync and inform
self.wait_for_everyone()
self.print(f"Loaded model")
return loaded_obj
# model functionality
def update(self):
# only continue with updates until all ranks finish
self.wait_for_everyone()
if exists(self.max_grad_norm):
self.scaler.unscale_(self.optimizer)
nn.utils.clip_grad_norm_(self.diffusion_prior.parameters(), self.max_grad_norm)
# utilize HFA clipping where applicable
self.clip_grad_norm_(self.diffusion_prior.parameters(), self.max_grad_norm)
self.scaler.step(self.optimizer)
self.scaler.update()
@@ -346,17 +357,26 @@ class DiffusionPriorTrainer(nn.Module):
@cast_torch_tensor
@prior_sample_in_chunks
def p_sample_loop(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs)
model = self.ema_diffusion_prior.ema_model if self.use_ema else self.diffusion_prior
return model.p_sample_loop(*args, **kwargs)
@torch.no_grad()
@cast_torch_tensor
@prior_sample_in_chunks
def sample(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs)
model = self.ema_diffusion_prior.ema_model if self.use_ema else self.diffusion_prior
return model.sample(*args, **kwargs)
@torch.no_grad()
def sample_batch_size(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.sample_batch_size(*args, **kwargs)
model = self.ema_diffusion_prior.ema_model if self.use_ema else self.diffusion_prior
return model.sample_batch_size(*args, **kwargs)
@torch.no_grad()
@cast_torch_tensor
@prior_sample_in_chunks
def embed_text(self, *args, **kwargs):
return self.unwrap_model(self.diffusion_prior).clip.embed_text(*args, **kwargs)
@cast_torch_tensor
def forward(
@@ -374,8 +394,10 @@ class DiffusionPriorTrainer(nn.Module):
total_loss += loss.item()
# backprop with accelerate if applicable
if self.training:
self.scaler.scale(loss).backward()
self.backprop(self.scaler.scale(loss))
return total_loss
@@ -401,20 +423,23 @@ class DecoderTrainer(nn.Module):
def __init__(
self,
decoder,
accelerator = None,
use_ema = True,
lr = 1e-4,
wd = 1e-2,
eps = 1e-8,
max_grad_norm = 0.5,
amp = False,
group_wd_params = True,
**kwargs
):
super().__init__()
assert isinstance(decoder, Decoder)
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
self.decoder = decoder
self.num_unets = len(self.decoder.unets)
self.accelerator = default(accelerator, Accelerator)
self.num_unets = len(decoder.unets)
self.use_ema = use_ema
self.ema_unets = nn.ModuleList([])
@@ -426,107 +451,101 @@ class DecoderTrainer(nn.Module):
lr, wd, eps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps))
for ind, (unet, unet_lr, unet_wd, unet_eps) in enumerate(zip(self.decoder.unets, lr, wd, eps)):
optimizers = []
for unet, unet_lr, unet_wd, unet_eps in zip(decoder.unets, lr, wd, eps):
optimizer = get_optimizer(
unet.parameters(),
lr = unet_lr,
wd = unet_wd,
eps = unet_eps,
group_wd_params = group_wd_params,
**kwargs
)
setattr(self, f'optim{ind}', optimizer) # cannot use pytorch ModuleList for some reason with optimizers
optimizers.append(optimizer)
if self.use_ema:
self.ema_unets.append(EMA(unet, **ema_kwargs))
scaler = GradScaler(enabled = amp)
setattr(self, f'scaler{ind}', scaler)
# gradient clipping if needed
self.max_grad_norm = max_grad_norm
self.register_buffer('step', torch.tensor([0.]))
def save(self, path, overwrite = True):
decoder, *optimizers = list(self.accelerator.prepare(decoder, *optimizers))
self.decoder = decoder
for opt_ind, optimizer in zip(range(len(optimizers)), optimizers):
setattr(self, f'optim{opt_ind}', optimizer)
def save(self, path, overwrite = True, **kwargs):
path = Path(path)
assert not (path.exists() and not overwrite)
path.parent.mkdir(parents = True, exist_ok = True)
save_obj = dict(
model = self.decoder.state_dict(),
version = get_pkg_version(),
step = self.step.item()
model = self.accelerator.unwrap_model(self.decoder).state_dict(),
version = __version__,
step = self.step.item(),
**kwargs
)
for ind in range(0, self.num_unets):
scaler_key = f'scaler{ind}'
optimizer_key = f'scaler{ind}'
scaler = getattr(self, scaler_key)
optimizer_key = f'optim{ind}'
optimizer = getattr(self, optimizer_key)
save_obj = {**save_obj, scaler_key: scaler.state_dict(), optimizer_key: optimizer.state_dict()}
save_obj = {**save_obj, optimizer_key: self.accelerator.unwrap_model(optimizer).state_dict()}
if self.use_ema:
save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}
torch.save(save_obj, str(path))
self.accelerator.save(save_obj, str(path))
def load(self, path, only_model = False, strict = True):
path = Path(path)
assert path.exists()
loaded_obj = torch.load(str(path))
loaded_obj = torch.load(str(path), map_location = 'cpu')
if get_pkg_version() != loaded_obj['version']:
print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {get_pkg_version()}')
if version.parse(__version__) != version.parse(loaded_obj['version']):
self.accelerator.print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {__version__}')
self.decoder.load_state_dict(loaded_obj['model'], strict = strict)
self.accelerator.unwrap_model(self.decoder).load_state_dict(loaded_obj['model'], strict = strict)
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
if only_model:
return
return loaded_obj
for ind in range(0, self.num_unets):
scaler_key = f'scaler{ind}'
optimizer_key = f'scaler{ind}'
scaler = getattr(self, scaler_key)
optimizer_key = f'optim{ind}'
optimizer = getattr(self, optimizer_key)
scaler.load_state_dict(loaded_obj[scaler_key])
optimizer.load_state_dict(loaded_obj[optimizer_key])
self.accelerator.unwrap_model(optimizer).load_state_dict(loaded_obj[optimizer_key])
if self.use_ema:
assert 'ema' in loaded_obj
self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict)
return loaded_obj
@property
def unets(self):
return nn.ModuleList([ema.ema_model for ema in self.ema_unets])
def scale(self, loss, *, unet_number):
assert 1 <= unet_number <= self.num_unets
index = unet_number - 1
scaler = getattr(self, f'scaler{index}')
return scaler.scale(loss)
def update(self, unet_number = None):
if self.num_unets == 1:
unet_number = default(unet_number, 1)
assert exists(unet_number) and 1 <= unet_number <= self.num_unets
index = unet_number - 1
unet = self.decoder.unets[index]
optimizer = getattr(self, f'optim{index}')
scaler = getattr(self, f'scaler{index}')
if exists(self.max_grad_norm):
scaler.unscale_(optimizer)
nn.utils.clip_grad_norm_(unet.parameters(), self.max_grad_norm)
scaler.step(optimizer)
scaler.update()
self.accelerator.clip_grad_norm_(self.decoder.parameters(), self.max_grad_norm) # Automatically unscales gradients
optimizer.step()
optimizer.zero_grad()
if self.use_ema:
@@ -539,15 +558,17 @@ class DecoderTrainer(nn.Module):
@cast_torch_tensor
@decoder_sample_in_chunks
def sample(self, *args, **kwargs):
distributed = self.accelerator.num_processes > 1
base_decoder = self.accelerator.unwrap_model(self.decoder)
if kwargs.pop('use_non_ema', False) or not self.use_ema:
return self.decoder.sample(*args, **kwargs)
return base_decoder.sample(*args, **kwargs, distributed = distributed)
trainable_unets = self.decoder.unets
self.decoder.unets = self.unets # swap in exponential moving averaged unets for sampling
trainable_unets = self.accelerator.unwrap_model(self.decoder).unets
base_decoder.unets = self.unets # swap in exponential moving averaged unets for sampling
output = self.decoder.sample(*args, **kwargs)
output = base_decoder.sample(*args, **kwargs, distributed = distributed)
self.decoder.unets = trainable_unets # restore original training unets
base_decoder.unets = trainable_unets # restore original training unets
# cast the ema_model unets back to original device
for ema in self.ema_unets:
@@ -569,13 +590,14 @@ class DecoderTrainer(nn.Module):
total_loss = 0.
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
with autocast(enabled = self.amp):
# with autocast(enabled = self.amp):
with self.accelerator.autocast():
loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs)
loss = loss * chunk_size_frac
total_loss += loss.item()
if self.training:
self.scale(loss, unet_number = unet_number).backward()
self.accelerator.backward(loss)
return total_loss

View File

@@ -1,4 +1,5 @@
import time
import importlib
# time helpers
@@ -17,3 +18,13 @@ class Timer:
def print_ribbon(s, symbol = '=', repeat = 40):
flank = symbol * repeat
return f'{flank} {s} {flank}'
# import helpers
def import_or_print_error(pkg_name, err_str = None):
try:
return importlib.import_module(pkg_name)
except ModuleNotFoundError as e:
if exists(err_str):
print(err_str)
exit()

View File

@@ -0,0 +1 @@
__version__ = '0.11.4'

View File

@@ -68,8 +68,8 @@ def group_dict_by_key(cond, d):
return_val[ind][key] = d[key]
return (*return_val,)
def string_begins_with(prefix, str):
return str.startswith(prefix)
def string_begins_with(prefix, string_input):
return string_input.startswith(prefix)
def group_by_key_prefix(prefix, d):
return group_dict_by_key(partial(string_begins_with, prefix), d)

View File

@@ -16,10 +16,11 @@ from torchvision.utils import make_grid, save_image
from einops import rearrange
from dalle2_pytorch.train import EMA
from dalle2_pytorch.vqgan_vae import VQGanVAE
from dalle2_pytorch.optimizer import get_optimizer
from ema_pytorch import EMA
# helpers
def exists(val):
@@ -97,7 +98,7 @@ class VQGanVAETrainer(nn.Module):
valid_frac = 0.05,
random_split_seed = 42,
ema_beta = 0.995,
ema_update_after_step = 2000,
ema_update_after_step = 500,
ema_update_every = 10,
apply_grad_penalty_every = 4,
amp = False

View File

@@ -1,4 +1,5 @@
from setuptools import setup, find_packages
exec(open('dalle2_pytorch/version.py').read())
setup(
name = 'dalle2-pytorch',
@@ -10,7 +11,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.4.7',
version = __version__,
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',
@@ -23,14 +24,17 @@ setup(
'text to image'
],
install_requires=[
'accelerate',
'click',
'clip-anytorch',
'coca-pytorch>=0.0.5',
'ema-pytorch>=0.0.7',
'einops>=0.4',
'einops-exts>=0.0.3',
'embedding-reader',
'kornia>=0.5.4',
'numpy',
'packaging',
'pillow',
'pydantic',
'resize-right>=0.0.2',
@@ -40,7 +44,6 @@ setup(
'tqdm',
'vector-quantize-pytorch',
'x-clip>=0.4.4',
'youtokentome',
'webdataset>=0.2.5',
'fsspec>=2022.1.0',
'torchmetrics[image]>=0.8.0'

View File

@@ -1,9 +1,11 @@
from dalle2_pytorch import Unet, Decoder
from pathlib import Path
from dalle2_pytorch.trainer import DecoderTrainer
from dalle2_pytorch.dataloaders import create_image_embedding_dataloader
from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker
from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker, DummyTracker
from dalle2_pytorch.train_configs import TrainDecoderConfig
from dalle2_pytorch.utils import Timer, print_ribbon
from dalle2_pytorch.dalle2_pytorch import resize_image_to
import torchvision
import torch
@@ -11,6 +13,8 @@ from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore
from torchmetrics.image.kid import KernelInceptionDistance
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
from accelerate import Accelerator, DistributedDataParallelKwargs
from accelerate.utils import dataclasses as accelerate_dataclasses
import webdataset as wds
import click
@@ -41,6 +45,7 @@ def create_dataloaders(
train_prop = 0.75,
val_prop = 0.15,
test_prop = 0.10,
seed = 0,
**kwargs
):
"""
@@ -51,7 +56,7 @@ def create_dataloaders(
num_test = round(test_prop*len(available_shards))
num_val = len(available_shards) - num_train - num_test
assert num_train + num_test + num_val == len(available_shards), f"{num_train} + {num_test} + {num_val} = {num_train + num_test + num_val} != {len(available_shards)}"
train_split, test_split, val_split = torch.utils.data.random_split(available_shards, [num_train, num_test, num_val], generator=torch.Generator().manual_seed(0))
train_split, test_split, val_split = torch.utils.data.random_split(available_shards, [num_train, num_test, num_val], generator=torch.Generator().manual_seed(seed))
# The shard number in the webdataset file names has a fixed width. We zero pad the shard numbers so they correspond to a filename.
train_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in train_split]
@@ -85,20 +90,6 @@ def create_dataloaders(
"test_sampling": test_sampling_dataloader
}
def create_decoder(device, decoder_config, unets_config):
"""Creates a sample decoder"""
unets = [Unet(**config.dict()) for config in unets_config]
decoder = Decoder(
unet=unets,
**decoder_config.dict()
)
decoder.to(device=device)
return decoder
def get_dataset_keys(dataloader):
"""
It is sometimes neccesary to get the keys the dataloader is returning. Since the dataset is burried in the dataloader, we need to do a process to recover it.
@@ -130,7 +121,6 @@ def get_example_data(dataloader, device, n=5):
captions.extend(list(txt))
if len(images) >= n:
break
print("Generated {} examples".format(len(images)))
return list(zip(images[:n], embeddings[:n], captions[:n]))
def generate_samples(trainer, example_data, text_prepend=""):
@@ -150,6 +140,14 @@ def generate_grid_samples(trainer, examples, text_prepend=""):
Generates samples and uses torchvision to put them in a side by side grid for easy viewing
"""
real_images, generated_images, captions = generate_samples(trainer, examples, text_prepend)
real_image_size = real_images[0].shape[-1]
generated_image_size = generated_images[0].shape[-1]
# training images may be larger than the generated one
if real_image_size > generated_image_size:
real_images = [resize_image_to(image, generated_image_size) for image in real_images]
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
return grid_images, captions
@@ -160,27 +158,34 @@ def evaluate_trainer(trainer, dataloader, device, n_evaluation_samples=1000, FID
metrics = {}
# Prepare the data
examples = get_example_data(dataloader, device, n_evaluation_samples)
if len(examples) == 0:
print("No data to evaluate. Check that your dataloader has shards.")
return metrics
real_images, generated_images, captions = generate_samples(trainer, examples)
real_images = torch.stack(real_images).to(device=device, dtype=torch.float)
generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float)
# Convert from [0, 1] to [0, 255] and from torch.float to torch.uint8
int_real_images = real_images.mul(255).add(0.5).clamp(0, 255).type(torch.uint8)
int_generated_images = generated_images.mul(255).add(0.5).clamp(0, 255).type(torch.uint8)
def null_sync(t, *args, **kwargs):
return [t]
if exists(FID):
fid = FrechetInceptionDistance(**FID)
fid = FrechetInceptionDistance(**FID, dist_sync_fn=null_sync)
fid.to(device=device)
fid.update(int_real_images, real=True)
fid.update(int_generated_images, real=False)
metrics["FID"] = fid.compute().item()
if exists(IS):
inception = InceptionScore(**IS)
inception = InceptionScore(**IS, dist_sync_fn=null_sync)
inception.to(device=device)
inception.update(int_real_images)
is_mean, is_std = inception.compute()
metrics["IS_mean"] = is_mean.item()
metrics["IS_std"] = is_std.item()
if exists(KID):
kernel_inception = KernelInceptionDistance(**KID)
kernel_inception = KernelInceptionDistance(**KID, dist_sync_fn=null_sync)
kernel_inception.to(device=device)
kernel_inception.update(int_real_images, real=True)
kernel_inception.update(int_generated_images, real=False)
@@ -191,39 +196,47 @@ def evaluate_trainer(trainer, dataloader, device, n_evaluation_samples=1000, FID
# Convert from [0, 1] to [-1, 1]
renorm_real_images = real_images.mul(2).sub(1)
renorm_generated_images = generated_images.mul(2).sub(1)
lpips = LearnedPerceptualImagePatchSimilarity(**LPIPS)
lpips = LearnedPerceptualImagePatchSimilarity(**LPIPS, dist_sync_fn=null_sync)
lpips.to(device=device)
lpips.update(renorm_real_images, renorm_generated_images)
metrics["LPIPS"] = lpips.compute().item()
if trainer.accelerator.num_processes > 1:
# Then we should sync the metrics
metrics_order = sorted(metrics.keys())
metrics_tensor = torch.zeros(1, len(metrics), device=device, dtype=torch.float)
for i, metric_name in enumerate(metrics_order):
metrics_tensor[0, i] = metrics[metric_name]
metrics_tensor = trainer.accelerator.gather(metrics_tensor)
metrics_tensor = metrics_tensor.mean(dim=0)
for i, metric_name in enumerate(metrics_order):
metrics[metric_name] = metrics_tensor[i].item()
return metrics
def save_trainer(tracker, trainer, epoch, step, validation_losses, relative_paths):
def save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, relative_paths):
"""
Logs the model with an appropriate method depending on the tracker
"""
if isinstance(relative_paths, str):
relative_paths = [relative_paths]
trainer_state_dict = {}
trainer_state_dict["trainer"] = trainer.state_dict()
trainer_state_dict['epoch'] = epoch
trainer_state_dict['step'] = step
trainer_state_dict['validation_losses'] = validation_losses
for relative_path in relative_paths:
tracker.save_state_dict(trainer_state_dict, relative_path)
local_path = str(tracker.data_path / relative_path)
trainer.save(local_path, epoch=epoch, sample=sample, next_task=next_task, validation_losses=validation_losses)
tracker.save_file(local_path)
def recall_trainer(tracker, trainer, recall_source=None, **load_config):
"""
Loads the model with an appropriate method depending on the tracker
"""
print(print_ribbon(f"Loading model from {recall_source}"))
state_dict = tracker.recall_state_dict(recall_source, **load_config)
trainer.load_state_dict(state_dict["trainer"])
print("Model loaded")
return state_dict["epoch"], state_dict["step"], state_dict["validation_losses"]
trainer.accelerator.print(print_ribbon(f"Loading model from {recall_source}"))
local_filepath = tracker.recall_file(recall_source, **load_config)
state_dict = trainer.load(local_filepath)
return state_dict.get("epoch", 0), state_dict.get("validation_losses", []), state_dict.get("next_task", "train"), state_dict.get("sample", 0)
def train(
dataloaders,
decoder,
accelerator,
tracker,
inference_device,
load_config=None,
@@ -242,17 +255,30 @@ def train(
"""
Trains a decoder on a dataset.
"""
trainer = DecoderTrainer( # TODO: Change the get_optimizer function so that it can take arbitrary named args so we can just put **kwargs as an argument here
is_master = accelerator.process_index == 0
trainer = DecoderTrainer(
accelerator,
decoder,
**kwargs
)
# Set up starting model and parameters based on a recalled state dict
start_step = 0
start_epoch = 0
validation_losses = []
next_task = 'train'
sample = 0
val_sample = 0
step = lambda: int(trainer.step.item())
if exists(load_config) and exists(load_config.source):
start_epoch, start_step, validation_losses = recall_trainer(tracker, trainer, recall_source=load_config.source, **load_config)
start_epoch, validation_losses, next_task, recalled_sample = recall_trainer(tracker, trainer, recall_source=load_config.source, **load_config.dict())
if next_task == 'train':
sample = recalled_sample
if next_task == 'val':
val_sample = recalled_sample
accelerator.print(f"Loaded model from {load_config.source} on epoch {start_epoch} with minimum validation loss {min(validation_losses) if len(validation_losses) > 0 else 'N/A'}")
accelerator.print(f"Starting training from task {next_task} at sample {sample} and validation sample {val_sample}")
trainer.to(device=inference_device)
if not exists(unet_training_mask):
@@ -260,139 +286,185 @@ def train(
unet_training_mask = [True] * trainer.num_unets
assert len(unet_training_mask) == trainer.num_unets, f"The unet training mask should be the same length as the number of unets in the decoder. Got {len(unet_training_mask)} and {trainer.num_unets}"
print(print_ribbon("Generating Example Data", repeat=40))
print("This can take a while to load the shard lists...")
train_example_data = get_example_data(dataloaders["train_sampling"], inference_device, n_sample_images)
test_example_data = get_example_data(dataloaders["test_sampling"], inference_device, n_sample_images)
accelerator.print(print_ribbon("Generating Example Data", repeat=40))
accelerator.print("This can take a while to load the shard lists...")
if is_master:
train_example_data = get_example_data(dataloaders["train_sampling"], inference_device, n_sample_images)
accelerator.print("Generated training examples")
test_example_data = get_example_data(dataloaders["test_sampling"], inference_device, n_sample_images)
accelerator.print("Generated testing examples")
send_to_device = lambda arr: [x.to(device=inference_device, dtype=torch.float) for x in arr]
step = start_step
sample_length_tensor = torch.zeros(1, dtype=torch.int, device=inference_device)
unet_losses_tensor = torch.zeros(TRAIN_CALC_LOSS_EVERY_ITERS, trainer.num_unets, dtype=torch.float, device=inference_device)
for epoch in range(start_epoch, epochs):
print(print_ribbon(f"Starting epoch {epoch}", repeat=40))
accelerator.print(print_ribbon(f"Starting epoch {epoch}", repeat=40))
timer = Timer()
last_sample = sample
last_snapshot = sample
sample = 0
last_sample = 0
last_snapshot = 0
if next_task == 'train':
for i, (img, emb) in enumerate(dataloaders["train"]):
# We want to count the total number of samples across all processes
sample_length_tensor[0] = len(img)
all_samples = accelerator.gather(sample_length_tensor) # TODO: accelerator.reduce is broken when this was written. If it is fixed replace this.
total_samples = all_samples.sum().item()
sample += total_samples
img, emb = send_to_device((img, emb))
losses = []
trainer.train()
for unet in range(1, trainer.num_unets+1):
# Check if this is a unet we are training
if not unet_training_mask[unet-1]: # Unet index is the unet number - 1
continue
for i, (img, emb) in enumerate(dataloaders["train"]):
step += 1
sample += img.shape[0]
img, emb = send_to_device((img, emb))
trainer.train()
for unet in range(1, trainer.num_unets+1):
# Check if this is a unet we are training
if not unet_training_mask[unet-1]: # Unet index is the unet number - 1
continue
loss = trainer.forward(img, image_embed=emb, unet_number=unet)
trainer.update(unet_number=unet)
unet_losses_tensor[i % TRAIN_CALC_LOSS_EVERY_ITERS, unet-1] = loss
samples_per_sec = (sample - last_sample) / timer.elapsed()
timer.reset()
last_sample = sample
loss = trainer.forward(img, image_embed=emb, unet_number=unet)
trainer.update(unet_number=unet)
losses.append(loss)
if i % TRAIN_CALC_LOSS_EVERY_ITERS == 0:
# We want to average losses across all processes
unet_all_losses = accelerator.gather(unet_losses_tensor)
mask = unet_all_losses != 0
unet_average_loss = (unet_all_losses * mask).sum(dim=0) / mask.sum(dim=0)
loss_map = { f"Unet {index} Training Loss": loss.item() for index, loss in enumerate(unet_average_loss) if loss != 0 }
log_data = {
"Epoch": epoch,
"Sample": sample,
"Step": i,
"Samples per second": samples_per_sec,
**loss_map
}
# print(f"I am rank {accelerator.state.process_index}. Example weight: {trainer.decoder.state_dict()['module.unets.0.init_conv.convs.0.weight'][0,0,0,0]}")
if is_master:
tracker.log(log_data, step=step(), verbose=True)
samples_per_sec = (sample - last_sample) / timer.elapsed()
if is_master and last_snapshot + save_every_n_samples < sample: # This will miss by some amount every time, but it's not a big deal... I hope
# It is difficult to gather this kind of info on the accelerator, so we have to do it on the master
print("Saving snapshot")
last_snapshot = sample
# We need to know where the model should be saved
save_paths = []
if save_latest:
save_paths.append("latest.pth")
if save_all:
save_paths.append(f"checkpoints/epoch_{epoch}_step_{step()}.pth")
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, save_paths)
if exists(n_sample_images) and n_sample_images > 0:
trainer.eval()
train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ")
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
if epoch_samples is not None and sample >= epoch_samples:
break
next_task = 'val'
sample = 0
timer.reset()
last_sample = sample
all_average_val_losses = None
if next_task == 'val':
trainer.eval()
accelerator.print(print_ribbon(f"Starting Validation {epoch}", repeat=40))
last_val_sample = val_sample
val_sample_length_tensor = torch.zeros(1, dtype=torch.int, device=inference_device)
average_val_loss_tensor = torch.zeros(1, trainer.num_unets, dtype=torch.float, device=inference_device)
timer = Timer()
accelerator.wait_for_everyone()
i = 0
for i, (img, emb, txt) in enumerate(dataloaders["val"]):
val_sample_length_tensor[0] = len(img)
all_samples = accelerator.gather(val_sample_length_tensor)
total_samples = all_samples.sum().item()
val_sample += total_samples
img, emb = send_to_device((img, emb))
if i % TRAIN_CALC_LOSS_EVERY_ITERS == 0:
average_loss = sum(losses) / len(losses)
log_data = {
"Training loss": average_loss,
"Epoch": epoch,
"Sample": sample,
"Step": i,
"Samples per second": samples_per_sec
}
tracker.log(log_data, step=step, verbose=True)
losses = []
for unet in range(1, len(decoder.unets)+1):
if not unet_training_mask[unet-1]: # Unet index is the unet number - 1
# No need to evaluate an unchanging unet
continue
loss = trainer.forward(img.float(), image_embed=emb.float(), unet_number=unet)
average_val_loss_tensor[0, unet-1] += loss
if last_snapshot + save_every_n_samples < sample: # This will miss by some amount every time, but it's not a big deal... I hope
last_snapshot = sample
# We need to know where the model should be saved
if i % VALID_CALC_LOSS_EVERY_ITERS == 0:
samples_per_sec = (val_sample - last_val_sample) / timer.elapsed()
timer.reset()
last_val_sample = val_sample
accelerator.print(f"Epoch {epoch}/{epochs} Val Step {i} - Sample {val_sample} - {samples_per_sec:.2f} samples/sec")
accelerator.print(f"Loss: {(average_val_loss_tensor / (i+1))}")
accelerator.print("")
if validation_samples is not None and val_sample >= validation_samples:
break
print(f"Rank {accelerator.state.process_index} finished validation after {i} steps")
accelerator.wait_for_everyone()
average_val_loss_tensor /= i+1
# Gather all the average loss tensors
all_average_val_losses = accelerator.gather(average_val_loss_tensor)
if is_master:
unet_average_val_loss = all_average_val_losses.mean(dim=0)
val_loss_map = { f"Unet {index} Validation Loss": loss.item() for index, loss in enumerate(unet_average_val_loss) if loss != 0 }
tracker.log(val_loss_map, step=step(), verbose=True)
next_task = 'eval'
if next_task == 'eval':
if exists(evaluate_config):
accelerator.print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config.dict())
if is_master:
tracker.log(evaluation, step=step(), verbose=True)
next_task = 'sample'
val_sample = 0
if next_task == 'sample':
if is_master:
# Generate examples and save the model if we are the master
# Generate sample images
print(print_ribbon(f"Sampling Set {epoch}", repeat=40))
test_images, test_captions = generate_grid_samples(trainer, test_example_data, "Test: ")
train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ")
tracker.log_images(test_images, captions=test_captions, image_section="Test Samples", step=step())
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
print(print_ribbon(f"Starting Saving {epoch}", repeat=40))
# Get the same paths
save_paths = []
if save_latest:
save_paths.append("latest.pth")
if save_all:
save_paths.append(f"checkpoints/epoch_{epoch}_step_{step}.pth")
if all_average_val_losses is not None:
average_loss = all_average_val_losses.mean(dim=0).item()
if save_best and (len(validation_losses) == 0 or average_loss < min(validation_losses)):
save_paths.append("best.pth")
validation_losses.append(average_loss)
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, save_paths)
next_task = 'train'
save_trainer(tracker, trainer, epoch, step, validation_losses, save_paths)
if exists(n_sample_images) and n_sample_images > 0:
trainer.eval()
train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ")
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step)
if exists(epoch_samples) and sample >= epoch_samples:
break
trainer.eval()
print(print_ribbon(f"Starting Validation {epoch}", repeat=40))
with torch.no_grad():
sample = 0
average_loss = 0
timer = Timer()
for i, (img, emb, txt) in enumerate(dataloaders["val"]):
sample += img.shape[0]
img, emb = send_to_device((img, emb))
for unet in range(1, len(decoder.unets)+1):
loss = trainer.forward(img.float(), image_embed=emb.float(), unet_number=unet)
average_loss += loss
if i % VALID_CALC_LOSS_EVERY_ITERS == 0:
print(f"Epoch {epoch}/{epochs} - {sample / timer.elapsed():.2f} samples/sec")
print(f"Loss: {average_loss / (i+1)}")
print("")
if exists(validation_samples) and sample >= validation_samples:
break
average_loss /= i+1
log_data = {
"Validation loss": average_loss
}
tracker.log(log_data, step=step, verbose=True)
# Compute evaluation metrics
if exists(evaluate_config):
print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config)
tracker.log(evaluation, step=step, verbose=True)
# Generate sample images
print(print_ribbon(f"Sampling Set {epoch}", repeat=40))
test_images, test_captions = generate_grid_samples(trainer, test_example_data, "Test: ")
train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ")
tracker.log_images(test_images, captions=test_captions, image_section="Test Samples", step=step)
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step)
print(print_ribbon(f"Starting Saving {epoch}", repeat=40))
# Get the same paths
save_paths = []
if save_latest:
save_paths.append("latest.pth")
if save_best and (len(validation_losses) == 0 or average_loss < min(validation_losses)):
save_paths.append("best.pth")
validation_losses.append(average_loss)
save_trainer(tracker, trainer, epoch, step, validation_losses, save_paths)
def create_tracker(config, tracker_type=None, data_path=None, **kwargs):
def create_tracker(accelerator, config, config_path, tracker_type=None, data_path=None):
"""
Creates a tracker of the specified type and initializes special features based on the full config
"""
tracker_config = config.tracker
init_config = {}
accelerator_config = {
"Distributed": accelerator.distributed_type != accelerate_dataclasses.DistributedType.NO,
"DistributedType": accelerator.distributed_type,
"NumProcesses": accelerator.num_processes,
"MixedPrecision": accelerator.mixed_precision
}
init_config = { "config": {**config.dict(), **accelerator_config} }
data_path = data_path or tracker_config.data_path
tracker_type = tracker_type or tracker_config.tracker_type
if exists(tracker_config.init_config):
init_config["config"] = tracker_config.init_config
if tracker_type == "console":
tracker = ConsoleTracker(**init_config)
if tracker_type == "dummy":
tracker = DummyTracker(data_path)
tracker.init(**init_config)
elif tracker_type == "console":
tracker = ConsoleTracker(data_path)
tracker.init(**init_config)
elif tracker_type == "wandb":
# We need to initialize the resume state here
load_config = config.load
@@ -406,51 +478,63 @@ def create_tracker(config, tracker_type=None, data_path=None, **kwargs):
init_config["project"] = tracker_config.wandb_project
tracker = WandbTracker(data_path)
tracker.init(**init_config)
tracker.save_file(str(config_path.absolute()), str(config_path.parent.absolute()))
else:
raise ValueError(f"Tracker type {tracker_type} not supported by decoder trainer")
return tracker
def initialize_training(config):
# Create the save path
if "cuda" in config.train.device:
assert torch.cuda.is_available(), "CUDA is not available"
device = torch.device(config.train.device)
torch.cuda.set_device(device)
all_shards = list(range(config.data.start_shard, config.data.end_shard + 1))
def initialize_training(config, config_path):
# Make sure if we are not loading, distributed models are initialized to the same values
torch.manual_seed(config.seed)
# Set up accelerator for configurable distributed training
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
# Set up data
all_shards = list(range(config.data.start_shard, config.data.end_shard + 1))
world_size = accelerator.num_processes
rank = accelerator.process_index
shards_per_process = len(all_shards) // world_size
assert shards_per_process > 0, "Not enough shards to split evenly"
my_shards = all_shards[rank * shards_per_process: (rank + 1) * shards_per_process]
dataloaders = create_dataloaders (
available_shards=all_shards,
available_shards=my_shards,
img_preproc = config.data.img_preproc,
train_prop = config.data.splits.train,
val_prop = config.data.splits.val,
test_prop = config.data.splits.test,
n_sample_images=config.train.n_sample_images,
**config.data.dict()
**config.data.dict(),
rank = rank,
seed = config.seed,
)
decoder = create_decoder(device, config.decoder, config.unets)
# Create the decoder model and print basic info
decoder = config.decoder.create()
num_parameters = sum(p.numel() for p in decoder.parameters())
print(print_ribbon("Loaded Config", repeat=40))
print(f"Number of parameters: {num_parameters}")
tracker = create_tracker(config, **config.tracker.dict())
# Create and initialize the tracker if we are the master
tracker = create_tracker(accelerator, config, config_path) if rank == 0 else create_tracker(accelerator, config, config_path, tracker_type="dummy")
train(dataloaders, decoder,
accelerator.print(print_ribbon("Loaded Config", repeat=40))
accelerator.print(f"Running training with {accelerator.num_processes} processes and {accelerator.distributed_type} distributed training")
accelerator.print(f"Number of parameters: {num_parameters}")
train(dataloaders, decoder, accelerator,
tracker=tracker,
inference_device=device,
inference_device=accelerator.device,
load_config=config.load,
evaluate_config=config.evaluate,
**config.train.dict(),
)
# Create a simple click command line interface to load the config and start the training
@click.command()
@click.option("--config_file", default="./train_decoder_config.json", help="Path to config file")
def main(config_file):
print("Recalling config from {}".format(config_file))
config = TrainDecoderConfig.from_json_path(config_file)
initialize_training(config)
config_file_path = Path(config_file)
config = TrainDecoderConfig.from_json_path(str(config_file_path))
initialize_training(config, config_path=config_file_path)
if __name__ == "__main__":
main()

View File

@@ -1,75 +1,135 @@
from pathlib import Path
# TODO: add start, num_data_points, eval_every and group to config
# TODO: switch back to repo's wandb
START = 0
NUM_DATA_POINTS = 250e6
EVAL_EVERY = 1000
GROUP = "distributed"
import os
import click
import math
import numpy as np
import wandb
import torch
import clip
from torch import nn
from torch.utils.data import DataLoader
from dalle2_pytorch.dataloaders import make_splits
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter
from dalle2_pytorch.trainer import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model
import numpy as np
from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker
from dalle2_pytorch.utils import Timer, print_ribbon
from accelerate import Accelerator
from embedding_reader import EmbeddingReader
from dalle2_pytorch.dataloaders import get_reader, make_splits
from dalle2_pytorch.utils import Timer
from dalle2_pytorch.train_configs import (
DiffusionPriorTrainConfig,
TrainDiffusionPriorConfig,
)
from dalle2_pytorch.trackers import BaseTracker, WandbTracker
from dalle2_pytorch import DiffusionPriorTrainer
from tqdm import tqdm
# constants
# helpers
REPORT_METRICS_EVERY = 250 # for cosine similarity and other metric reporting during training
tracker = WandbTracker()
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
# helpers functions
def exists(val):
val is not None
return val is not None
# functions
def eval_model(model, dataloader, text_conditioned, loss_type, phase="Validation"):
model.eval()
def make_model(
prior_config, train_config, device: str = None, accelerator: Accelerator = None
):
# create model from config
diffusion_prior = prior_config.create()
# instantiate the trainer
trainer = DiffusionPriorTrainer(
diffusion_prior=diffusion_prior,
lr=train_config.lr,
wd=train_config.wd,
max_grad_norm=train_config.max_grad_norm,
amp=train_config.amp,
use_ema=train_config.use_ema,
device=device,
accelerator=accelerator,
)
return trainer
# eval functions
def eval_model(
trainer: DiffusionPriorTrainer,
dataloader: DataLoader,
text_conditioned: bool,
loss_type: str,
tracker_context: str,
tracker: BaseTracker = None,
use_ema: bool = True,
):
trainer.eval()
if trainer.is_main_process():
click.secho(f"Measuring performance on {tracker_context}", fg="green", blink=True)
with torch.no_grad():
total_loss = 0.
total_samples = 0.
total_loss = 0.0
total_samples = 0.0
for image_embeddings, text_data in tqdm(dataloader):
for image_embeddings, text_data in dataloader:
image_embeddings = image_embeddings.to(trainer.device)
text_data = text_data.to(trainer.device)
batches = image_embeddings.shape[0]
input_args = dict(image_embed=image_embeddings)
if text_conditioned:
input_args = dict(**input_args, text = text_data)
input_args = dict(**input_args, text=text_data)
else:
input_args = dict(**input_args, text_embed=text_data)
loss = model(**input_args)
if use_ema:
loss = trainer.ema_diffusion_prior(**input_args)
else:
loss = trainer(**input_args)
total_loss += loss * batches
total_samples += batches
avg_loss = (total_loss / total_samples)
avg_loss = total_loss / total_samples
tracker.log({f'{phase} {loss_type}': avg_loss})
stats = {f"{tracker_context}-{loss_type}": avg_loss}
trainer.print(stats)
def report_cosine_sims(diffusion_prior, dataloader, text_conditioned):
diffusion_prior.eval()
if exists(tracker):
tracker.log(stats, step=trainer.step.item() + 1)
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
for test_image_embeddings, text_data in tqdm(dataloader):
def report_cosine_sims(
trainer: DiffusionPriorTrainer,
dataloader: DataLoader,
text_conditioned: bool,
tracker: BaseTracker,
tracker_context: str = "validation",
):
trainer.eval()
if trainer.is_main_process():
click.secho("Measuring Cosine-Similarity", fg="green", blink=True)
for test_image_embeddings, text_data in dataloader:
test_image_embeddings = test_image_embeddings.to(trainer.device)
text_data = text_data.to(trainer.device)
# we are text conditioned, we produce an embedding from the tokenized text
if text_conditioned:
text_embedding, text_encodings, text_mask = diffusion_prior.clip.embed_text(
text_data)
text_cond = dict(text_embed=text_embedding,
text_encodings=text_encodings, mask=text_mask)
text_embedding, text_encodings, text_mask = trainer.embed_text(text_data)
text_cond = dict(
text_embed=text_embedding, text_encodings=text_encodings, mask=text_mask
)
else:
text_embedding = text_data
text_cond = dict(text_embed=text_embedding)
@@ -80,8 +140,9 @@ def report_cosine_sims(diffusion_prior, dataloader, text_conditioned):
# roll the text to simulate "unrelated" captions
rolled_idx = torch.roll(torch.arange(text_embedding.shape[0]), 1)
text_embed_shuffled = text_embed_shuffled[rolled_idx]
text_embed_shuffled = text_embed_shuffled / \
text_embed_shuffled.norm(dim=1, keepdim=True)
text_embed_shuffled = text_embed_shuffled / text_embed_shuffled.norm(
dim=1, keepdim=True
)
if text_conditioned:
text_encodings_shuffled = text_encodings[rolled_idx]
@@ -90,276 +151,276 @@ def report_cosine_sims(diffusion_prior, dataloader, text_conditioned):
text_encodings_shuffled = None
text_mask_shuffled = None
text_cond_shuffled = dict(text_embed=text_embed_shuffled,
text_encodings=text_encodings_shuffled, mask=text_mask_shuffled)
text_cond_shuffled = dict(
text_embed=text_embed_shuffled,
text_encodings=text_encodings_shuffled,
mask=text_mask_shuffled,
)
# prepare the text embedding
text_embed = text_embedding / text_embedding.norm(dim=1, keepdim=True)
# prepare image embeddings
test_image_embeddings = test_image_embeddings / \
test_image_embeddings.norm(dim=1, keepdim=True)
test_image_embeddings = test_image_embeddings / test_image_embeddings.norm(
dim=1, keepdim=True
)
# predict on the unshuffled text embeddings
predicted_image_embeddings = diffusion_prior.p_sample_loop(
test_image_embeddings.shape, text_cond)
predicted_image_embeddings = predicted_image_embeddings / \
predicted_image_embeddings.norm(dim=1, keepdim=True)
predicted_image_embeddings = trainer.p_sample_loop(
test_image_embeddings.shape, text_cond
)
predicted_image_embeddings = (
predicted_image_embeddings
/ predicted_image_embeddings.norm(dim=1, keepdim=True)
)
# predict on the shuffled embeddings
predicted_unrelated_embeddings = diffusion_prior.p_sample_loop(
test_image_embeddings.shape, text_cond_shuffled)
predicted_unrelated_embeddings = predicted_unrelated_embeddings / \
predicted_unrelated_embeddings.norm(dim=1, keepdim=True)
predicted_unrelated_embeddings = trainer.p_sample_loop(
test_image_embeddings.shape, text_cond_shuffled
)
predicted_unrelated_embeddings = (
predicted_unrelated_embeddings
/ predicted_unrelated_embeddings.norm(dim=1, keepdim=True)
)
# calculate similarities
original_similarity = cos(
text_embed, test_image_embeddings).cpu().numpy()
predicted_similarity = cos(
text_embed, predicted_image_embeddings).cpu().numpy()
unrelated_similarity = cos(
text_embed, predicted_unrelated_embeddings).cpu().numpy()
predicted_img_similarity = cos(
test_image_embeddings, predicted_image_embeddings).cpu().numpy()
tracker.log({"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity),
"CosineSimilarity(text_embed,predicted_image_embed)":np.mean(predicted_similarity),
"CosineSimilarity(orig_image_embed,predicted_image_embed)":np.mean(predicted_img_similarity),
"CosineSimilarity(text_embed,predicted_unrelated_embed)": np.mean(unrelated_similarity),
"Cosine similarity difference":np.mean(predicted_similarity - original_similarity)})
original_similarity = cos(text_embed, test_image_embeddings).cpu().numpy()
predicted_similarity = cos(text_embed, predicted_image_embeddings).cpu().numpy()
unrelated_similarity = (
cos(text_embed, predicted_unrelated_embeddings).cpu().numpy()
)
predicted_img_similarity = (
cos(test_image_embeddings, predicted_image_embeddings).cpu().numpy()
)
stats = {
f"{tracker_context}/baseline similarity": np.mean(original_similarity),
f"{tracker_context}/similarity with text": np.mean(predicted_similarity),
f"{tracker_context}/similarity with original image": np.mean(
predicted_img_similarity
),
f"{tracker_context}/similarity with unrelated caption": np.mean(unrelated_similarity),
f"{tracker_context}/difference from baseline similarity": np.mean(
predicted_similarity - original_similarity
),
}
for k, v in stats.items():
trainer.print(f"{tracker_context}/{k}: {v}")
if exists(tracker):
tracker.log(stats, step=trainer.step.item() + 1)
# training script
def train(
trainer: DiffusionPriorTrainer,
train_loader: DataLoader,
eval_loader: DataLoader,
test_loader: DataLoader,
config: DiffusionPriorTrainConfig,
):
# distributed tracking with wandb
if trainer.accelerator.num_processes > 1:
os.environ["WANDB_START_METHOD"] = "thread"
tracker = wandb.init(
name=f"RANK:{trainer.device}",
entity=config.tracker.wandb_entity,
project=config.tracker.wandb_project,
config=config.dict(),
group=GROUP,
)
# sync after tracker init
trainer.wait_for_everyone()
# init a timer
timer = Timer()
# do training
for img, txt in train_loader:
trainer.train()
current_step = trainer.step.item() + 1
# place data on device
img = img.to(trainer.device)
txt = txt.to(trainer.device)
# pass to model
loss = trainer(text=txt, image_embed=img)
# display & log loss (will only print from main process)
trainer.print(f"Step {current_step}: Loss {loss}")
# perform backprop & apply EMA updates
trainer.update()
# track samples/sec/rank
samples_per_sec = img.shape[0] / timer.elapsed()
# samples seen
samples_seen = (
config.data.batch_size * trainer.accelerator.num_processes * current_step
)
# ema decay
ema_decay = trainer.ema_diffusion_prior.get_current_decay()
# Log on all processes for debugging
tracker.log(
{
"tracking/samples-sec": samples_per_sec,
"tracking/samples-seen": samples_seen,
"tracking/ema-decay": ema_decay,
"metrics/training-loss": loss,
},
step=current_step,
)
# Metric Tracking & Checkpointing (outside of timer's scope)
if current_step % EVAL_EVERY == 0:
eval_model(
trainer=trainer,
dataloader=eval_loader,
text_conditioned=config.prior.condition_on_text_encodings,
loss_type=config.prior.loss_type,
tracker_context="metrics/online-model-validation",
tracker=tracker,
use_ema=False,
)
eval_model(
trainer=trainer,
dataloader=eval_loader,
text_conditioned=config.prior.condition_on_text_encodings,
loss_type=config.prior.loss_type,
tracker_context="metrics/ema-model-validation",
tracker=tracker,
use_ema=True,
)
report_cosine_sims(
trainer=trainer,
dataloader=eval_loader,
text_conditioned=config.prior.condition_on_text_encodings,
tracker=tracker,
tracker_context="metrics",
)
if current_step % config.train.save_every == 0:
trainer.save(f"{config.tracker.data_path}/chkpt_step_{current_step}.pth")
# reset timer for next round
timer.reset()
# evaluate on test data
eval_model(
trainer=trainer,
dataloader=test_loader,
text_conditioned=config.prior.condition_on_text_encodings,
loss_type=config.prior.loss_type,
tracker_context="test",
tracker=tracker,
)
report_cosine_sims(
trainer,
test_loader,
config.prior.condition_on_text_encodings,
tracker,
tracker_context="test",
)
def initialize_training(config, accelerator=None):
"""
Parse the configuration file, and prepare everything necessary for training
"""
# get a device
if accelerator:
device = accelerator.device
click.secho(f"Accelerating on: {device}", fg="yellow")
else:
if torch.cuda.is_available():
click.secho("GPU detected, defaulting to cuda:0", fg="yellow")
device = "cuda:0"
else:
click.secho("No GPU detected...using cpu", fg="yellow")
device = "cpu"
# make the trainer (will automatically distribute if possible & configured)
trainer = make_model(config.prior, config.train, device, accelerator).to(device)
# reload from chcekpoint
if config.load.resume == True:
click.secho(f"Loading checkpoint: {config.load.source}", fg="cyan")
trainer.load(config.load.source)
# fetch and prepare data
if trainer.is_main_process():
click.secho("Grabbing data from source", fg="blue", blink=True)
img_reader = get_reader(
text_conditioned=trainer.text_conditioned,
img_url=config.data.image_url,
meta_url=config.data.meta_url,
)
train_loader, eval_loader, test_loader = make_splits(
text_conditioned=trainer.text_conditioned,
batch_size=config.data.batch_size,
num_data_points=NUM_DATA_POINTS,
train_split=config.data.splits.train,
eval_split=config.data.splits.val,
image_reader=img_reader,
rank=accelerator.state.process_index if exists(accelerator) else 0,
world_size=accelerator.state.num_processes if exists(accelerator) else 1,
start=START,
)
# wait for everyone to load data before continuing
trainer.wait_for_everyone()
# start training
train(
trainer=trainer,
train_loader=train_loader,
eval_loader=eval_loader,
test_loader=test_loader,
config=config,
)
@click.command()
@click.option("--wandb-entity", default="laion")
@click.option("--wandb-project", default="diffusion-prior")
@click.option("--wandb-dataset", default="LAION-5B")
@click.option("--wandb-arch", default="DiffusionPrior")
@click.option("--image-embed-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/")
@click.option("--text-embed-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/")
@click.option("--meta-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/laion2B-en-metadata/")
@click.option("--learning-rate", default=1.1e-4)
@click.option("--weight-decay", default=6.02e-2)
@click.option("--dropout", default=5e-2)
@click.option("--max-grad-norm", default=0.5)
@click.option("--num-data-points", default=250e6)
@click.option("--batch-size", default=320)
@click.option("--num-epochs", default=5)
@click.option("--image-embed-dim", default=768)
@click.option("--train-percent", default=0.9)
@click.option("--val-percent", default=1e-7)
@click.option("--test-percent", default=0.0999999)
@click.option("--dpn-depth", default=12)
@click.option("--dpn-dim-head", default=64)
@click.option("--dpn-heads", default=12)
@click.option("--dp-condition-on-text-encodings", default=True)
@click.option("--dp-timesteps", default=1000)
@click.option("--dp-normformer", default=True)
@click.option("--dp-cond-drop-prob", default=0.1)
@click.option("--dp-loss-type", default="l2")
@click.option("--clip", default="ViT-L/14")
@click.option("--amp", default=False)
@click.option("--save-interval", default=120)
@click.option("--save-path", default="./diffusion_prior_checkpoints")
@click.option("--pretrained-model-path", default=None)
@click.option("--gpu-device", default=0)
def train(
wandb_entity,
wandb_project,
wandb_dataset,
wandb_arch,
image_embed_url,
text_embed_url,
meta_url,
learning_rate,
weight_decay,
dropout,
max_grad_norm,
num_data_points,
batch_size,
num_epochs,
image_embed_dim,
train_percent,
val_percent,
test_percent,
dpn_depth,
dpn_dim_head,
dpn_heads,
dp_condition_on_text_encodings,
dp_timesteps,
dp_normformer,
dp_cond_drop_prob,
dp_loss_type,
clip,
amp,
save_interval,
save_path,
pretrained_model_path,
gpu_device
):
config = {
"learning_rate": learning_rate,
"architecture": wandb_arch,
"dataset": wandb_dataset,
"weight_decay": weight_decay,
"max_gradient_clipping_norm": max_grad_norm,
"batch_size": batch_size,
"epochs": num_epochs,
"diffusion_prior_network": {
"depth": dpn_depth,
"dim_head": dpn_dim_head,
"heads": dpn_heads,
"normformer": dp_normformer
},
"diffusion_prior": {
"condition_on_text_encodings": dp_condition_on_text_encodings,
"timesteps": dp_timesteps,
"cond_drop_prob": dp_cond_drop_prob,
"loss_type": dp_loss_type,
"clip": clip
}
}
# Check if DPRIOR_PATH exists(saved model path)
DPRIOR_PATH = pretrained_model_path
RESUME = exists(DPRIOR_PATH)
if not RESUME:
tracker.init(
entity = wandb_entity,
project = wandb_project,
config = config
)
# Obtain the utilized device.
has_cuda = torch.cuda.is_available()
if has_cuda:
device = torch.device(f"cuda:{gpu_device}")
torch.cuda.set_device(device)
# Training loop
# diffusion prior network
prior_network = DiffusionPriorNetwork(
dim = image_embed_dim,
depth = dpn_depth,
dim_head = dpn_dim_head,
heads = dpn_heads,
attn_dropout = dropout,
ff_dropout = dropout,
normformer = dp_normformer
)
# Load clip model if text-conditioning
if dp_condition_on_text_encodings:
clip_adapter = OpenAIClipAdapter(clip)
@click.option("--hfa", default=True)
@click.option("--config_path", default="configs/prior.json")
def main(hfa, config_path):
# start HFA if requested
if hfa:
accelerator = Accelerator()
else:
clip_adapter = None
# diffusion prior with text embeddings and image embeddings pre-computed
accelerator = None
diffusion_prior = DiffusionPrior(
net = prior_network,
clip = clip_adapter,
image_embed_dim = image_embed_dim,
timesteps = dp_timesteps,
cond_drop_prob = dp_cond_drop_prob,
loss_type = dp_loss_type,
condition_on_text_encodings = dp_condition_on_text_encodings
)
# load the configuration file on main process
if not exists(accelerator) or accelerator.is_main_process:
click.secho(f"Loading configuration from {config_path}", fg="green")
# Load pre-trained model from DPRIOR_PATH
config = TrainDiffusionPriorConfig.from_json_path(config_path)
if RESUME:
diffusion_prior, loaded_obj = load_diffusion_model(DPRIOR_PATH, device)
tracker.init(entity = wandb_entity, project = wandb_project, config = config)
# diffusion prior trainer
trainer = DiffusionPriorTrainer(
diffusion_prior = diffusion_prior,
lr = learning_rate,
wd = weight_decay,
max_grad_norm = max_grad_norm,
amp = amp,
).to(device)
# load optimizer and scaler
if RESUME:
trainer.optimizer.load_state_dict(loaded_obj['optimizer'])
trainer.scaler.load_state_dict(loaded_obj['scaler'])
# Create save_path if it doesn't exist
Path(save_path).mkdir(exist_ok = True, parents = True)
# Utilize wrapper to abstract away loader logic
print_ribbon("Downloading Embeddings")
loader_args = dict(text_conditioned=dp_condition_on_text_encodings, batch_size=batch_size, num_data_points=num_data_points,
train_split=train_percent, eval_split=val_percent, device=device, img_url=image_embed_url)
if dp_condition_on_text_encodings:
loader_args = dict(**loader_args, meta_url=meta_url)
else:
loader_args = dict(**loader_args, txt_url=text_embed_url)
train_loader, eval_loader, test_loader = make_splits(**loader_args)
### Training code ###
step = 1
timer = Timer()
epochs = num_epochs
for _ in range(epochs):
for image, text in tqdm(train_loader):
diffusion_prior.train()
input_args = dict(image_embed=image)
if dp_condition_on_text_encodings:
input_args = dict(**input_args, text = text)
else:
input_args = dict(**input_args, text_embed=text)
loss = trainer(**input_args)
# Samples per second
samples_per_sec = batch_size * step / timer.elapsed()
# Save checkpoint every save_interval minutes
if(int(timer.elapsed()) >= 60 * save_interval):
timer.reset()
save_diffusion_model(
save_path,
diffusion_prior,
trainer.optimizer,
trainer.scaler,
config,
image_embed_dim)
# Log to wandb
tracker.log({"Training loss": loss,
"Steps": step,
"Samples per second": samples_per_sec})
# Log cosineSim(text_embed,predicted_image_embed) - cosineSim(text_embed,image_embed)
# Use NUM_TEST_EMBEDDINGS samples from the test set each time
# Get embeddings from the most recently saved model
if(step % REPORT_METRICS_EVERY) == 0:
report_cosine_sims(diffusion_prior, eval_loader, dp_condition_on_text_encodings)
### Evaluate model(validation run) ###
eval_model(diffusion_prior, eval_loader, dp_condition_on_text_encodings, dp_loss_type, phase="Validation")
step += 1
trainer.update()
### Test run ###
eval_model(diffusion_prior, test_loader, dp_condition_on_text_encodings, dp_loss_type, phase="Test")
# send config to get processed
initialize_training(config, accelerator)
if __name__ == "__main__":
train()
main()