Compare commits

..

1 Commits

Author SHA1 Message Date
Phil Wang
9340d33d5f fix wandb logging in tracker, and do some cleanup 2022-05-20 17:10:33 -07:00
24 changed files with 945 additions and 3274 deletions

6
.gitignore vendored
View File

@@ -1,12 +1,6 @@
# default experiment tracker data
.tracker-data/
# Configuration Files
configs/*
!configs/*.example
!configs/*_defaults.py
!configs/README.md
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]

119
README.md
View File

@@ -12,7 +12,7 @@ This model is SOTA for text-to-image for now.
Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord" src="https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white"></a> if you are interested in helping out with the replication with the <a href="https://laion.ai/">LAION</a> community | <a href="https://www.youtube.com/watch?v=AIOE1l1W0Tw">Yannic Interview</a>
As of 5/23/22, it is no longer SOTA. SOTA will be <a href="https://github.com/lucidrains/imagen-pytorch">here</a>. Jax versions as well as text-to-video project will be shifted towards the Imagen architecture, as it is way simpler.
There was enough interest for a <a href="https://github.com/lucidrains/dalle2-jax">Jax version</a>. I will also eventually extend this to <a href="https://github.com/lucidrains/dalle2-video">text to video</a>, once the repository is in a good place.
## Status
@@ -24,30 +24,6 @@ As of 5/23/22, it is no longer SOTA. SOTA will be <a href="https://github.com/lu
*ongoing at 21k steps*
- <a href="https://twitter.com/Buntworthy/status/1529475416775434240?t=0GEge3Kr9I36cjcUVCQUTg">Justin Pinkney</a> successfully trained the diffusion prior in the repository for his CLIP to Stylegan2 text-to-image application
## Pre-Trained Models
- LAION is training prior models. Checkpoints are available on <a href="https://huggingface.co/zenglishuci/conditioned-prior">🤗huggingface</a> and the training statistics are available on <a href="https://wandb.ai/nousr_laion/conditioned-prior/reports/LAION-DALLE2-PyTorch-Prior--VmlldzoyMDI2OTIx">🐝WANDB</a>.
- Decoder - <a href="https://wandb.ai/veldrovive/dalle2_train_decoder/runs/jkrtg0so?workspace=user-veldrovive">In-progress test run</a> 🚧
- Decoder - <a href="https://wandb.ai/veldrovive/dalle2_train_decoder/runs/3d5rytsa?workspace=">Another test run with sparse attention</a>
- DALL-E 2 🚧
## Appreciation
This library would not have gotten to this working state without the help of
- <a href="https://github.com/nousr">Zion</a> for the distributed training code for the diffusion prior
- <a href="https://github.com/Veldrovive">Aidan</a> for the distributed training code for the decoder as well as the dataloaders
- <a href="https://github.com/krish240574">Kumar</a> for working on the initial diffusion training script
- <a href="https://github.com/rom1504">Romain</a> for the pull request reviews and project management
- <a href="https://github.com/Ciaohe">He Cao</a> and <a href="https://github.com/xiankgx">xiankgx</a> for the Q&A and for identifying of critical bugs
- <a href="https://github.com/crowsonkb">Katherine</a> for her advice
- <a href="https://stability.ai/">Stability AI</a> for the generous sponsorship
- <a href="https://huggingface.co">🤗 Huggingface</a> and in particular <a href="https://github.com/sgugger">Sylvain</a> for the <a href="https://github.com/huggingface/accelerate">Accelerate</a> library
... and many others. Thank you! 🙏
## Install
```bash
@@ -368,8 +344,7 @@ unet1 = Unet(
image_embed_dim = 512,
cond_dim = 128,
channels = 3,
dim_mults=(1, 2, 4, 8),
cond_on_text_encodings = True # set to True for any unets that need to be conditioned on text encodings
dim_mults=(1, 2, 4, 8)
).cuda()
unet2 = Unet(
@@ -386,7 +361,8 @@ decoder = Decoder(
clip = clip,
timesteps = 100,
image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5
text_cond_drop_prob = 0.5,
condition_on_text_encodings = False # set this to True if you wish to condition on text during training and sampling
).cuda()
for unet_number in (1, 2):
@@ -960,7 +936,7 @@ from dalle2_pytorch.dataloaders import ImageEmbeddingDataset, create_image_embed
# Create a dataloader directly.
dataloader = create_image_embedding_dataloader(
tar_url="/path/or/url/to/webdataset/{0000..9999}.tar", # Uses bracket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar
tar_url="/path/or/url/to/webdataset/{0000..9999}.tar", # Uses braket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar
embeddings_url="path/or/url/to/embeddings/folder", # Included if .npy files are not in webdataset. Left out or set to None otherwise
num_workers=4,
batch_size=32,
@@ -1017,6 +993,33 @@ The most significant parameters for the script are as follows:
- `clip`, default = `None` # Signals the prior to use pre-computed embeddings
#### Loading and Saving the DiffusionPrior model
Two methods are provided, load_diffusion_model and save_diffusion_model, the names being self-explanatory.
```python
from dalle2_pytorch.train import load_diffusion_model, save_diffusion_model
```
##### Loading
load_diffusion_model(dprior_path, device)
dprior_path : path to saved model(.pth)
device : the cuda device you're running on
##### Saving
save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim)
save_path : path to save at
model : object of Diffusion_Prior
optimizer : optimizer object - see train_diffusion_prior.py for how to create one.
e.g: optimizer = get_optimizer(diffusion_prior.net.parameters(), wd=weight_decay, lr=learning_rate)
scaler : a GradScaler object.
e.g: scaler = GradScaler(enabled=amp)
config : config object created in train_diffusion_prior.py - see file for example.
image_embed_dim - the dimension of the image_embedding
e.g: 768
## CLI (wip)
```bash
@@ -1061,18 +1064,22 @@ Once built, images will be saved to the same directory the command is invoked
- [x] bring in cross-scale embedding from iclr paper https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/crossformer.py#L14
- [x] cross embed layers for downsampling, as an option
- [x] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a>
- [x] use pydantic for config drive training
- [x] for both diffusion prior and decoder, all exponential moving averaged models needs to be saved and restored as well (as well as the step number)
- [x] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
- [x] allow for creation of diffusion prior model off pydantic config classes - consider the same for tracker configs
- [x] bring in skip-layer excitations (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training (doesnt work well)
- [x] test out grid attention in cascading ddpm locally, decide whether to keep or remove https://arxiv.org/abs/2204.01697 (keeping, seems to be fine)
- [x] allow for unet to be able to condition non-cross attention style as well
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
- [ ] speed up inference, read up on papers (ddim or diffusion-gan, etc)
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
- [ ] train on a toy task, offer in colab
- [ ] think about how best to design a declarative training config that handles preencoding for prior and training of multiple networks in decoder
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
- [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824
- [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
- [ ] add inpainting ability using resampler from repaint paper https://arxiv.org/abs/2201.09865
- [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783
- [ ] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
- [ ] bring in skip-layer excitatons (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training
- [ ] decoder needs one day worth of refactor for tech debt
- [ ] allow for unet to be able to condition non-cross attention style as well
- [ ] for all model classes with hyperparameters that changes the network architecture, make it requirement that they must expose a config property, and write a simple function that asserts that it restores the object correctly
- [ ] for both diffusion prior and decoder, all exponential moving averaged models needs to be saved and restored as well (as well as the step number)
- [ ] read the paper, figure it out, and build it https://github.com/lucidrains/DALLE2-pytorch/issues/89
## Citations
@@ -1112,6 +1119,14 @@ Once built, images will be saved to the same directory the command is invoked
}
```
```bibtex
@inproceedings{Tu2022MaxViTMV,
title = {MaxViT: Multi-Axis Vision Transformer},
author = {Zhe-Wei Tu and Hossein Talebi and Han Zhang and Feng Yang and Peyman Milanfar and Alan Conrad Bovik and Yinxiao Li},
year = {2022}
}
```
```bibtex
@article{Yu2021VectorquantizedIM,
title = {Vector-quantized Image Modeling with Improved VQGAN},
@@ -1162,32 +1177,4 @@ Once built, images will be saved to the same directory the command is invoked
}
```
```bibtex
@misc{Saharia2022,
title = {Imagen: unprecedented photorealism × deep level of language understanding},
author = {Chitwan Saharia*, William Chan*, Saurabh Saxena†, Lala Li†, Jay Whang†, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S. Sara Mahdavi, Rapha Gontijo Lopes, Tim Salimans, Jonathan Ho†, David Fleet†, Mohammad Norouzi*},
year = {2022}
}
```
```bibtex
@article{Choi2022PerceptionPT,
title = {Perception Prioritized Training of Diffusion Models},
author = {Jooyoung Choi and Jungbeom Lee and Chaehun Shin and Sungwon Kim and Hyunwoo J. Kim and Sung-Hoon Yoon},
journal = {ArXiv},
year = {2022},
volume = {abs/2204.00227}
}
```
```bibtex
@article{Saharia2021PaletteID,
title = {Palette: Image-to-Image Diffusion Models},
author = {Chitwan Saharia and William Chan and Huiwen Chang and Chris A. Lee and Jonathan Ho and Tim Salimans and David J. Fleet and Mohammad Norouzi},
journal = {ArXiv},
year = {2021},
volume = {abs/2111.05826}
}
```
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>

View File

@@ -1,173 +0,0 @@
## DALLE2 Training Configurations
For more complex configuration, we provide the option of using a configuration file instead of command line arguments.
### Decoder Trainer
The decoder trainer has 7 main configuration options. A full example of their use can be found in the [example decoder configuration](train_decoder_config.example.json).
**<ins>Unet</ins>:**
This is a single unet config, which belongs as an array nested under the decoder config as a list of `unets`
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `dim` | Yes | N/A | The starting channels of the unet. |
| `image_embed_dim` | Yes | N/A | The dimension of the image embeddings. |
| `dim_mults` | No | `(1, 2, 4, 8)` | The growth factors of the channels. |
Any parameter from the `Unet` constructor can also be given here.
**<ins>Decoder</ins>:**
Defines the configuration options for the decoder model. The unets defined above will automatically be inserted.
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `unets` | Yes | N/A | A list of unets, using the configuration above |
| `image_sizes` | Yes | N/A | The resolution of the image after each upsampling step. The length of this array should be the number of unets defined. |
| `image_size` | Yes | N/A | Not used. Can be any number. |
| `timesteps` | No | `1000` | The number of diffusion timesteps used for generation. |
| `loss_type` | No | `l2` | The loss function. Options are `l1`, `huber`, or `l2`. |
| `beta_schedule` | No | `cosine` | The noising schedule. Options are `cosine`, `linear`, `quadratic`, `jsd`, or `sigmoid`. |
| `learned_variance` | No | `True` | Whether to learn the variance. |
Any parameter from the `Decoder` constructor can also be given here.
**<ins>Data</ins>:**
Settings for creation of the dataloaders.
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `webdataset_base_url` | Yes | N/A | The url of a shard in the webdataset with the shard replaced with `{}`[^1]. |
| `embeddings_url` | No | N/A | The url of the folder containing embeddings shards. Not required if embeddings are in webdataset. |
| `num_workers` | No | `4` | The number of workers used in the dataloader. |
| `batch_size` | No | `64` | The batch size. |
| `start_shard` | No | `0` | Defines the start of the shard range the dataset will recall. |
| `end_shard` | No | `9999999` | Defines the end of the shard range the dataset will recall. |
| `shard_width` | No | `6` | Defines the width of one webdataset shard number[^2]. |
| `index_width` | No | `4` | Defines the width of the index of a file inside a shard[^3]. |
| `splits` | No | `{ "train": 0.75, "val": 0.15, "test": 0.1 }` | Defines the proportion of shards that will be allocated to the training, validation, and testing datasets. |
| `shuffle_train` | No | `True` | Whether to shuffle the shards of the training dataset. |
| `resample_train` | No | `False` | If true, shards will be randomly sampled with replacement from the datasets making the epoch length infinite if a limit is not set. Cannot be enabled if `shuffle_train` is enabled. |
| `preprocessing` | No | `{ "ToTensor": True }` | Defines preprocessing applied to images from the datasets. |
[^1]: If your shard files have the paths `protocol://path/to/shard/00104.tar`, then the base url would be `protocol://path/to/shard/{}.tar`. If you are using a protocol like `s3`, you need to pipe the tars. For example `pipe:s3cmd get s3://bucket/path/{}.tar -`.
[^2]: This refers to the string length of the shard number for your webdataset shards. For instance, if your webdataset shard has the filename `00104.tar`, your shard length is 5.
[^3]: Inside the webdataset `tar`, you have files named something like `001045945.jpg`. 5 of these characters refer to the shard, and 4 refer to the index of the file in the webdataset (shard is `001041` and index is `5945`). The `index_width` in this case is 4.
**<ins>Train</ins>:**
Settings for controlling the training hyperparameters.
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `epochs` | No | `20` | The number of epochs in the training run. |
| `lr` | No | `1e-4` | The learning rate. |
| `wd` | No | `0.01` | The weight decay. |
| `max_grad_norm`| No | `0.5` | The grad norm clipping. |
| `save_every_n_samples` | No | `100000` | Samples will be generated and a checkpoint will be saved every `save_every_n_samples` samples. |
| `device` | No | `cuda:0` | The device to train on. |
| `epoch_samples` | No | `None` | Limits the number of samples iterated through in each epoch. This must be set if resampling. None means no limit. |
| `validation_samples` | No | `None` | The number of samples to use for validation. None mean the entire validation set. |
| `use_ema` | No | `True` | Whether to use exponential moving average models for sampling. |
| `ema_beta` | No | `0.99` | The ema coefficient. |
| `save_all` | No | `False` | If True, preserves a checkpoint for every epoch. |
| `save_latest` | No | `True` | If True, overwrites the `latest.pth` every time the model is saved. |
| `save_best` | No | `True` | If True, overwrites the `best.pth` every time the model has a lower validation loss than all previous models. |
| `unet_training_mask` | No | `None` | A boolean array of the same length as the number of unets. If false, the unet is frozen. A value of `None` trains all unets. |
**<ins>Evaluate</ins>:**
Defines which evaluation metrics will be used to test the model.
Each metric can be enabled by setting its configuration. The configuration keys for each metric are defined by the torchmetrics constructors which will be linked.
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `n_evaluation_samples` | No | `1000` | The number of samples to generate to test the model. |
| `FID` | No | `None` | Setting to an object enables the [Frechet Inception Distance](https://torchmetrics.readthedocs.io/en/stable/image/frechet_inception_distance.html) metric.
| `IS` | No | `None` | Setting to an object enables the [Inception Score](https://torchmetrics.readthedocs.io/en/stable/image/inception_score.html) metric.
| `KID` | No | `None` | Setting to an object enables the [Kernel Inception Distance](https://torchmetrics.readthedocs.io/en/stable/image/kernel_inception_distance.html) metric. |
| `LPIPS` | No | `None` | Setting to an object enables the [Learned Perceptual Image Patch Similarity](https://torchmetrics.readthedocs.io/en/stable/image/learned_perceptual_image_patch_similarity.html) metric. |
**<ins>Tracker</ins>:**
Selects how the experiment will be tracked.
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `data_path` | No | `./.tracker-data` | The path to the folder where temporary tracker data will be saved. |
| `overwrite_data_path` | No | `False` | If true, the data path will be overwritten. Otherwise, you need to delete it yourself. |
| `log` | Yes | N/A | Logging configuration. |
| `load` | No | `None` | Checkpoint loading configuration. |
| `save` | Yes | N/A | Checkpoint/Model saving configuration. |
Tracking is split up into three sections:
* Log: Where to save run metadata and image output. Options are `console` or `wandb`.
* Load: Where to load a checkpoint from. Options are `local`, `url`, or `wandb`.
* Save: Where to save a checkpoint to. Options are `local`, `huggingface`, or `wandb`.
**Logging:**
If using `console` there is no further configuration than setting `log_type` to `console`.
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `log_type` | Yes | N/A | Must be `console`. |
If using `wandb`
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `log_type` | Yes | N/A | Must be `wandb`. |
| `wandb_entity` | Yes | N/A | The wandb entity to log to. |
| `wandb_project` | Yes | N/A | The wandb project save the run to. |
| `wandb_run_name` | No | `None` | The wandb run name. |
| `wandb_run_id` | No | `None` | The wandb run id. Used if resuming an old run. |
| `wandb_resume` | No | `False` | Whether to resume an old run. |
**Loading:**
If using `local`
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `load_from` | Yes | N/A | Must be `local`. |
| `file_path` | Yes | N/A | The path to the checkpoint file. |
If using `url`
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `load_from` | Yes | N/A | Must be `url`. |
| `url` | Yes | N/A | The url of the checkpoint file. |
If using `wandb`
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `load_from` | Yes | N/A | Must be `wandb`. |
| `wandb_run_path` | No | `None` | The wandb run path. If `None`, uses the run that is being resumed. |
| `wandb_file_path` | Yes | N/A | The path to the checkpoint file in the W&B file system. |
**Saving:**
Unlike `log` and `load`, `save` may be an array of options so that you can save to different locations in a run.
All save locations have these configuration options
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `save_to` | Yes | N/A | Must be `local`, `huggingface`, or `wandb`. |
| `save_latest_to` | No | `latest.pth` | Sets the relative path to save the latest model to. |
| `save_best_to` | No | `best.pth` | Sets the relative path to save the best model to every time the model has a lower validation loss than all previous models. |
| `save_type` | No | `'checkpoint'` | The type of save. `'checkpoint'` saves a checkpoint, `'model'` saves a model without any fluff (Saves with ema if ema is enabled). |
If using `local`
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `save_to` | Yes | N/A | Must be `local`. |
If using `huggingface`
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `save_to` | Yes | N/A | Must be `huggingface`. |
| `huggingface_repo` | Yes | N/A | The huggingface repository to save to. |
| `huggingface_base_path` | Yes | N/A | The base path that checkpoints will be saved under. |
| `token_path` | No | `None` | If logging in with the huggingface cli is not possible, point to a token file instead. |
If using `wandb`
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `save_to` | Yes | N/A | Must be `wandb`. |
| `wandb_run_path` | No | `None` | The wandb run path. If `None`, uses the current run. You will almost always want this to be `None`. |

View File

@@ -1,111 +0,0 @@
{
"decoder": {
"unets": [
{
"dim": 128,
"image_embed_dim": 768,
"cond_dim": 64,
"channels": 3,
"dim_mults": [1, 2, 4, 8],
"attn_dim_head": 32,
"attn_heads": 16
}
],
"image_sizes": [64],
"channels": 3,
"timesteps": 1000,
"loss_type": "l2",
"beta_schedule": ["cosine"],
"learned_variance": true
},
"data": {
"webdataset_base_url": "pipe:s3cmd get s3://bucket/path/{}.tar -",
"embeddings_url": "s3://bucket/embeddings/path/",
"num_workers": 4,
"batch_size": 64,
"start_shard": 0,
"end_shard": 9999999,
"shard_width": 6,
"index_width": 4,
"splits": {
"train": 0.75,
"val": 0.15,
"test": 0.1
},
"shuffle_train": true,
"resample_train": false,
"preprocessing": {
"RandomResizedCrop": {
"size": [128, 128],
"scale": [0.75, 1.0],
"ratio": [1.0, 1.0]
},
"ToTensor": true
}
},
"train": {
"epochs": 20,
"lr": 1e-4,
"wd": 0.01,
"max_grad_norm": 0.5,
"save_every_n_samples": 100000,
"n_sample_images": 6,
"device": "cuda:0",
"epoch_samples": null,
"validation_samples": null,
"use_ema": true,
"ema_beta": 0.99,
"amp": false,
"save_all": false,
"save_latest": true,
"save_best": true,
"unet_training_mask": [true]
},
"evaluate": {
"n_evaluation_samples": 1000,
"FID": {
"feature": 64
},
"IS": {
"feature": 64,
"splits": 10
},
"KID": {
"feature": 64,
"subset_size": 10
},
"LPIPS": {
"net_type": "vgg",
"reduction": "mean"
}
},
"tracker": {
"overwrite_data_path": true,
"log": {
"log_type": "wandb",
"wandb_entity": "your_wandb",
"wandb_project": "your_project",
"verbose": true
},
"load": {
"load_from": null
},
"save": [{
"save_to": "wandb"
}, {
"save_to": "huggingface",
"huggingface_repo": "Veldrovive/test_model",
"save_all": true,
"save_latest": true,
"save_best": true,
"save_type": "model"
}]
}
}

View File

@@ -1,70 +0,0 @@
{
"prior": {
"clip": {
"make": "x-clip",
"model": "ViT-L/14",
"base_model_kwargs": {
"dim_text": 768,
"dim_image": 768,
"dim_latent": 768
}
},
"net": {
"dim": 768,
"depth": 12,
"num_timesteps": 1000,
"num_time_embeds": 1,
"num_image_embeds": 1,
"num_text_embeds": 1,
"dim_head": 64,
"heads": 12,
"ff_mult": 4,
"norm_out": true,
"attn_dropout": 0.0,
"ff_dropout": 0.0,
"final_proj": true,
"normformer": true,
"rotary_emb": true
},
"image_embed_dim": 768,
"image_size": 224,
"image_channels": 3,
"timesteps": 1000,
"cond_drop_prob": 0.1,
"loss_type": "l2",
"predict_x_start": true,
"beta_schedule": "cosine",
"condition_on_text_encodings": true
},
"data": {
"image_url": "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/",
"text_url": "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/",
"meta_url": "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/laion2B-en-metadata/",
"batch_size": 256,
"splits": {
"train": 0.9,
"val": 1e-7,
"test": 0.0999999
}
},
"train": {
"epochs": 1,
"lr": 1.1e-4,
"wd": 6.02e-2,
"max_grad_norm": 0.5,
"use_ema": true,
"amp": false,
"save_every": 10000
},
"load": {
"source": null,
"resume": false
},
"tracker": {
"tracker_type": "wandb",
"data_path": "./prior_checkpoints",
"wandb_entity": "laion",
"wandb_project": "diffusion-prior",
"verbose": true
}
}

View File

@@ -1,4 +1,3 @@
from dalle2_pytorch.version import __version__
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter
from dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer

File diff suppressed because it is too large Load Diff

View File

@@ -4,7 +4,7 @@ In order to make loading data simple and efficient, we include some general data
### Decoder: Image Embedding Dataset
When training the decoder (and up samplers if training together) in isolation, you will need to load images and corresponding image embeddings. This dataset can read two similar types of datasets. First, it can read a [webdataset](https://github.com/webdataset/webdataset) that contains `.jpg` and `.npy` files in the `.tar`s that contain the images and associated image embeddings respectively. Alternatively, you can also specify a source for the embeddings outside of the webdataset. In this case, the path to the embeddings should contain `.npy` files with the same shard numbers as the webdataset and there should be a correspondence between the filename of the `.jpg` and the index of the embedding in the `.npy`. So, for example, `0001.tar` from the webdataset with image `00010509.jpg` (the first 4 digits are the shard number and the last 4 are the index) in it should be paralleled by a `img_emb_0001.npy` which contains a NumPy array with the embedding at index 509.
Generating a dataset of this type:
Generating a dataset of this type:
1. Use [img2dataset](https://github.com/rom1504/img2dataset) to generate a webdataset.
2. Use [clip-retrieval](https://github.com/rom1504/clip-retrieval) to convert the images to embeddings.
3. Use [embedding-dataset-reordering](https://github.com/Veldrovive/embedding-dataset-reordering) to reorder the embeddings into the expected format.
@@ -15,7 +15,7 @@ from dalle2_pytorch.dataloaders import ImageEmbeddingDataset, create_image_embed
# Create a dataloader directly.
dataloader = create_image_embedding_dataloader(
tar_url="/path/or/url/to/webdataset/{0000..9999}.tar", # Uses bracket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar
tar_url="/path/or/url/to/webdataset/{0000..9999}.tar", # Uses braket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar
embeddings_url="path/or/url/to/embeddings/folder", # Included if .npy files are not in webdataset. Left out or set to None otherwise
num_workers=4,
batch_size=32,
@@ -39,37 +39,3 @@ dataset = ImageEmbeddingDataset(
)
```
### Diffusion Prior: Prior Embedding Dataset
When training the prior it is much more efficient to work with pre-computed embeddings. The `PriorEmbeddingDataset` class enables you to leverage the same script (with minimal modification) for both embedding-only and text-conditioned prior training. This saves you from having to worry about a lot of the boilerplate code.
To utilize the `PriorEmbeddingDataset`, all you need to do is make a single call to `get_reader()` which will create `EmbeddingReader` object(s) for you. Afterwards, you can utilize `make_splits()` to cleanly create DataLoader objects from for your training run.
If you are training in a distributed manner, `make_splits()` accepts `rank` and `world_size` arguments to properly distribute to each process. The defaults for these values are `rank=0` and `world_size=1`, so single-process training can safely ignore these parameters.
Usage:
```python
from dalle2_pytorch.dataloaders import get_reader, make_splits
# grab embeddings from some specified location
IMG_URL = "data/img_emb/"
META_URL = "data/meta/"
reader = get_reader(text_conditioned=True, img_url=IMG_URL, meta_url=META_URL)
# some config for training
TRAIN_ARGS = {
"world_size": 3,
"text_conditioned": True,
"start": 0,
"num_data_points": 10000,
"batch_size": 2,
"train_split": 0.5,
"eval_split": 0.25,
"image_reader": reader,
}
# specifying a rank will handle allocation internally
rank0_train, rank0_eval, rank0_test = make_splits(rank=0, **TRAIN_ARGS)
rank1_train, rank1_eval, rank1_test = make_splits(rank=1, **TRAIN_ARGS)
rank2_train, rank2_eval, rank2_test = make_splits(rank=2, **TRAIN_ARGS)
```

View File

@@ -1,2 +1,2 @@
from dalle2_pytorch.dataloaders.decoder_loader import ImageEmbeddingDataset, create_image_embedding_dataloader
from dalle2_pytorch.dataloaders.prior_loader import make_splits, get_reader, PriorEmbeddingDataset
from dalle2_pytorch.dataloaders.embedding_wrapper import make_splits

View File

@@ -21,7 +21,7 @@ def get_example_file(fs, path, file_format):
"""
return fs.glob(os.path.join(path, f"*.{file_format}"))[0]
def embedding_inserter(samples, embeddings_url, index_width, sample_key='npy', handler=wds.handlers.reraise_exception):
def embedding_inserter(samples, embeddings_url, index_width, handler=wds.handlers.reraise_exception):
"""Given a datum of {"__key__": str, "__url__": str, ...} adds the cooresponding embedding and yields"""
previous_tar_url = None
current_embeddings = None
@@ -56,7 +56,7 @@ def embedding_inserter(samples, embeddings_url, index_width, sample_key='npy', h
# We need to check if this sample is nonzero. If it is, this embedding is not valid and we should continue to the next loop
if torch.count_nonzero(embedding) == 0:
raise RuntimeError(f"Webdataset had a sample, but no embedding was found. ImgShard: {key[:-index_width]} - Index: {key[-index_width:]}")
sample[sample_key] = embedding
sample["npy"] = embedding
yield sample
except Exception as exn: # From wds implementation
if handler(exn):
@@ -84,43 +84,24 @@ def unassociated_shard_skipper(tarfiles, embeddings_url, handler=wds.handlers.re
continue
else:
break
skip_unassociated_shards = wds.filters.pipelinefilter(unassociated_shard_skipper)
def join_embeddings(samples, handler=wds.handlers.reraise_exception):
"""
Takes the img_emb and text_emb keys and turns them into one key "emb": { "text": text_emb, "img": img_emb }
either or both of text_emb and img_emb may not be in the sample so we only add the ones that exist
"""
for sample in samples:
try:
sample['emb'] = {}
if 'text_emb' in sample:
sample['emb']['text'] = sample['text_emb']
if 'img_emb' in sample:
sample['emb']['img'] = sample['img_emb']
yield sample
except Exception as exn: # From wds implementation
if handler(exn):
continue
else:
break
def verify_keys(samples, required_keys, handler=wds.handlers.reraise_exception):
def verify_keys(samples, handler=wds.handlers.reraise_exception):
"""
Requires that both the image and embedding are present in the sample
This is important to do as a user may forget they do not have embeddings in their webdataset and neglect to add them using the embedding_folder_url parameter.
"""
for sample in samples:
try:
for key in required_keys:
assert key in sample, f"Sample {sample['__key__']} missing {key}. Has keys {sample.keys()}"
assert "jpg" in sample, f"Sample {sample['__key__']} missing image"
assert "npy" in sample, f"Sample {sample['__key__']} missing embedding. Did you set embedding_folder_url?"
yield sample
except Exception as exn: # From wds implementation
if handler(exn):
continue
else:
break
key_verifier = wds.filters.pipelinefilter(verify_keys)
class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
"""
@@ -131,8 +112,7 @@ class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
def __init__(
self,
urls,
img_embedding_folder_url=None,
text_embedding_folder_url=None,
embedding_folder_url=None,
index_width=None,
img_preproc=None,
extra_keys=[],
@@ -156,12 +136,7 @@ class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
"""
super().__init__()
keys = ["jpg", "emb"] + extra_keys
# if img_embedding_folder_url is not None:
# keys.append("img_emb")
# if text_embedding_folder_url is not None:
# keys.append("text_emb")
# keys.extend(extra_keys)
keys = ["jpg", "npy"] + extra_keys
self.key_map = {key: i for i, key in enumerate(keys)}
self.resampling = resample
self.img_preproc = img_preproc
@@ -170,7 +145,7 @@ class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
# Then this has an s3 link for the webdataset and we need extra packages
if shutil.which("s3cmd") is None:
raise RuntimeError("s3cmd is required for s3 webdataset")
if (img_embedding_folder_url is not None and "s3:" in img_embedding_folder_url) or (text_embedding_folder_url is not None and "s3:" in text_embedding_folder_url):
if "s3:" in embedding_folder_url:
# Then the embeddings are being loaded from s3 and fsspec requires s3fs
try:
import s3fs
@@ -185,24 +160,20 @@ class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
if shuffle_shards:
self.append(wds.filters.shuffle(1000))
if img_embedding_folder_url is not None:
if embedding_folder_url is not None:
# There may be webdataset shards that do not have a embedding shard associated with it. If we do not skip these, they would cause issues.
self.append(skip_unassociated_shards(embeddings_url=img_embedding_folder_url, handler=handler))
if text_embedding_folder_url is not None:
self.append(skip_unassociated_shards(embeddings_url=text_embedding_folder_url, handler=handler))
self.append(skip_unassociated_shards(embeddings_url=embedding_folder_url, handler=handler))
self.append(wds.split_by_node)
self.append(wds.split_by_worker)
self.append(wds.tarfile_to_samples(handler=handler))
self.append(wds.decode("pilrgb", handler=handler))
if img_embedding_folder_url is not None:
# Then we are loading image embeddings for a remote source
if embedding_folder_url is not None:
# Then we are loading embeddings for a remote source
assert index_width is not None, "Reading embeddings separately requires index width length to be given"
self.append(insert_embedding(embeddings_url=img_embedding_folder_url, index_width=index_width, sample_key='img_emb', handler=handler))
if text_embedding_folder_url is not None:
# Then we are loading image embeddings for a remote source
assert index_width is not None, "Reading embeddings separately requires index width length to be given"
self.append(insert_embedding(embeddings_url=text_embedding_folder_url, index_width=index_width, sample_key='text_emb', handler=handler))
self.append(join_embeddings)
self.append(key_verifier(required_keys=keys, handler=handler))
self.append(insert_embedding(embeddings_url=embedding_folder_url, index_width=index_width, handler=handler))
self.append(verify_keys)
# Apply preprocessing
self.append(wds.map(self.preproc))
self.append(wds.to_tuple(*keys))
@@ -217,8 +188,7 @@ def create_image_embedding_dataloader(
tar_url,
num_workers,
batch_size,
img_embeddings_url=None,
text_embeddings_url=None,
embeddings_url=None,
index_width=None,
shuffle_num = None,
shuffle_shards = True,
@@ -244,8 +214,7 @@ def create_image_embedding_dataloader(
"""
ds = ImageEmbeddingDataset(
tar_url,
img_embedding_folder_url=img_embeddings_url,
text_embedding_folder_url=text_embeddings_url,
embeddings_url,
index_width=index_width,
shuffle_shards=shuffle_shards,
resample=resample_shards,
@@ -262,4 +231,4 @@ def create_image_embedding_dataloader(
prefetch_factor=2, # This might be good to have high so the next npy file is prefetched
pin_memory=True,
shuffle=False
)
)

View File

@@ -0,0 +1,180 @@
from torch.utils.data import IterableDataset
from torch import from_numpy
from clip import tokenize
from embedding_reader import EmbeddingReader
class PriorEmbeddingLoader(IterableDataset):
def __init__(
self,
text_conditioned: bool,
batch_size: int,
start: int,
stop: int,
image_reader,
text_reader: EmbeddingReader = None,
device: str = "cpu",
) -> None:
super(PriorEmbeddingLoader).__init__()
self.text_conditioned = text_conditioned
if not self.text_conditioned:
self.text_reader = text_reader
self.image_reader = image_reader
self.batch_size = batch_size
self.start = start
self.stop = stop
self.device = device
def __iter__(self):
self.n = 0
loader_args = dict(
batch_size=self.batch_size,
start=self.start,
end=self.stop,
show_progress=False,
)
if self.text_conditioned:
self.loader = self.image_reader(**loader_args)
else:
self.loader = zip(
self.image_reader(**loader_args), self.text_reader(**loader_args)
)
return self
def __next__(self):
try:
return self.get_sample()
except StopIteration:
raise StopIteration
def get_sample(self):
"""
pre-proocess data from either reader into a common format
"""
self.n += 1
if self.text_conditioned:
image_embedding, caption = next(self.loader)
image_embedding = from_numpy(image_embedding).to(self.device)
tokenized_caption = tokenize(
caption["caption"].to_list(), truncate=True
).to(self.device)
return image_embedding, tokenized_caption
else:
(image_embedding, _), (text_embedding, _) = next(self.loader)
image_embedding = from_numpy(image_embedding).to(self.device)
text_embedding = from_numpy(text_embedding).to(self.device)
return image_embedding, text_embedding
def make_splits(
text_conditioned: bool,
batch_size: int,
num_data_points: int,
train_split: float,
eval_split: float,
device: str,
img_url: str,
meta_url: str = None,
txt_url: str = None,
):
assert img_url is not None, "Must supply some image embeddings"
if text_conditioned:
assert meta_url is not None, "Must supply metadata url if text-conditioning"
image_reader = EmbeddingReader(
embeddings_folder=img_url,
file_format="parquet_npy",
meta_columns=["caption"],
metadata_folder=meta_url,
)
# compute split points
if num_data_points > image_reader.count:
print("Specified point count is larger than the number of points available...defaulting to max length of reader.")
num_data_points = image_reader.count
train_set_size = int(train_split * num_data_points)
eval_set_size = int(eval_split * num_data_points)
eval_stop = int(train_set_size + eval_set_size)
train_loader = PriorEmbeddingLoader(
text_conditioned=text_conditioned,
image_reader=image_reader,
batch_size=batch_size,
start=0,
stop=train_set_size,
device=device,
)
eval_loader = PriorEmbeddingLoader(
text_conditioned=text_conditioned,
image_reader=image_reader,
batch_size=batch_size,
start=train_set_size,
stop=eval_stop,
device=device,
)
test_loader = PriorEmbeddingLoader(
text_conditioned=text_conditioned,
image_reader=image_reader,
batch_size=batch_size,
start=eval_stop,
stop=int(num_data_points),
device=device,
)
else:
assert (
txt_url is not None
), "Must supply text embedding url if not text-conditioning"
image_reader = EmbeddingReader(img_url, file_format="npy")
text_reader = EmbeddingReader(txt_url, file_format="npy")
# compute split points
if num_data_points > image_reader.count:
print("Specified point count is larger than the number of points available...defaulting to max length of reader.")
num_data_points = image_reader.count
train_set_size = int(train_split * num_data_points)
eval_set_size = int(eval_split * num_data_points)
eval_stop = int(train_set_size + eval_set_size)
train_loader = PriorEmbeddingLoader(
text_conditioned=text_conditioned,
image_reader=image_reader,
text_reader=text_reader,
batch_size=batch_size,
start=0,
stop=train_set_size,
device=device,
)
eval_loader = PriorEmbeddingLoader(
text_conditioned=text_conditioned,
image_reader=image_reader,
text_reader=text_reader,
batch_size=batch_size,
start=train_set_size,
stop=eval_stop,
device=device,
)
test_loader = PriorEmbeddingLoader(
text_conditioned=text_conditioned,
image_reader=image_reader,
text_reader=text_reader,
batch_size=batch_size,
start=eval_stop,
stop=int(num_data_points),
device=device,
)
return train_loader, eval_loader, test_loader

View File

@@ -1,273 +0,0 @@
from math import ceil
from clip import tokenize
from embedding_reader import EmbeddingReader
from torch import from_numpy
from torch.utils.data import IterableDataset, DataLoader
class PriorEmbeddingDataset(IterableDataset):
"""
PriorEmbeddingDataset is a wrapper of EmbeddingReader.
It enables one to simplify the logic necessary to yield samples from
the different EmbeddingReader configurations available.
"""
def __init__(
self,
text_conditioned: bool,
batch_size: int,
start: int,
stop: int,
image_reader,
text_reader: EmbeddingReader = None,
) -> None:
super(PriorEmbeddingDataset).__init__()
self.text_conditioned = text_conditioned
if not self.text_conditioned:
self.text_reader = text_reader
self.image_reader = image_reader
self.start = start
self.stop = stop
self.batch_size = batch_size
def __len__(self):
return self.stop - self.start
def __iter__(self):
# D.R.Y loader args
loader_args = dict(
batch_size=self.batch_size,
start=self.start,
end=self.stop,
show_progress=False,
)
# if the data requested is text conditioned, only load images
if self.text_conditioned:
self.loader = self.image_reader(**loader_args)
# otherwise, include text embeddings and bypass metadata
else:
self.loader = zip(
self.image_reader(**loader_args), self.text_reader(**loader_args)
)
# return the data loader in its formatted state
return self
def __next__(self):
try:
return self.get_sample()
except StopIteration:
raise StopIteration
def __str__(self):
return f"<PriorEmbeddingDataset: start: {self.start}, stop: {self.stop}, len: {self.__len__()}>"
def get_sample(self):
"""
pre-proocess data from either reader into a common format
"""
if self.text_conditioned:
image_embedding, caption = next(self.loader)
image_embedding = from_numpy(image_embedding)
tokenized_caption = tokenize(caption["caption"].to_list(), truncate=True)
return image_embedding, tokenized_caption
else:
(image_embedding, _), (text_embedding, _) = next(self.loader)
image_embedding = from_numpy(image_embedding)
text_embedding = from_numpy(text_embedding)
return image_embedding, text_embedding
# helper functions
def distribute_to_rank(start, stop, rank, world_size):
"""
Distribute data to each rank given the world size.
Return:
- New start and stop points for this rank.
"""
num_samples = int(stop - start)
per_rank = int(ceil((num_samples) / float(world_size)))
assert (
per_rank > 0
), f"Number of samples per rank must be larger than 0, (found: {per_rank})"
rank_start = start + rank * per_rank
rank_stop = min(rank_start + per_rank, stop)
new_length = rank_stop - rank_start
assert (
new_length > 0
), "Calculated start and stop points result in a length of zero for this rank."
return rank_start, rank_stop
def get_reader(
text_conditioned: bool, img_url: str, meta_url: str = None, txt_url: str = None
):
"""
Create an EmbeddingReader object from the specified URLs
get_reader() will always expect a url to image embeddings.
If text-conditioned, it will also expect a meta_url for the captions.
Otherwise, it will need txt_url for the matching text embeddings.
Returns an image_reader object if text-conditioned.
Otherwise it returns both an image_reader and a text_reader
"""
assert img_url is not None, "Must supply a image url"
if text_conditioned:
assert meta_url is not None, "Must supply meta url if text-conditioned"
image_reader = EmbeddingReader(
embeddings_folder=img_url,
file_format="parquet_npy",
# will assume the caption column exists and is the only one requested
meta_columns=["caption"],
metadata_folder=meta_url,
)
return image_reader
# otherwise we will require text embeddings as well and return two readers
assert (
txt_url is not None
), "Must supply text embedding url if not text-conditioning"
image_reader = EmbeddingReader(img_url, file_format="npy")
text_reader = EmbeddingReader(txt_url, file_format="npy")
return image_reader, text_reader
def make_splits(
text_conditioned: bool,
batch_size: int,
num_data_points: int,
train_split: float,
eval_split: float,
image_reader: EmbeddingReader,
text_reader: EmbeddingReader = None,
start=0,
rank=0,
world_size=1,
):
"""
Split an embedding reader object as needed.
NOTE: make_splits() will infer the test set size from your train and eval.
Input:
- text_conditioned: whether to prepare text-conditioned training data
- batch_size: the batch size for a single gpu
- num_data_points: the total number of data points you wish to train on
- train_split: the percentage of data you wish to train on
- eval_split: the percentage of data you wish to validate on
- image_reader: the image_reader you wish to split
- text_reader: the text_reader you want to split (if !text_conditioned)
- start: the starting point within your dataset
- rank: the rank of your worker
- world_size: the total world size of your distributed training run
Returns:
- PyTorch Dataloaders that yield tuples of (img, txt) data.
"""
assert start < image_reader.count, "start position cannot exceed reader count."
# verify that the num_data_points does not exceed the max points
if num_data_points > (image_reader.count - start):
print(
"Specified count is larger than what's available...defaulting to reader's count."
)
num_data_points = image_reader.count
# compute split points
train_set_size = int(train_split * num_data_points)
eval_set_size = int(eval_split * num_data_points)
eval_start = train_set_size
eval_stop = int(eval_start + eval_set_size)
assert (
train_split + eval_split
) < 1.0, "Specified train and eval split is too large to infer a test split."
# distribute to rank
rank_train_start, rank_train_stop = distribute_to_rank(
start, train_set_size, rank, world_size
)
rank_eval_start, rank_eval_stop = distribute_to_rank(
train_set_size, eval_stop, rank, world_size
)
rank_test_start, rank_test_stop = distribute_to_rank(
eval_stop, num_data_points, rank, world_size
)
# wrap up splits into a dict
train_split_args = dict(
start=rank_train_start, stop=rank_train_stop, batch_size=batch_size
)
eval_split_args = dict(
start=rank_eval_start, stop=rank_eval_stop, batch_size=batch_size
)
test_split_args = dict(
start=rank_test_start, stop=rank_test_stop, batch_size=batch_size
)
if text_conditioned:
# add the text-conditioned args to a unified dict
reader_args = dict(
text_conditioned=text_conditioned,
image_reader=image_reader,
)
train_split_args = dict(**reader_args, **train_split_args)
eval_split_args = dict(**reader_args, **eval_split_args)
test_split_args = dict(**reader_args, **test_split_args)
train = PriorEmbeddingDataset(**train_split_args)
val = PriorEmbeddingDataset(**eval_split_args)
test = PriorEmbeddingDataset(**test_split_args)
else:
# add the non-conditioned args to a unified dict
reader_args = dict(
text_conditioned=text_conditioned,
image_reader=image_reader,
text_reader=text_reader,
)
train_split_args = dict(**reader_args, **train_split_args)
eval_split_args = dict(**reader_args, **eval_split_args)
test_split_args = dict(**reader_args, **test_split_args)
train = PriorEmbeddingDataset(**train_split_args)
val = PriorEmbeddingDataset(**eval_split_args)
test = PriorEmbeddingDataset(**test_split_args)
# true batch size is specifed in the PriorEmbeddingDataset
train_loader = DataLoader(train, batch_size=None)
eval_loader = DataLoader(val, batch_size=None)
test_loader = DataLoader(test, batch_size=None)
return train_loader, eval_loader, test_loader

View File

@@ -1,21 +1,17 @@
from torch.optim import AdamW, Adam
def separate_weight_decayable_params(params):
wd_params, no_wd_params = [], []
for param in params:
param_list = no_wd_params if param.ndim < 2 else wd_params
param_list.append(param)
no_wd_params = set([param for param in params if param.ndim < 2])
wd_params = set(params) - no_wd_params
return wd_params, no_wd_params
def get_optimizer(
params,
lr = 1e-4,
wd = 1e-2,
betas = (0.9, 0.99),
betas = (0.9, 0.999),
eps = 1e-8,
filter_by_requires_grad = False,
group_wd_params = True,
**kwargs
filter_by_requires_grad = False
):
if filter_by_requires_grad:
params = list(filter(lambda t: t.requires_grad, params))
@@ -23,12 +19,12 @@ def get_optimizer(
if wd == 0:
return Adam(params, lr = lr, betas = betas, eps = eps)
if group_wd_params:
wd_params, no_wd_params = separate_weight_decayable_params(params)
params = set(params)
wd_params, no_wd_params = separate_weight_decayable_params(params)
params = [
{'params': wd_params},
{'params': no_wd_params, 'weight_decay': 0},
]
param_groups = [
{'params': list(wd_params)},
{'params': list(no_wd_params), 'weight_decay': 0},
]
return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps)
return AdamW(param_groups, lr = lr, weight_decay = wd, betas = betas, eps = eps)

View File

@@ -2,6 +2,7 @@
# to give users a quick easy start to training DALL-E without doing BPE
import torch
import youtokentome as yttm
import html
import os
@@ -10,8 +11,6 @@ import regex as re
from functools import lru_cache
from pathlib import Path
from dalle2_pytorch.utils import import_or_print_error
# OpenAI simple tokenizer
@lru_cache()
@@ -157,9 +156,7 @@ class YttmTokenizer:
bpe_path = Path(bpe_path)
assert bpe_path.exists(), f'BPE json path {str(bpe_path)} does not exist'
self.yttm = import_or_print_error('youtokentome', 'you need to install youtokentome by `pip install youtokentome`')
tokenizer = self.yttm.BPE(model = str(bpe_path))
tokenizer = yttm.BPE(model = str(bpe_path))
self.tokenizer = tokenizer
self.vocab_size = tokenizer.vocab_size()
@@ -170,7 +167,7 @@ class YttmTokenizer:
return self.tokenizer.decode(tokens, ignore_ids = pad_tokens.union({0}))
def encode(self, texts):
encoded = self.tokenizer.encode(texts, output_type = self.yttm.OutputType.ID)
encoded = self.tokenizer.encode(texts, output_type = yttm.OutputType.ID)
return list(map(torch.tensor, encoded))
def tokenize(self, texts, context_length = 256, truncate_text = False):

View File

@@ -1,15 +1,11 @@
import urllib.request
import os
from pathlib import Path
import shutil
from enum import Enum
import importlib
from itertools import zip_longest
from typing import Optional, List, Union
from pydantic import BaseModel
import torch
from dalle2_pytorch.utils import import_or_print_error
from dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer
from torch import nn
# constants
@@ -20,494 +16,102 @@ DEFAULT_DATA_PATH = './.tracker-data'
def exists(val):
return val is not None
# load file functions
def import_or_print_error(pkg_name, err_str = None):
try:
return importlib.import_module(pkg_name)
except ModuleNotFoundError as e:
if exists(err_str):
print(err_str)
exit()
def load_wandb_file(run_path, file_path, **kwargs):
# load state dict functions
def load_wandb_state_dict(run_path, file_path, **kwargs):
wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb recall function')
file_reference = wandb.restore(file_path, run_path=run_path)
return file_reference.name
return torch.load(file_reference.name)
def load_local_file(file_path, **kwargs):
return file_path
def load_local_state_dict(file_path, **kwargs):
return torch.load(file_path)
class BaseLogger:
"""
An abstract class representing an object that can log data.
Parameters:
data_path (str): A file path for storing temporary data.
verbose (bool): Whether of not to always print logs to the console.
"""
def __init__(self, data_path: str, verbose: bool = False, **kwargs):
# base class
class BaseTracker(nn.Module):
def __init__(self, data_path = DEFAULT_DATA_PATH):
super().__init__()
assert data_path is not None, "Tracker must have a data_path to save local content"
self.data_path = Path(data_path)
self.verbose = verbose
self.data_path.mkdir(parents = True, exist_ok = True)
def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:
def init(self, config, **kwargs):
raise NotImplementedError
def log(self, log, **kwargs):
raise NotImplementedError
def log_images(self, images, **kwargs):
raise NotImplementedError
def save_state_dict(self, state_dict, relative_path, **kwargs):
raise NotImplementedError
def recall_state_dict(self, recall_source, *args, **kwargs):
"""
Initializes the logger.
Errors if the logger is invalid.
Loads a state dict from any source.
Since a user may wish to load a model from a different source than their own tracker (i.e. tracking using wandb but recalling from disk),
this should not be linked to any individual tracker.
"""
raise NotImplementedError
# TODO: Pull this into a dict or something similar so that we can add more sources without having a massive switch statement
if recall_source == 'wandb':
return load_wandb_state_dict(*args, **kwargs)
elif recall_source == 'local':
return load_local_state_dict(*args, **kwargs)
else:
raise ValueError('`recall_source` must be one of `wandb` or `local`')
def log(self, log, **kwargs) -> None:
raise NotImplementedError
def log_images(self, images, captions=[], image_section="images", **kwargs) -> None:
raise NotImplementedError
# basic stdout class
def log_file(self, file_path, **kwargs) -> None:
raise NotImplementedError
class ConsoleTracker(BaseTracker):
def init(self, **config):
print(config)
def log_error(self, error_string, **kwargs) -> None:
raise NotImplementedError
class ConsoleLogger(BaseLogger):
def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:
print("Logging to console")
def log(self, log, **kwargs) -> None:
def log(self, log, **kwargs):
print(log)
def log_images(self, images, captions=[], image_section="images", **kwargs) -> None:
def log_images(self, images, **kwargs): # noop for logging images
pass
def save_state_dict(self, state_dict, relative_path, **kwargs):
torch.save(state_dict, str(self.data_path / relative_path))
def log_file(self, file_path, **kwargs) -> None:
pass
# basic wandb class
def log_error(self, error_string, **kwargs) -> None:
print(error_string)
class WandbLogger(BaseLogger):
"""
Logs to a wandb run.
Parameters:
data_path (str): A file path for storing temporary data.
wandb_entity (str): The wandb entity to log to.
wandb_project (str): The wandb project to log to.
wandb_run_id (str): The wandb run id to resume.
wandb_run_name (str): The wandb run name to use.
wandb_resume (bool): Whether to resume a wandb run.
"""
def __init__(self,
data_path: str,
wandb_entity: str,
wandb_project: str,
wandb_run_id: Optional[str] = None,
wandb_run_name: Optional[str] = None,
wandb_resume: bool = False,
**kwargs
):
super().__init__(data_path, **kwargs)
self.entity = wandb_entity
self.project = wandb_project
self.run_id = wandb_run_id
self.run_name = wandb_run_name
self.resume = wandb_resume
def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:
assert self.entity is not None, "wandb_entity must be specified for wandb logger"
assert self.project is not None, "wandb_project must be specified for wandb logger"
self.wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb logger')
class WandbTracker(BaseTracker):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb experiment tracker')
os.environ["WANDB_SILENT"] = "true"
# Initializes the wandb run
init_object = {
"entity": self.entity,
"project": self.project,
"config": {**full_config.dict(), **extra_config}
}
if self.run_name is not None:
init_object['name'] = self.run_name
if self.resume:
assert self.run_id is not None, '`wandb_run_id` must be provided if `wandb_resume` is True'
if self.run_name is not None:
print("You are renaming a run. I hope that is what you intended.")
init_object['resume'] = 'must'
init_object['id'] = self.run_id
self.wandb.init(**init_object)
print(f"Logging to wandb run {self.wandb.run.path}-{self.wandb.run.name}")
def init(self, **config):
self.wandb.init(**config)
def log(self, log, **kwargs) -> None:
if self.verbose:
def log(self, log, verbose=False, **kwargs):
if verbose:
print(log)
self.wandb.log(log, **kwargs)
def log_images(self, images, captions=[], image_section="images", **kwargs) -> None:
def log_images(self, images, captions=[], image_section="images", **kwargs):
"""
Takes a tensor of images and a list of captions and logs them to wandb.
"""
wandb_images = [self.wandb.Image(image, caption=caption) for image, caption in zip_longest(images, captions)]
self.wandb.log({ image_section: wandb_images }, **kwargs)
def log_file(self, file_path, base_path: Optional[str] = None, **kwargs) -> None:
if base_path is None:
# Then we take the basepath as the parent of the file_path
base_path = Path(file_path).parent
self.wandb.save(str(file_path), base_path = str(base_path))
def log_error(self, error_string, step=None, **kwargs) -> None:
if self.verbose:
print(error_string)
self.wandb.log({"error": error_string, **kwargs}, step=step)
logger_type_map = {
'console': ConsoleLogger,
'wandb': WandbLogger,
}
def create_logger(logger_type: str, data_path: str, **kwargs) -> BaseLogger:
if logger_type == 'custom':
raise NotImplementedError('Custom loggers are not supported yet. Please use a different logger type.')
try:
logger_class = logger_type_map[logger_type]
except KeyError:
raise ValueError(f'Unknown logger type: {logger_type}. Must be one of {list(logger_type_map.keys())}')
return logger_class(data_path, **kwargs)
class BaseLoader:
"""
An abstract class representing an object that can load a model checkpoint.
Parameters:
data_path (str): A file path for storing temporary data.
"""
def __init__(self, data_path: str, **kwargs):
self.data_path = Path(data_path)
def init(self, logger: BaseLogger, **kwargs) -> None:
raise NotImplementedError
def recall() -> dict:
raise NotImplementedError
class UrlLoader(BaseLoader):
"""
A loader that downloads the file from a url and loads it
Parameters:
data_path (str): A file path for storing temporary data.
url (str): The url to download the file from.
"""
def __init__(self, data_path: str, url: str, **kwargs):
super().__init__(data_path, **kwargs)
self.url = url
def init(self, logger: BaseLogger, **kwargs) -> None:
# Makes sure the file exists to be downloaded
pass # TODO: Actually implement that
def recall(self) -> dict:
# Download the file
save_path = self.data_path / 'loaded_checkpoint.pth'
urllib.request.urlretrieve(self.url, str(save_path))
# Load the file
return torch.load(str(save_path), map_location='cpu')
class LocalLoader(BaseLoader):
"""
A loader that loads a file from a local path
Parameters:
data_path (str): A file path for storing temporary data.
file_path (str): The path to the file to load.
"""
def __init__(self, data_path: str, file_path: str, **kwargs):
super().__init__(data_path, **kwargs)
self.file_path = Path(file_path)
def init(self, logger: BaseLogger, **kwargs) -> None:
# Makes sure the file exists to be loaded
if not self.file_path.exists():
raise FileNotFoundError(f'Model not found at {self.file_path}')
def recall(self) -> dict:
# Load the file
return torch.load(str(self.file_path), map_location='cpu')
class WandbLoader(BaseLoader):
"""
A loader that loads a model from an existing wandb run
"""
def __init__(self, data_path: str, wandb_file_path: str, wandb_run_path: Optional[str] = None, **kwargs):
super().__init__(data_path, **kwargs)
self.run_path = wandb_run_path
self.file_path = wandb_file_path
def init(self, logger: BaseLogger, **kwargs) -> None:
self.wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb recall function')
# Make sure the file can be downloaded
if self.wandb.run is not None and self.run_path is None:
self.run_path = self.wandb.run.path
assert self.run_path is not None, 'wandb run was not found to load from. If not using the wandb logger must specify the `wandb_run_path`.'
assert self.run_path is not None, '`wandb_run_path` must be provided for the wandb loader'
assert self.file_path is not None, '`wandb_file_path` must be provided for the wandb loader'
os.environ["WANDB_SILENT"] = "true"
pass # TODO: Actually implement that
def recall(self) -> dict:
file_reference = self.wandb.restore(self.file_path, run_path=self.run_path)
return torch.load(file_reference.name, map_location='cpu')
loader_type_map = {
'url': UrlLoader,
'local': LocalLoader,
'wandb': WandbLoader,
}
def create_loader(loader_type: str, data_path: str, **kwargs) -> BaseLoader:
if loader_type == 'custom':
raise NotImplementedError('Custom loaders are not supported yet. Please use a different loader type.')
try:
loader_class = loader_type_map[loader_type]
except KeyError:
raise ValueError(f'Unknown loader type: {loader_type}. Must be one of {list(loader_type_map.keys())}')
return loader_class(data_path, **kwargs)
class BaseSaver:
def __init__(self,
data_path: str,
save_latest_to: Optional[Union[str, bool]] = 'latest.pth',
save_best_to: Optional[Union[str, bool]] = 'best.pth',
save_meta_to: str = './',
save_type: str = 'checkpoint',
**kwargs
):
self.data_path = Path(data_path)
self.save_latest_to = save_latest_to
self.saving_latest = save_latest_to is not None and save_latest_to is not False
self.save_best_to = save_best_to
self.saving_best = save_best_to is not None and save_best_to is not False
self.save_meta_to = save_meta_to
self.save_type = save_type
assert save_type in ['checkpoint', 'model'], '`save_type` must be one of `checkpoint` or `model`'
assert self.save_meta_to is not None, '`save_meta_to` must be provided'
assert self.saving_latest or self.saving_best, '`save_latest_to` or `save_best_to` must be provided'
def init(self, logger: BaseLogger, **kwargs) -> None:
raise NotImplementedError
def save_file(self, local_path: Path, save_path: str, is_best=False, is_latest=False, **kwargs) -> None:
"""
Save a general file under save_meta_to
"""
raise NotImplementedError
class LocalSaver(BaseSaver):
def __init__(self,
data_path: str,
**kwargs
):
super().__init__(data_path, **kwargs)
def init(self, logger: BaseLogger, **kwargs) -> None:
# Makes sure the directory exists to be saved to
print(f"Saving {self.save_type} locally")
if not self.data_path.exists():
self.data_path.mkdir(parents=True)
def save_file(self, local_path: str, save_path: str, **kwargs) -> None:
# Copy the file to save_path
save_path_file_name = Path(save_path).name
print(f"Saving {save_path_file_name} {self.save_type} to local path {save_path}")
shutil.copy(local_path, save_path)
class WandbSaver(BaseSaver):
def __init__(self, data_path: str, wandb_run_path: Optional[str] = None, **kwargs):
super().__init__(data_path, **kwargs)
self.run_path = wandb_run_path
def init(self, logger: BaseLogger, **kwargs) -> None:
self.wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb logger')
os.environ["WANDB_SILENT"] = "true"
# Makes sure that the user can upload tot his run
if self.run_path is not None:
entity, project, run_id = self.run_path.split("/")
self.run = self.wandb.init(entity=entity, project=project, id=run_id)
else:
assert self.wandb.run is not None, 'You must be using the wandb logger if you are saving to wandb and have not set `wandb_run_path`'
self.run = self.wandb.run
# TODO: Now actually check if upload is possible
print(f"Saving to wandb run {self.run.path}-{self.run.name}")
def save_file(self, local_path: Path, save_path: str, **kwargs) -> None:
# In order to log something in the correct place in wandb, we need to have the same file structure here
save_path_file_name = Path(save_path).name
print(f"Saving {save_path_file_name} {self.save_type} to wandb run {self.run.path}-{self.run.name}")
save_path = Path(self.data_path) / save_path
save_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy(local_path, save_path)
self.run.save(str(save_path), base_path = str(self.data_path), policy='now')
class HuggingfaceSaver(BaseSaver):
def __init__(self, data_path: str, huggingface_repo: str, token_path: Optional[str] = None, **kwargs):
super().__init__(data_path, **kwargs)
self.huggingface_repo = huggingface_repo
self.token_path = token_path
def init(self, logger: BaseLogger, **kwargs):
# Makes sure this user can upload to the repo
self.hub = import_or_print_error('huggingface_hub', '`pip install huggingface_hub` to use the huggingface saver')
try:
identity = self.hub.whoami() # Errors if not logged in
# Then we are logged in
except:
# We are not logged in. Use the token_path to set the token.
if not os.path.exists(self.token_path):
raise Exception("Not logged in to huggingface and no token_path specified. Please login with `huggingface-cli login` or if that does not work set the token_path.")
with open(self.token_path, "r") as f:
token = f.read().strip()
self.hub.HfApi.set_access_token(token)
identity = self.hub.whoami()
print(f"Saving to huggingface repo {self.huggingface_repo}")
def save_file(self, local_path: Path, save_path: str, **kwargs) -> None:
# Saving to huggingface is easy, we just need to upload the file with the correct name
save_path_file_name = Path(save_path).name
print(f"Saving {save_path_file_name} {self.save_type} to huggingface repo {self.huggingface_repo}")
self.hub.upload_file(
path_or_fileobj=str(local_path),
path_in_repo=str(save_path),
repo_id=self.huggingface_repo
)
saver_type_map = {
'local': LocalSaver,
'wandb': WandbSaver,
'huggingface': HuggingfaceSaver
}
def create_saver(saver_type: str, data_path: str, **kwargs) -> BaseSaver:
if saver_type == 'custom':
raise NotImplementedError('Custom savers are not supported yet. Please use a different saver type.')
try:
saver_class = saver_type_map[saver_type]
except KeyError:
raise ValueError(f'Unknown saver type: {saver_type}. Must be one of {list(saver_type_map.keys())}')
return saver_class(data_path, **kwargs)
class Tracker:
def __init__(self, data_path: Optional[str] = DEFAULT_DATA_PATH, overwrite_data_path: bool = False, dummy_mode: bool = False):
self.data_path = Path(data_path)
if not dummy_mode:
if overwrite_data_path:
if self.data_path.exists():
shutil.rmtree(self.data_path)
self.data_path.mkdir(parents=True)
else:
assert not self.data_path.exists(), f'Data path {self.data_path} already exists. Set overwrite_data_path to True to overwrite.'
if not self.data_path.exists():
self.data_path.mkdir(parents=True)
self.logger: BaseLogger = None
self.loader: Optional[BaseLoader] = None
self.savers: List[BaseSaver]= []
self.dummy_mode = dummy_mode
def init(self, full_config: BaseModel, extra_config: dict):
assert self.logger is not None, '`logger` must be set before `init` is called'
if self.dummy_mode:
# The only thing we need is a loader
if self.loader is not None:
self.loader.init(self.logger)
return
assert len(self.savers) > 0, '`savers` must be set before `init` is called'
self.logger.init(full_config, extra_config)
if self.loader is not None:
self.loader.init(self.logger)
for saver in self.savers:
saver.init(self.logger)
def add_logger(self, logger: BaseLogger):
self.logger = logger
def add_loader(self, loader: BaseLoader):
self.loader = loader
def add_saver(self, saver: BaseSaver):
self.savers.append(saver)
def log(self, *args, **kwargs):
if self.dummy_mode:
return
self.logger.log(*args, **kwargs)
def log_images(self, *args, **kwargs):
if self.dummy_mode:
return
self.logger.log_images(*args, **kwargs)
def log_file(self, *args, **kwargs):
if self.dummy_mode:
return
self.logger.log_file(*args, **kwargs)
def save_config(self, current_config_path: str, config_name = 'config.json'):
if self.dummy_mode:
return
# Save the config under config_name in the root folder of data_path
shutil.copy(current_config_path, self.data_path / config_name)
for saver in self.savers:
remote_path = Path(saver.save_meta_to) / config_name
saver.save_file(current_config_path, str(remote_path))
def _save_state_dict(self, trainer: Union[DiffusionPriorTrainer, DecoderTrainer], save_type: str, file_path: str, **kwargs) -> Path:
def save_state_dict(self, state_dict, relative_path, **kwargs):
"""
Gets the state dict to be saved and writes it to file_path.
If save_type is 'checkpoint', we save the entire trainer state dict.
If save_type is 'model', we save only the model state dict.
Saves a state_dict to disk and uploads it
"""
assert save_type in ['checkpoint', 'model']
if save_type == 'checkpoint':
trainer.save(file_path, overwrite=True, **kwargs)
elif save_type == 'model':
if isinstance(trainer, DiffusionPriorTrainer):
prior = trainer.ema_diffusion_prior.ema_model if trainer.use_ema else trainer.diffusion_prior
state_dict = trainer.unwrap_model(prior).state_dict()
torch.save(state_dict, file_path)
elif isinstance(trainer, DecoderTrainer):
decoder = trainer.accelerator.unwrap_model(trainer.decoder)
if trainer.use_ema:
trainable_unets = decoder.unets
decoder.unets = trainer.unets # Swap EMA unets in
state_dict = decoder.state_dict()
decoder.unets = trainable_unets # Swap back
else:
state_dict = decoder.state_dict()
torch.save(state_dict, file_path)
else:
raise NotImplementedError('Saving this type of model with EMA mode enabled is not yet implemented. Actually, how did you get here?')
return Path(file_path)
def save(self, trainer, is_best: bool, is_latest: bool, **kwargs):
if self.dummy_mode:
return
if not is_best and not is_latest:
# Nothing to do
return
# Save the checkpoint and model to data_path
checkpoint_path = self.data_path / 'checkpoint.pth'
self._save_state_dict(trainer, 'checkpoint', checkpoint_path, **kwargs)
model_path = self.data_path / 'model.pth'
self._save_state_dict(trainer, 'model', model_path, **kwargs)
print("Saved cached models")
# Call the save methods on the savers
for saver in self.savers:
local_path = checkpoint_path if saver.save_type == 'checkpoint' else model_path
if saver.saving_latest and is_latest:
latest_checkpoint_path = saver.save_latest_to.format(**kwargs)
try:
saver.save_file(local_path, latest_checkpoint_path, is_latest=True, **kwargs)
except Exception as e:
self.logger.log_error(f'Error saving checkpoint: {e}', **kwargs)
print(f'Error saving checkpoint: {e}')
if saver.saving_best and is_best:
best_checkpoint_path = saver.save_best_to.format(**kwargs)
try:
saver.save_file(local_path, best_checkpoint_path, is_best=True, **kwargs)
except Exception as e:
self.logger.log_error(f'Error saving checkpoint: {e}', **kwargs)
print(f'Error saving checkpoint: {e}')
def recall(self):
if self.loader is not None:
return self.loader.recall()
else:
raise ValueError('No loader specified')
full_path = str(self.data_path / relative_path)
torch.save(state_dict, full_path)
self.wandb.save(full_path, base_path = str(self.data_path)) # Upload and keep relative to data_path

View File

@@ -1,362 +0,0 @@
import json
from torchvision import transforms as T
from pydantic import BaseModel, validator, root_validator
from typing import List, Iterable, Optional, Union, Tuple, Dict, Any
from x_clip import CLIP as XCLIP
from coca_pytorch import CoCa
from dalle2_pytorch.dalle2_pytorch import (
CoCaAdapter,
OpenAIClipAdapter,
Unet,
Decoder,
DiffusionPrior,
DiffusionPriorNetwork,
XClipAdapter
)
from dalle2_pytorch.trackers import Tracker, create_loader, create_logger, create_saver
# helper functions
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def ListOrTuple(inner_type):
return Union[List[inner_type], Tuple[inner_type]]
def SingularOrIterable(inner_type):
return Union[inner_type, ListOrTuple(inner_type)]
# general pydantic classes
class TrainSplitConfig(BaseModel):
train: float = 0.75
val: float = 0.15
test: float = 0.1
@root_validator
def validate_all(cls, fields):
actual_sum = sum([*fields.values()])
if actual_sum != 1.:
raise ValueError(f'{fields.keys()} must sum to 1.0. Found: {actual_sum}')
return fields
class TrackerLogConfig(BaseModel):
log_type: str = 'console'
verbose: bool = False
class Config:
# Each individual log type has it's own arguments that will be passed through the config
extra = "allow"
def create(self, data_path: str):
kwargs = self.dict()
return create_logger(self.log_type, data_path, **kwargs)
class TrackerLoadConfig(BaseModel):
load_from: Optional[str] = None
class Config:
extra = "allow"
def create(self, data_path: str):
kwargs = self.dict()
if self.load_from is None:
return None
return create_loader(self.load_from, data_path, **kwargs)
class TrackerSaveConfig(BaseModel):
save_to: str = 'local'
save_all: bool = False
save_latest: bool = True
save_best: bool = True
class Config:
extra = "allow"
def create(self, data_path: str):
kwargs = self.dict()
return create_saver(self.save_to, data_path, **kwargs)
class TrackerConfig(BaseModel):
data_path: str = '.tracker_data'
overwrite_data_path: bool = False
log: TrackerLogConfig
load: Optional[TrackerLoadConfig]
save: Union[List[TrackerSaveConfig], TrackerSaveConfig]
def create(self, full_config: BaseModel, extra_config: dict, dummy_mode: bool = False) -> Tracker:
tracker = Tracker(self.data_path, dummy_mode=dummy_mode, overwrite_data_path=self.overwrite_data_path)
# Add the logger
tracker.add_logger(self.log.create(self.data_path))
# Add the loader
if self.load is not None:
tracker.add_loader(self.load.create(self.data_path))
# Add the saver or savers
if isinstance(self.save, list):
for save_config in self.save:
tracker.add_saver(save_config.create(self.data_path))
else:
tracker.add_saver(self.save.create(self.data_path))
# Initialize all the components and verify that all data is valid
tracker.init(full_config, extra_config)
return tracker
# diffusion prior pydantic classes
class AdapterConfig(BaseModel):
make: str = "openai"
model: str = "ViT-L/14"
base_model_kwargs: Dict[str, Any] = None
def create(self):
if self.make == "openai":
return OpenAIClipAdapter(self.model)
elif self.make == "x-clip":
return XClipAdapter(XCLIP(**self.base_model_kwargs))
elif self.make == "coca":
return CoCaAdapter(CoCa(**self.base_model_kwargs))
else:
raise AttributeError("No adapter with that name is available.")
class DiffusionPriorNetworkConfig(BaseModel):
dim: int
depth: int
num_timesteps: int = None
num_time_embeds: int = 1
num_image_embeds: int = 1
num_text_embeds: int = 1
dim_head: int = 64
heads: int = 8
ff_mult: int = 4
norm_out: bool = True
attn_dropout: float = 0.
ff_dropout: float = 0.
final_proj: bool = True
normformer: bool = False
rotary_emb: bool = True
def create(self):
kwargs = self.dict()
return DiffusionPriorNetwork(**kwargs)
class DiffusionPriorConfig(BaseModel):
clip: AdapterConfig = None
net: DiffusionPriorNetworkConfig
image_embed_dim: int
image_size: int
image_channels: int = 3
timesteps: int = 1000
cond_drop_prob: float = 0.
loss_type: str = 'l2'
predict_x_start: bool = True
beta_schedule: str = 'cosine'
condition_on_text_encodings: bool = True
class Config:
extra = "allow"
def create(self):
kwargs = self.dict()
has_clip = exists(kwargs.pop('clip'))
kwargs.pop('net')
clip = None
if has_clip:
clip = self.clip.create()
diffusion_prior_network = self.net.create()
return DiffusionPrior(net = diffusion_prior_network, clip = clip, **kwargs)
class DiffusionPriorTrainConfig(BaseModel):
epochs: int = 1
lr: float = 1.1e-4
wd: float = 6.02e-2
max_grad_norm: float = 0.5
use_ema: bool = True
ema_beta: float = 0.99
amp: bool = False
save_every: int = 10000 # what steps to save on
class DiffusionPriorDataConfig(BaseModel):
image_url: str # path to embeddings folder
meta_url: str # path to metadata (captions) for images
splits: TrainSplitConfig
batch_size: int = 64
class DiffusionPriorLoadConfig(BaseModel):
source: str = None
resume: bool = False
class TrainDiffusionPriorConfig(BaseModel):
prior: DiffusionPriorConfig
data: DiffusionPriorDataConfig
train: DiffusionPriorTrainConfig
load: DiffusionPriorLoadConfig
tracker: TrackerConfig
@classmethod
def from_json_path(cls, json_path):
with open(json_path) as f:
config = json.load(f)
return cls(**config)
# decoder pydantic classes
class UnetConfig(BaseModel):
dim: int
dim_mults: ListOrTuple(int)
image_embed_dim: int = None
text_embed_dim: int = None
cond_on_text_encodings: bool = None
cond_dim: int = None
channels: int = 3
self_attn: ListOrTuple(int)
attn_dim_head: int = 32
attn_heads: int = 16
class Config:
extra = "allow"
class DecoderConfig(BaseModel):
unets: ListOrTuple(UnetConfig)
image_size: int = None
image_sizes: ListOrTuple(int) = None
clip: Optional[AdapterConfig] # The clip model to use if embeddings are not provided
channels: int = 3
timesteps: int = 1000
loss_type: str = 'l2'
beta_schedule: ListOrTuple(str) = 'cosine'
learned_variance: bool = True
image_cond_drop_prob: float = 0.1
text_cond_drop_prob: float = 0.5
def create(self):
decoder_kwargs = self.dict()
unet_configs = decoder_kwargs.pop('unets')
unets = [Unet(**config) for config in unet_configs]
has_clip = exists(decoder_kwargs.pop('clip'))
clip = None
if has_clip:
clip = self.clip.create()
return Decoder(unets, clip=clip, **decoder_kwargs)
@validator('image_sizes')
def check_image_sizes(cls, image_sizes, values):
if exists(values.get('image_size')) ^ exists(image_sizes):
return image_sizes
raise ValueError('either image_size or image_sizes is required, but not both')
class Config:
extra = "allow"
class DecoderDataConfig(BaseModel):
webdataset_base_url: str # path to a webdataset with jpg images
img_embeddings_url: Optional[str] # path to .npy files with embeddings
text_embeddings_url: Optional[str] # path to .npy files with embeddings
num_workers: int = 4
batch_size: int = 64
start_shard: int = 0
end_shard: int = 9999999
shard_width: int = 6
index_width: int = 4
splits: TrainSplitConfig
shuffle_train: bool = True
resample_train: bool = False
preprocessing: Dict[str, Any] = {'ToTensor': True}
@property
def img_preproc(self):
def _get_transformation(transformation_name, **kwargs):
if transformation_name == "RandomResizedCrop":
return T.RandomResizedCrop(**kwargs)
elif transformation_name == "RandomHorizontalFlip":
return T.RandomHorizontalFlip()
elif transformation_name == "ToTensor":
return T.ToTensor()
transforms = []
for transform_name, transform_kwargs_or_bool in self.preprocessing.items():
transform_kwargs = {} if not isinstance(transform_kwargs_or_bool, dict) else transform_kwargs_or_bool
transforms.append(_get_transformation(transform_name, **transform_kwargs))
return T.Compose(transforms)
class DecoderTrainConfig(BaseModel):
epochs: int = 20
lr: SingularOrIterable(float) = 1e-4
wd: SingularOrIterable(float) = 0.01
find_unused_parameters: bool = True
max_grad_norm: SingularOrIterable(float) = 0.5
save_every_n_samples: int = 100000
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
device: str = 'cuda:0'
epoch_samples: int = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
validation_samples: int = None # Same as above but for validation.
use_ema: bool = True
ema_beta: float = 0.999
amp: bool = False
unet_training_mask: ListOrTuple(bool) = None # If None, use all unets
class DecoderEvaluateConfig(BaseModel):
n_evaluation_samples: int = 1000
FID: Dict[str, Any] = None
IS: Dict[str, Any] = None
KID: Dict[str, Any] = None
LPIPS: Dict[str, Any] = None
class DecoderLoadConfig(BaseModel):
source: str = None # Supports file and wandb
run_path: str = '' # Used only if source is wandb
file_path: str = '' # The local filepath if source is file. If source is wandb, the relative path to the model file in wandb.
resume: bool = False # If using wandb, whether to resume the run
class TrainDecoderConfig(BaseModel):
decoder: DecoderConfig
data: DecoderDataConfig
train: DecoderTrainConfig
evaluate: DecoderEvaluateConfig
tracker: TrackerConfig
seed: int = 0
@classmethod
def from_json_path(cls, json_path):
with open(json_path) as f:
config = json.load(f)
return cls(**config)
@root_validator
def check_has_embeddings(cls, values):
# Makes sure that enough information is provided to get the embeddings specified for training
data_config, decoder_config = values.get('data'), values.get('decoder')
if not exists(data_config) or not exists(decoder_config):
# Then something else errored and we should just pass through
return values
using_text_embeddings = any([unet.cond_on_text_encodings for unet in decoder_config.unets])
using_clip = exists(decoder_config.clip)
img_emb_url = data_config.img_embeddings_url
text_emb_url = data_config.text_embeddings_url
if using_text_embeddings:
# Then we need some way to get the embeddings
assert using_clip or exists(text_emb_url), 'If text conditioning, either clip or text_embeddings_url must be provided'
if using_clip:
if using_text_embeddings:
assert not exists(text_emb_url) or not exists(img_emb_url), 'Loaded clip, but also provided text_embeddings_url and img_embeddings_url. This is redundant. Remove the clip model or the text embeddings'
else:
assert not exists(img_emb_url), 'Loaded clip, but also provided img_embeddings_url. This is redundant. Remove the clip model or the embeddings'
if text_emb_url:
assert using_text_embeddings, "Text embeddings are being loaded, but text embeddings are not being conditioned on. This will slow down the dataloader for no reason."
return values

View File

@@ -1,6 +1,5 @@
import time
import copy
from pathlib import Path
from math import ceil
from functools import partial, wraps
from collections.abc import Iterable
@@ -11,12 +10,6 @@ from torch.cuda.amp import autocast, GradScaler
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
from dalle2_pytorch.optimizer import get_optimizer
from dalle2_pytorch.version import __version__
from packaging import version
from ema_pytorch import EMA
from accelerate import Accelerator
import numpy as np
@@ -26,9 +19,7 @@ def exists(val):
return val is not None
def default(val, d):
if exists(val):
return val
return d() if callable(d) else d
return val if exists(val) else d
def cast_tuple(val, length = 1):
return val if isinstance(val, tuple) else ((val,) * length)
@@ -137,6 +128,111 @@ def split_args_and_kwargs(*args, split_size = None, **kwargs):
chunk_size_frac = chunk_size / batch_size
yield chunk_size_frac, (chunked_args, chunked_kwargs)
# print helpers
def print_ribbon(s, symbol = '=', repeat = 40):
flank = symbol * repeat
return f'{flank} {s} {flank}'
# saving and loading functions
# for diffusion prior
def load_diffusion_model(dprior_path, device):
dprior_path = Path(dprior_path)
assert dprior_path.exists(), 'Dprior model file does not exist'
loaded_obj = torch.load(str(dprior_path), map_location='cpu')
# Get hyperparameters of loaded model
dpn_config = loaded_obj['hparams']['diffusion_prior_network']
dp_config = loaded_obj['hparams']['diffusion_prior']
image_embed_dim = loaded_obj['image_embed_dim']['image_embed_dim']
# Create DiffusionPriorNetwork and DiffusionPrior with loaded hyperparameters
# DiffusionPriorNetwork
prior_network = DiffusionPriorNetwork( dim = image_embed_dim, **dpn_config).to(device)
# DiffusionPrior with text embeddings and image embeddings pre-computed
diffusion_prior = DiffusionPrior(net = prior_network, **dp_config, image_embed_dim = image_embed_dim).to(device)
# Load state dict from saved model
diffusion_prior.load_state_dict(loaded_obj['model'])
return diffusion_prior, loaded_obj
def save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim):
# Saving State Dict
print_ribbon('Saving checkpoint')
state_dict = dict(model=model.state_dict(),
optimizer=optimizer.state_dict(),
scaler=scaler.state_dict(),
hparams = config,
image_embed_dim = {"image_embed_dim":image_embed_dim})
torch.save(state_dict, save_path+'/'+str(time.time())+'_saved_model.pth')
# exponential moving average wrapper
class EMA(nn.Module):
def __init__(
self,
model,
beta = 0.9999,
update_after_step = 1000,
update_every = 10,
):
super().__init__()
self.beta = beta
self.online_model = model
self.ema_model = copy.deepcopy(model)
self.update_every = update_every
self.update_after_step = update_after_step // update_every # only start EMA after this step number, starting at 0
self.register_buffer('initted', torch.Tensor([False]))
self.register_buffer('step', torch.tensor([0.]))
def restore_ema_model_device(self):
device = self.initted.device
self.ema_model.to(device)
def copy_params_from_model_to_ema(self):
self.ema_model.state_dict(self.online_model.state_dict())
def update(self):
self.step += 1
if (self.step % self.update_every) != 0:
return
if self.step <= self.update_after_step:
self.copy_params_from_model_to_ema()
return
if not self.initted:
self.copy_params_from_model_to_ema()
self.initted.data.copy_(torch.Tensor([True]))
self.update_moving_average(self.ema_model, self.online_model)
def update_moving_average(self, ma_model, current_model):
def calculate_ema(beta, old, new):
if not exists(old):
return new
return old * beta + (1 - beta) * new
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
old_weight, up_weight = ma_params.data, current_params.data
ma_params.data = calculate_ema(self.beta, old_weight, up_weight)
for current_buffer, ma_buffer in zip(current_model.buffers(), ma_model.buffers()):
new_buffer_value = calculate_ema(self.beta, ma_buffer, current_buffer)
ma_buffer.copy_(new_buffer_value)
def __call__(self, *args, **kwargs):
return self.ema_model(*args, **kwargs)
# diffusion prior trainer
def prior_sample_in_chunks(fn):
@@ -159,190 +255,44 @@ class DiffusionPriorTrainer(nn.Module):
eps = 1e-6,
max_grad_norm = None,
amp = False,
group_wd_params = True,
device = None,
accelerator = None,
**kwargs
):
super().__init__()
assert isinstance(diffusion_prior, DiffusionPrior)
assert not exists(accelerator) or isinstance(accelerator, Accelerator)
assert exists(accelerator) or exists(device), "You must supply some method of obtaining a device."
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
# assign some helpful member vars
self.accelerator = accelerator
self.device = accelerator.device if exists(accelerator) else device
self.text_conditioned = diffusion_prior.condition_on_text_encodings
# save model
self.diffusion_prior = diffusion_prior
# exponential moving average
self.use_ema = use_ema
if self.use_ema:
self.ema_diffusion_prior = EMA(diffusion_prior, **ema_kwargs)
# optimizer and mixed precision stuff
self.amp = amp
self.scaler = GradScaler(enabled = amp)
self.optim_kwargs = dict(lr=lr, wd=wd, eps=eps, group_wd_params=group_wd_params)
self.optimizer = get_optimizer(
self.diffusion_prior.parameters(),
**self.optim_kwargs,
diffusion_prior.parameters(),
lr = lr,
wd = wd,
eps = eps,
**kwargs
)
# distribute the model if using HFA
if exists(self.accelerator):
self.diffusion_prior, self.optimizer = self.accelerator.prepare(self.diffusion_prior, self.optimizer)
# exponential moving average stuff
self.use_ema = use_ema
if self.use_ema:
self.ema_diffusion_prior = EMA(self.unwrap_model(self.diffusion_prior), **ema_kwargs)
# gradient clipping if needed
self.max_grad_norm = max_grad_norm
# track steps internally
self.register_buffer('step', torch.tensor([0]))
# accelerator wrappers
def print(self, msg):
if exists(self.accelerator):
self.accelerator.print(msg)
else:
print(msg)
def unwrap_model(self, model):
if exists(self.accelerator):
return self.accelerator.unwrap_model(model)
else:
return model
def wait_for_everyone(self):
if exists(self.accelerator):
self.accelerator.wait_for_everyone()
def is_main_process(self):
if exists(self.accelerator):
return self.accelerator.is_main_process
else:
return True
def clip_grad_norm_(self, *args):
if exists(self.accelerator):
return self.accelerator.clip_grad_norm_(*args)
else:
return torch.nn.utils.clip_grad_norm_(*args)
def backprop(self, x):
if exists(self.accelerator):
self.accelerator.backward(x)
else:
try:
x.backward()
except Exception as e:
self.print(f"Caught error in backprop call: {e}")
# utility
def save(self, path, overwrite = True, **kwargs):
# ensure we sync gradients before continuing
self.wait_for_everyone()
# only save on the main process
if self.is_main_process():
self.print(f"Saving checkpoint at step: {self.step.item()}")
path = Path(path)
assert not (path.exists() and not overwrite)
path.parent.mkdir(parents = True, exist_ok = True)
save_obj = dict(
scaler = self.scaler.state_dict(),
optimizer = self.optimizer.state_dict(),
model = self.unwrap_model(self.diffusion_prior).state_dict(), # unwrap the model from distribution if applicable
version = version.parse(__version__),
step = self.step.item(),
**kwargs
)
if self.use_ema:
save_obj = {
**save_obj,
'ema': self.ema_diffusion_prior.state_dict(),
'ema_model': self.ema_diffusion_prior.ema_model.state_dict() # save the ema model specifically for easy ema-only reload
}
torch.save(save_obj, str(path))
def load(self, path, overwrite_lr = True, strict = True):
"""
Load a checkpoint of a diffusion prior trainer.
Will load the entire trainer, including the optimizer and EMA.
Params:
- path (str): a path to the DiffusionPriorTrainer checkpoint file
- overwrite_lr (bool): wether or not to overwrite the stored LR with the LR specified in the new trainer
- strict (bool): kwarg for `torch.nn.Module.load_state_dict`, will force an exact checkpoint match
Returns:
loaded_obj (dict): The loaded checkpoint dictionary
"""
# all processes need to load checkpoint. no restriction here
path = Path(path)
assert path.exists()
loaded_obj = torch.load(str(path), map_location=self.device)
if version.parse(__version__) != loaded_obj['version']:
print(f'loading saved diffusion prior at version {loaded_obj["version"]} but current package version is at {__version__}')
# unwrap the model when loading from checkpoint
self.unwrap_model(self.diffusion_prior).load_state_dict(loaded_obj['model'], strict = strict)
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
self.scaler.load_state_dict(loaded_obj['scaler'])
self.optimizer.load_state_dict(loaded_obj['optimizer'])
if overwrite_lr:
new_lr = self.optim_kwargs["lr"]
self.print(f"Overriding LR to be {new_lr}")
for group in self.optimizer.param_groups:
group["lr"] = new_lr
if self.use_ema:
assert 'ema' in loaded_obj
self.ema_diffusion_prior.load_state_dict(loaded_obj['ema'], strict = strict)
# below not be necessary, but I had a suspicion that this wasn't being loaded correctly
self.ema_diffusion_prior.ema_model.load_state_dict(loaded_obj["ema_model"])
# sync and inform
self.wait_for_everyone()
self.print(f"Loaded model")
return loaded_obj
# model functionality
self.register_buffer('step', torch.tensor([0.]))
def update(self):
# only continue with updates until all ranks finish
self.wait_for_everyone()
if exists(self.max_grad_norm):
self.scaler.unscale_(self.optimizer)
# utilize HFA clipping where applicable
self.clip_grad_norm_(self.diffusion_prior.parameters(), self.max_grad_norm)
nn.utils.clip_grad_norm_(self.diffusion_prior.parameters(), self.max_grad_norm)
self.scaler.step(self.optimizer)
self.scaler.update()
@@ -357,26 +307,17 @@ class DiffusionPriorTrainer(nn.Module):
@cast_torch_tensor
@prior_sample_in_chunks
def p_sample_loop(self, *args, **kwargs):
model = self.ema_diffusion_prior.ema_model if self.use_ema else self.diffusion_prior
return model.p_sample_loop(*args, **kwargs)
return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs)
@torch.no_grad()
@cast_torch_tensor
@prior_sample_in_chunks
def sample(self, *args, **kwargs):
model = self.ema_diffusion_prior.ema_model if self.use_ema else self.diffusion_prior
return model.sample(*args, **kwargs)
return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs)
@torch.no_grad()
def sample_batch_size(self, *args, **kwargs):
model = self.ema_diffusion_prior.ema_model if self.use_ema else self.diffusion_prior
return model.sample_batch_size(*args, **kwargs)
@torch.no_grad()
@cast_torch_tensor
@prior_sample_in_chunks
def embed_text(self, *args, **kwargs):
return self.unwrap_model(self.diffusion_prior).clip.embed_text(*args, **kwargs)
return self.ema_diffusion_prior.ema_model.sample_batch_size(*args, **kwargs)
@cast_torch_tensor
def forward(
@@ -394,10 +335,8 @@ class DiffusionPriorTrainer(nn.Module):
total_loss += loss.item()
# backprop with accelerate if applicable
if self.training:
self.backprop(self.scaler.scale(loss))
self.scaler.scale(loss).backward()
return total_loss
@@ -423,23 +362,20 @@ class DecoderTrainer(nn.Module):
def __init__(
self,
decoder,
accelerator = None,
use_ema = True,
lr = 1e-4,
wd = 1e-2,
eps = 1e-8,
max_grad_norm = 0.5,
amp = False,
group_wd_params = True,
**kwargs
):
super().__init__()
assert isinstance(decoder, Decoder)
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
self.accelerator = default(accelerator, Accelerator)
self.num_unets = len(decoder.unets)
self.decoder = decoder
self.num_unets = len(self.decoder.unets)
self.use_ema = use_ema
self.ema_unets = nn.ModuleList([])
@@ -451,106 +387,56 @@ class DecoderTrainer(nn.Module):
lr, wd, eps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps))
assert all([unet_lr < 1e-3 for unet_lr in lr]), 'your learning rate is too high, recommend sticking with 1e-4, at most 5e-4'
optimizers = []
for unet, unet_lr, unet_wd, unet_eps in zip(decoder.unets, lr, wd, eps):
for ind, (unet, unet_lr, unet_wd, unet_eps) in enumerate(zip(self.decoder.unets, lr, wd, eps)):
optimizer = get_optimizer(
unet.parameters(),
lr = unet_lr,
wd = unet_wd,
eps = unet_eps,
group_wd_params = group_wd_params,
**kwargs
)
optimizers.append(optimizer)
setattr(self, f'optim{ind}', optimizer) # cannot use pytorch ModuleList for some reason with optimizers
if self.use_ema:
self.ema_unets.append(EMA(unet, **ema_kwargs))
scaler = GradScaler(enabled = amp)
setattr(self, f'scaler{ind}', scaler)
# gradient clipping if needed
self.max_grad_norm = max_grad_norm
self.register_buffer('step', torch.tensor([0.]))
decoder, *optimizers = list(self.accelerator.prepare(decoder, *optimizers))
self.decoder = decoder
for opt_ind, optimizer in zip(range(len(optimizers)), optimizers):
setattr(self, f'optim{opt_ind}', optimizer)
def save(self, path, overwrite = True, **kwargs):
path = Path(path)
assert not (path.exists() and not overwrite)
path.parent.mkdir(parents = True, exist_ok = True)
save_obj = dict(
model = self.accelerator.unwrap_model(self.decoder).state_dict(),
version = __version__,
step = self.step.item(),
**kwargs
)
for ind in range(0, self.num_unets):
optimizer_key = f'optim{ind}'
optimizer = getattr(self, optimizer_key)
save_obj = {**save_obj, optimizer_key: self.accelerator.unwrap_model(optimizer).state_dict()}
if self.use_ema:
save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}
self.accelerator.save(save_obj, str(path))
def load_state_dict(self, loaded_obj, only_model = False, strict = True):
if version.parse(__version__) != version.parse(loaded_obj['version']):
self.accelerator.print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {__version__}')
self.accelerator.unwrap_model(self.decoder).load_state_dict(loaded_obj['model'], strict = strict)
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
if only_model:
return loaded_obj
for ind in range(0, self.num_unets):
optimizer_key = f'optim{ind}'
optimizer = getattr(self, optimizer_key)
self.accelerator.unwrap_model(optimizer).load_state_dict(loaded_obj[optimizer_key])
if self.use_ema:
assert 'ema' in loaded_obj
self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict)
def load(self, path, only_model = False, strict = True):
path = Path(path)
assert path.exists()
loaded_obj = torch.load(str(path), map_location = 'cpu')
self.load_state_dict(loaded_obj, only_model = only_model, strict = strict)
return loaded_obj
@property
def unets(self):
return nn.ModuleList([ema.ema_model for ema in self.ema_unets])
def scale(self, loss, *, unet_number):
assert 1 <= unet_number <= self.num_unets
index = unet_number - 1
scaler = getattr(self, f'scaler{index}')
return scaler.scale(loss)
def update(self, unet_number = None):
if self.num_unets == 1:
unet_number = default(unet_number, 1)
assert exists(unet_number) and 1 <= unet_number <= self.num_unets
index = unet_number - 1
unet = self.decoder.unets[index]
optimizer = getattr(self, f'optim{index}')
scaler = getattr(self, f'scaler{index}')
if exists(self.max_grad_norm):
self.accelerator.clip_grad_norm_(self.decoder.parameters(), self.max_grad_norm) # Automatically unscales gradients
optimizer.step()
scaler.unscale_(optimizer)
nn.utils.clip_grad_norm_(unet.parameters(), self.max_grad_norm)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
if self.use_ema:
@@ -563,17 +449,15 @@ class DecoderTrainer(nn.Module):
@cast_torch_tensor
@decoder_sample_in_chunks
def sample(self, *args, **kwargs):
distributed = self.accelerator.num_processes > 1
base_decoder = self.accelerator.unwrap_model(self.decoder)
if kwargs.pop('use_non_ema', False) or not self.use_ema:
return base_decoder.sample(*args, **kwargs, distributed = distributed)
return self.decoder.sample(*args, **kwargs)
trainable_unets = self.accelerator.unwrap_model(self.decoder).unets
base_decoder.unets = self.unets # swap in exponential moving averaged unets for sampling
trainable_unets = self.decoder.unets
self.decoder.unets = self.unets # swap in exponential moving averaged unets for sampling
output = base_decoder.sample(*args, **kwargs, distributed = distributed)
output = self.decoder.sample(*args, **kwargs)
base_decoder.unets = trainable_unets # restore original training unets
self.decoder.unets = trainable_unets # restore original training unets
# cast the ema_model unets back to original device
for ema in self.ema_unets:
@@ -581,18 +465,6 @@ class DecoderTrainer(nn.Module):
return output
@torch.no_grad()
@cast_torch_tensor
@prior_sample_in_chunks
def embed_text(self, *args, **kwargs):
return self.accelerator.unwrap_model(self.decoder).clip.embed_text(*args, **kwargs)
@torch.no_grad()
@cast_torch_tensor
@prior_sample_in_chunks
def embed_image(self, *args, **kwargs):
return self.accelerator.unwrap_model(self.decoder).clip.embed_image(*args, **kwargs)
@cast_torch_tensor
def forward(
self,
@@ -607,14 +479,13 @@ class DecoderTrainer(nn.Module):
total_loss = 0.
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
# with autocast(enabled = self.amp):
with self.accelerator.autocast():
with autocast(enabled = self.amp):
loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs)
loss = loss * chunk_size_frac
total_loss += loss.item()
if self.training:
self.accelerator.backward(loss)
self.scale(loss, unet_number = unet_number).backward()
return total_loss

View File

@@ -1,35 +0,0 @@
import time
import importlib
# helper functions
def exists(val):
return val is not None
# time helpers
class Timer:
def __init__(self):
self.reset()
def reset(self):
self.last_time = time.time()
def elapsed(self):
return time.time() - self.last_time
# print helpers
def print_ribbon(s, symbol = '=', repeat = 40):
flank = symbol * repeat
return f'{flank} {s} {flank}'
# import helpers
def import_or_print_error(pkg_name, err_str = None):
try:
return importlib.import_module(pkg_name)
except ModuleNotFoundError as e:
if exists(err_str):
print(err_str)
exit()

View File

@@ -1 +0,0 @@
__version__ = '0.15.4'

View File

@@ -68,8 +68,8 @@ def group_dict_by_key(cond, d):
return_val[ind][key] = d[key]
return (*return_val,)
def string_begins_with(prefix, string_input):
return string_input.startswith(prefix)
def string_begins_with(prefix, str):
return str.startswith(prefix)
def group_by_key_prefix(prefix, d):
return group_dict_by_key(partial(string_begins_with, prefix), d)

View File

@@ -16,11 +16,10 @@ from torchvision.utils import make_grid, save_image
from einops import rearrange
from dalle2_pytorch.train import EMA
from dalle2_pytorch.vqgan_vae import VQGanVAE
from dalle2_pytorch.optimizer import get_optimizer
from ema_pytorch import EMA
# helpers
def exists(val):
@@ -98,7 +97,7 @@ class VQGanVAETrainer(nn.Module):
valid_frac = 0.05,
random_split_seed = 42,
ema_beta = 0.995,
ema_update_after_step = 500,
ema_update_after_step = 2000,
ema_update_every = 10,
apply_grad_penalty_every = 4,
amp = False

View File

@@ -1,5 +1,4 @@
from setuptools import setup, find_packages
exec(open('dalle2_pytorch/version.py').read())
setup(
name = 'dalle2-pytorch',
@@ -11,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = __version__,
version = '0.3.4',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',
@@ -24,19 +23,15 @@ setup(
'text to image'
],
install_requires=[
'accelerate',
'click',
'clip-anytorch',
'coca-pytorch>=0.0.5',
'ema-pytorch>=0.0.7',
'einops>=0.4',
'einops-exts>=0.0.3',
'embedding-reader',
'kornia>=0.5.4',
'numpy',
'packaging',
'pillow',
'pydantic',
'resize-right>=0.0.2',
'rotary-embedding-torch',
'torch>=1.10',
@@ -44,9 +39,9 @@ setup(
'tqdm',
'vector-quantize-pytorch',
'x-clip>=0.4.4',
'youtokentome',
'webdataset>=0.2.5',
'fsspec>=2022.1.0',
'torchmetrics[image]>=0.8.0'
'fsspec>=2022.1.0'
],
classifiers=[
'Development Status :: 4 - Beta',

View File

@@ -1,596 +0,0 @@
from pathlib import Path
from typing import List
from dalle2_pytorch.trainer import DecoderTrainer
from dalle2_pytorch.dataloaders import create_image_embedding_dataloader
from dalle2_pytorch.trackers import Tracker
from dalle2_pytorch.train_configs import DecoderConfig, TrainDecoderConfig
from dalle2_pytorch.utils import Timer, print_ribbon
from dalle2_pytorch.dalle2_pytorch import Decoder, resize_image_to
from clip import tokenize
import torchvision
import torch
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore
from torchmetrics.image.kid import KernelInceptionDistance
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
from accelerate import Accelerator, DistributedDataParallelKwargs
from accelerate.utils import dataclasses as accelerate_dataclasses
import webdataset as wds
import click
# constants
TRAIN_CALC_LOSS_EVERY_ITERS = 10
VALID_CALC_LOSS_EVERY_ITERS = 10
# helpers functions
def exists(val):
return val is not None
# main functions
def create_dataloaders(
available_shards,
webdataset_base_url,
img_embeddings_url=None,
text_embeddings_url=None,
shard_width=6,
num_workers=4,
batch_size=32,
n_sample_images=6,
shuffle_train=True,
resample_train=False,
img_preproc = None,
index_width=4,
train_prop = 0.75,
val_prop = 0.15,
test_prop = 0.10,
seed = 0,
**kwargs
):
"""
Randomly splits the available shards into train, val, and test sets and returns a dataloader for each
"""
assert train_prop + test_prop + val_prop == 1
num_train = round(train_prop*len(available_shards))
num_test = round(test_prop*len(available_shards))
num_val = len(available_shards) - num_train - num_test
assert num_train + num_test + num_val == len(available_shards), f"{num_train} + {num_test} + {num_val} = {num_train + num_test + num_val} != {len(available_shards)}"
train_split, test_split, val_split = torch.utils.data.random_split(available_shards, [num_train, num_test, num_val], generator=torch.Generator().manual_seed(seed))
# The shard number in the webdataset file names has a fixed width. We zero pad the shard numbers so they correspond to a filename.
train_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in train_split]
test_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in test_split]
val_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in val_split]
create_dataloader = lambda tar_urls, shuffle=False, resample=False, for_sampling=False: create_image_embedding_dataloader(
tar_url=tar_urls,
num_workers=num_workers,
batch_size=batch_size if not for_sampling else n_sample_images,
img_embeddings_url=img_embeddings_url,
text_embeddings_url=text_embeddings_url,
index_width=index_width,
shuffle_num = None,
extra_keys= ["txt"],
shuffle_shards = shuffle,
resample_shards = resample,
img_preproc=img_preproc,
handler=wds.handlers.warn_and_continue
)
train_dataloader = create_dataloader(train_urls, shuffle=shuffle_train, resample=resample_train)
train_sampling_dataloader = create_dataloader(train_urls, shuffle=False, for_sampling=True)
val_dataloader = create_dataloader(val_urls, shuffle=False)
test_dataloader = create_dataloader(test_urls, shuffle=False)
test_sampling_dataloader = create_dataloader(test_urls, shuffle=False, for_sampling=True)
return {
"train": train_dataloader,
"train_sampling": train_sampling_dataloader,
"val": val_dataloader,
"test": test_dataloader,
"test_sampling": test_sampling_dataloader
}
def get_dataset_keys(dataloader):
"""
It is sometimes neccesary to get the keys the dataloader is returning. Since the dataset is burried in the dataloader, we need to do a process to recover it.
"""
# If the dataloader is actually a WebLoader, we need to extract the real dataloader
if isinstance(dataloader, wds.WebLoader):
dataloader = dataloader.pipeline[0]
return dataloader.dataset.key_map
def get_example_data(dataloader, device, n=5):
"""
Samples the dataloader and returns a zipped list of examples
"""
images = []
img_embeddings = []
text_embeddings = []
captions = []
for img, emb, txt in dataloader:
img_emb, text_emb = emb.get('img'), emb.get('text')
if img_emb is not None:
img_emb = img_emb.to(device=device, dtype=torch.float)
img_embeddings.extend(list(img_emb))
else:
# Then we add None img.shape[0] times
img_embeddings.extend([None]*img.shape[0])
if text_emb is not None:
text_emb = text_emb.to(device=device, dtype=torch.float)
text_embeddings.extend(list(text_emb))
else:
# Then we add None img.shape[0] times
text_embeddings.extend([None]*img.shape[0])
img = img.to(device=device, dtype=torch.float)
images.extend(list(img))
captions.extend(list(txt))
if len(images) >= n:
break
return list(zip(images[:n], img_embeddings[:n], text_embeddings[:n], captions[:n]))
def generate_samples(trainer, example_data, condition_on_text_encodings=False, text_prepend=""):
"""
Takes example data and generates images from the embeddings
Returns three lists: real images, generated images, and captions
"""
real_images, img_embeddings, text_embeddings, txts = zip(*example_data)
sample_params = {}
if img_embeddings[0] is None:
# Generate image embeddings from clip
imgs_tensor = torch.stack(real_images)
img_embeddings, *_ = trainer.embed_image(imgs_tensor)
sample_params["image_embed"] = img_embeddings
else:
# Then we are using precomputed image embeddings
img_embeddings = torch.stack(img_embeddings)
sample_params["image_embed"] = img_embeddings
if condition_on_text_encodings:
if text_embeddings[0] is None:
# Generate text embeddings from text
tokenized_texts = tokenize(txts, truncate=True)
sample_params["text"] = tokenized_texts
else:
# Then we are using precomputed text embeddings
text_embeddings = torch.stack(text_embeddings)
sample_params["text_encodings"] = text_embeddings
samples = trainer.sample(**sample_params)
generated_images = list(samples)
captions = [text_prepend + txt for txt in txts]
return real_images, generated_images, captions
def generate_grid_samples(trainer, examples, condition_on_text_encodings=False, text_prepend=""):
"""
Generates samples and uses torchvision to put them in a side by side grid for easy viewing
"""
real_images, generated_images, captions = generate_samples(trainer, examples, condition_on_text_encodings, text_prepend)
real_image_size = real_images[0].shape[-1]
generated_image_size = generated_images[0].shape[-1]
# training images may be larger than the generated one
if real_image_size > generated_image_size:
real_images = [resize_image_to(image, generated_image_size) for image in real_images]
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
return grid_images, captions
def evaluate_trainer(trainer, dataloader, device, condition_on_text_encodings=False, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
"""
Computes evaluation metrics for the decoder
"""
metrics = {}
# Prepare the data
examples = get_example_data(dataloader, device, n_evaluation_samples)
if len(examples) == 0:
print("No data to evaluate. Check that your dataloader has shards.")
return metrics
real_images, generated_images, captions = generate_samples(trainer, examples, condition_on_text_encodings)
real_images = torch.stack(real_images).to(device=device, dtype=torch.float)
generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float)
# Convert from [0, 1] to [0, 255] and from torch.float to torch.uint8
int_real_images = real_images.mul(255).add(0.5).clamp(0, 255).type(torch.uint8)
int_generated_images = generated_images.mul(255).add(0.5).clamp(0, 255).type(torch.uint8)
def null_sync(t, *args, **kwargs):
return [t]
if exists(FID):
fid = FrechetInceptionDistance(**FID, dist_sync_fn=null_sync)
fid.to(device=device)
fid.update(int_real_images, real=True)
fid.update(int_generated_images, real=False)
metrics["FID"] = fid.compute().item()
if exists(IS):
inception = InceptionScore(**IS, dist_sync_fn=null_sync)
inception.to(device=device)
inception.update(int_real_images)
is_mean, is_std = inception.compute()
metrics["IS_mean"] = is_mean.item()
metrics["IS_std"] = is_std.item()
if exists(KID):
kernel_inception = KernelInceptionDistance(**KID, dist_sync_fn=null_sync)
kernel_inception.to(device=device)
kernel_inception.update(int_real_images, real=True)
kernel_inception.update(int_generated_images, real=False)
kid_mean, kid_std = kernel_inception.compute()
metrics["KID_mean"] = kid_mean.item()
metrics["KID_std"] = kid_std.item()
if exists(LPIPS):
# Convert from [0, 1] to [-1, 1]
renorm_real_images = real_images.mul(2).sub(1)
renorm_generated_images = generated_images.mul(2).sub(1)
lpips = LearnedPerceptualImagePatchSimilarity(**LPIPS, dist_sync_fn=null_sync)
lpips.to(device=device)
lpips.update(renorm_real_images, renorm_generated_images)
metrics["LPIPS"] = lpips.compute().item()
if trainer.accelerator.num_processes > 1:
# Then we should sync the metrics
metrics_order = sorted(metrics.keys())
metrics_tensor = torch.zeros(1, len(metrics), device=device, dtype=torch.float)
for i, metric_name in enumerate(metrics_order):
metrics_tensor[0, i] = metrics[metric_name]
metrics_tensor = trainer.accelerator.gather(metrics_tensor)
metrics_tensor = metrics_tensor.mean(dim=0)
for i, metric_name in enumerate(metrics_order):
metrics[metric_name] = metrics_tensor[i].item()
return metrics
def save_trainer(tracker: Tracker, trainer: DecoderTrainer, epoch: int, sample: int, next_task: str, validation_losses: List[float], samples_seen: int, is_latest=True, is_best=False):
"""
Logs the model with an appropriate method depending on the tracker
"""
tracker.save(trainer, is_best=is_best, is_latest=is_latest, epoch=epoch, sample=sample, next_task=next_task, validation_losses=validation_losses, samples_seen=samples_seen)
def recall_trainer(tracker: Tracker, trainer: DecoderTrainer):
"""
Loads the model with an appropriate method depending on the tracker
"""
trainer.accelerator.print(print_ribbon(f"Loading model from {type(tracker.loader).__name__}"))
state_dict = tracker.recall()
trainer.load_state_dict(state_dict, only_model=False, strict=True)
return state_dict.get("epoch", 0), state_dict.get("validation_losses", []), state_dict.get("next_task", "train"), state_dict.get("sample", 0), state_dict.get("samples_seen", 0)
def train(
dataloaders,
decoder: Decoder,
accelerator: Accelerator,
tracker: Tracker,
inference_device,
evaluate_config=None,
epoch_samples = None, # If the training dataset is resampling, we have to manually stop an epoch
validation_samples = None,
epochs = 20,
n_sample_images = 5,
save_every_n_samples = 100000,
unet_training_mask=None,
condition_on_text_encodings=False,
**kwargs
):
"""
Trains a decoder on a dataset.
"""
is_master = accelerator.process_index == 0
trainer = DecoderTrainer(
decoder=decoder,
accelerator=accelerator,
**kwargs
)
# Set up starting model and parameters based on a recalled state dict
start_epoch = 0
validation_losses = []
next_task = 'train'
sample = 0
samples_seen = 0
val_sample = 0
step = lambda: int(trainer.step.item())
if tracker.loader is not None:
start_epoch, validation_losses, next_task, recalled_sample, samples_seen = recall_trainer(tracker, trainer)
if next_task == 'train':
sample = recalled_sample
if next_task == 'val':
val_sample = recalled_sample
accelerator.print(f"Loaded model from {type(tracker.loader).__name__} on epoch {start_epoch} having seen {samples_seen} samples with minimum validation loss {min(validation_losses) if len(validation_losses) > 0 else 'N/A'}")
accelerator.print(f"Starting training from task {next_task} at sample {sample} and validation sample {val_sample}")
trainer.to(device=inference_device)
if not exists(unet_training_mask):
# Then the unet mask should be true for all unets in the decoder
unet_training_mask = [True] * trainer.num_unets
assert len(unet_training_mask) == trainer.num_unets, f"The unet training mask should be the same length as the number of unets in the decoder. Got {len(unet_training_mask)} and {trainer.num_unets}"
accelerator.print(print_ribbon("Generating Example Data", repeat=40))
accelerator.print("This can take a while to load the shard lists...")
if is_master:
train_example_data = get_example_data(dataloaders["train_sampling"], inference_device, n_sample_images)
accelerator.print("Generated training examples")
test_example_data = get_example_data(dataloaders["test_sampling"], inference_device, n_sample_images)
accelerator.print("Generated testing examples")
send_to_device = lambda arr: [x.to(device=inference_device, dtype=torch.float) for x in arr]
sample_length_tensor = torch.zeros(1, dtype=torch.int, device=inference_device)
unet_losses_tensor = torch.zeros(TRAIN_CALC_LOSS_EVERY_ITERS, trainer.num_unets, dtype=torch.float, device=inference_device)
for epoch in range(start_epoch, epochs):
accelerator.print(print_ribbon(f"Starting epoch {epoch}", repeat=40))
timer = Timer()
last_sample = sample
last_snapshot = sample
if next_task == 'train':
for i, (img, emb, txt) in enumerate(dataloaders["train"]):
# We want to count the total number of samples across all processes
sample_length_tensor[0] = len(img)
all_samples = accelerator.gather(sample_length_tensor) # TODO: accelerator.reduce is broken when this was written. If it is fixed replace this.
total_samples = all_samples.sum().item()
sample += total_samples
samples_seen += total_samples
img_emb = emb.get('img')
has_img_embedding = img_emb is not None
if has_img_embedding:
img_emb, = send_to_device((img_emb,))
text_emb = emb.get('text')
has_text_embedding = text_emb is not None
if has_text_embedding:
text_emb, = send_to_device((text_emb,))
img, = send_to_device((img,))
trainer.train()
for unet in range(1, trainer.num_unets+1):
# Check if this is a unet we are training
if not unet_training_mask[unet-1]: # Unet index is the unet number - 1
continue
forward_params = {}
if has_img_embedding:
forward_params['image_embed'] = img_emb
else:
# Forward pass automatically generates embedding
pass
if condition_on_text_encodings:
if has_text_embedding:
forward_params['text_encodings'] = text_emb
else:
# Then we need to pass the text instead
tokenized_texts = tokenize(txt, truncate=True)
forward_params['text'] = tokenized_texts
loss = trainer.forward(img, **forward_params, unet_number=unet)
trainer.update(unet_number=unet)
unet_losses_tensor[i % TRAIN_CALC_LOSS_EVERY_ITERS, unet-1] = loss
samples_per_sec = (sample - last_sample) / timer.elapsed()
timer.reset()
last_sample = sample
if i % TRAIN_CALC_LOSS_EVERY_ITERS == 0:
# We want to average losses across all processes
unet_all_losses = accelerator.gather(unet_losses_tensor)
mask = unet_all_losses != 0
unet_average_loss = (unet_all_losses * mask).sum(dim=0) / mask.sum(dim=0)
loss_map = { f"Unet {index} Training Loss": loss.item() for index, loss in enumerate(unet_average_loss) if loss != 0 }
# gather decay rate on each UNet
ema_decay_list = {f"Unet {index} EMA Decay": ema_unet.get_current_decay() for index, ema_unet in enumerate(trainer.ema_unets)}
log_data = {
"Epoch": epoch,
"Sample": sample,
"Step": i,
"Samples per second": samples_per_sec,
"Samples Seen": samples_seen,
**ema_decay_list,
**loss_map
}
if is_master:
tracker.log(log_data, step=step())
if is_master and last_snapshot + save_every_n_samples < sample: # This will miss by some amount every time, but it's not a big deal... I hope
# It is difficult to gather this kind of info on the accelerator, so we have to do it on the master
print("Saving snapshot")
last_snapshot = sample
# We need to know where the model should be saved
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, samples_seen)
if exists(n_sample_images) and n_sample_images > 0:
trainer.eval()
train_images, train_captions = generate_grid_samples(trainer, train_example_data, condition_on_text_encodings, "Train: ")
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
if epoch_samples is not None and sample >= epoch_samples:
break
next_task = 'val'
sample = 0
all_average_val_losses = None
if next_task == 'val':
trainer.eval()
accelerator.print(print_ribbon(f"Starting Validation {epoch}", repeat=40))
last_val_sample = val_sample
val_sample_length_tensor = torch.zeros(1, dtype=torch.int, device=inference_device)
average_val_loss_tensor = torch.zeros(1, trainer.num_unets, dtype=torch.float, device=inference_device)
timer = Timer()
accelerator.wait_for_everyone()
i = 0
for i, (img, emb, txt) in enumerate(dataloaders["val"]):
val_sample_length_tensor[0] = len(img)
all_samples = accelerator.gather(val_sample_length_tensor)
total_samples = all_samples.sum().item()
val_sample += total_samples
img_emb = emb.get('img')
has_img_embedding = img_emb is not None
if has_img_embedding:
img_emb, = send_to_device((img_emb,))
text_emb = emb.get('text')
has_text_embedding = text_emb is not None
if has_text_embedding:
text_emb, = send_to_device((text_emb,))
img, = send_to_device((img,))
for unet in range(1, len(decoder.unets)+1):
if not unet_training_mask[unet-1]: # Unet index is the unet number - 1
# No need to evaluate an unchanging unet
continue
forward_params = {}
if has_img_embedding:
forward_params['image_embed'] = img_emb.float()
else:
# Forward pass automatically generates embedding
pass
if condition_on_text_encodings:
if has_text_embedding:
forward_params['text_encodings'] = text_emb.float()
else:
# Then we need to pass the text instead
tokenized_texts = tokenize(txt, truncate=True)
forward_params['text'] = tokenized_texts
loss = trainer.forward(img.float(), **forward_params, unet_number=unet)
average_val_loss_tensor[0, unet-1] += loss
if i % VALID_CALC_LOSS_EVERY_ITERS == 0:
samples_per_sec = (val_sample - last_val_sample) / timer.elapsed()
timer.reset()
last_val_sample = val_sample
accelerator.print(f"Epoch {epoch}/{epochs} Val Step {i} - Sample {val_sample} - {samples_per_sec:.2f} samples/sec")
accelerator.print(f"Loss: {(average_val_loss_tensor / (i+1))}")
accelerator.print("")
if validation_samples is not None and val_sample >= validation_samples:
break
print(f"Rank {accelerator.state.process_index} finished validation after {i} steps")
accelerator.wait_for_everyone()
average_val_loss_tensor /= i+1
# Gather all the average loss tensors
all_average_val_losses = accelerator.gather(average_val_loss_tensor)
if is_master:
unet_average_val_loss = all_average_val_losses.mean(dim=0)
val_loss_map = { f"Unet {index} Validation Loss": loss.item() for index, loss in enumerate(unet_average_val_loss) if loss != 0 }
tracker.log(val_loss_map, step=step())
next_task = 'eval'
if next_task == 'eval':
if exists(evaluate_config):
accelerator.print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config.dict(), condition_on_text_encodings=condition_on_text_encodings)
if is_master:
tracker.log(evaluation, step=step())
next_task = 'sample'
val_sample = 0
if next_task == 'sample':
if is_master:
# Generate examples and save the model if we are the master
# Generate sample images
print(print_ribbon(f"Sampling Set {epoch}", repeat=40))
test_images, test_captions = generate_grid_samples(trainer, test_example_data, condition_on_text_encodings, "Test: ")
train_images, train_captions = generate_grid_samples(trainer, train_example_data, condition_on_text_encodings, "Train: ")
tracker.log_images(test_images, captions=test_captions, image_section="Test Samples", step=step())
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
print(print_ribbon(f"Starting Saving {epoch}", repeat=40))
is_best = False
if all_average_val_losses is not None:
average_loss = all_average_val_losses.mean(dim=0).item()
if len(validation_losses) == 0 or average_loss < min(validation_losses):
is_best = True
validation_losses.append(average_loss)
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, samples_seen, is_best=is_best)
next_task = 'train'
def create_tracker(accelerator: Accelerator, config: TrainDecoderConfig, config_path: str, dummy: bool = False) -> Tracker:
tracker_config = config.tracker
accelerator_config = {
"Distributed": accelerator.distributed_type != accelerate_dataclasses.DistributedType.NO,
"DistributedType": accelerator.distributed_type,
"NumProcesses": accelerator.num_processes,
"MixedPrecision": accelerator.mixed_precision
}
tracker: Tracker = tracker_config.create(config, accelerator_config, dummy_mode=dummy)
tracker.save_config(config_path, config_name='decoder_config.json')
return tracker
def initialize_training(config: TrainDecoderConfig, config_path):
# Make sure if we are not loading, distributed models are initialized to the same values
torch.manual_seed(config.seed)
# Set up accelerator for configurable distributed training
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=config.train.find_unused_parameters)
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
# Set up data
all_shards = list(range(config.data.start_shard, config.data.end_shard + 1))
world_size = accelerator.num_processes
rank = accelerator.process_index
shards_per_process = len(all_shards) // world_size
assert shards_per_process > 0, "Not enough shards to split evenly"
my_shards = all_shards[rank * shards_per_process: (rank + 1) * shards_per_process]
dataloaders = create_dataloaders (
available_shards=my_shards,
img_preproc = config.data.img_preproc,
train_prop = config.data.splits.train,
val_prop = config.data.splits.val,
test_prop = config.data.splits.test,
n_sample_images=config.train.n_sample_images,
**config.data.dict(),
rank = rank,
seed = config.seed,
)
# Create the decoder model and print basic info
decoder = config.decoder.create()
num_parameters = sum(p.numel() for p in decoder.parameters())
# Create and initialize the tracker if we are the master
tracker = create_tracker(accelerator, config, config_path, dummy = rank!=0)
has_img_embeddings = config.data.img_embeddings_url is not None
has_text_embeddings = config.data.text_embeddings_url is not None
conditioning_on_text = any([unet.cond_on_text_encodings for unet in config.decoder.unets])
has_clip_model = config.decoder.clip is not None
data_source_string = ""
if has_img_embeddings:
data_source_string += "precomputed image embeddings"
elif has_clip_model:
data_source_string += "clip image embeddings generation"
else:
raise ValueError("No image embeddings source specified")
if conditioning_on_text:
if has_text_embeddings:
data_source_string += " and precomputed text embeddings"
elif has_clip_model:
data_source_string += " and clip text encoding generation"
else:
raise ValueError("No text embeddings source specified")
accelerator.print(print_ribbon("Loaded Config", repeat=40))
accelerator.print(f"Running training with {accelerator.num_processes} processes and {accelerator.distributed_type} distributed training")
accelerator.print(f"Training using {data_source_string}. {'conditioned on text' if conditioning_on_text else 'not conditioned on text'}")
accelerator.print(f"Number of parameters: {num_parameters}")
train(dataloaders, decoder, accelerator,
tracker=tracker,
inference_device=accelerator.device,
evaluate_config=config.evaluate,
condition_on_text_encodings=conditioning_on_text,
**config.train.dict(),
)
# Create a simple click command line interface to load the config and start the training
@click.command()
@click.option("--config_file", default="./train_decoder_config.json", help="Path to config file")
def main(config_file):
config_file_path = Path(config_file)
config = TrainDecoderConfig.from_json_path(str(config_file_path))
initialize_training(config, config_path=config_file_path)
if __name__ == "__main__":
main()

View File

@@ -1,135 +1,85 @@
# TODO: add start, num_data_points, eval_every and group to config
# TODO: switch back to repo's wandb
START = 0
NUM_DATA_POINTS = 250e6
EVAL_EVERY = 1000
GROUP = "distributed"
import os
from pathlib import Path
import click
import wandb
import torch
from torch import nn
from torch.utils.data import DataLoader
import math
import time
import numpy as np
from accelerate import Accelerator
import torch
import clip
from torch import nn
from dalle2_pytorch.dataloaders import get_reader, make_splits
from dalle2_pytorch.utils import Timer
from dalle2_pytorch.train_configs import (
DiffusionPriorTrainConfig,
TrainDiffusionPriorConfig,
)
from dalle2_pytorch.trackers import BaseTracker, WandbTracker
from dalle2_pytorch import DiffusionPriorTrainer
from dalle2_pytorch.dataloaders import make_splits
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter
from dalle2_pytorch.trainer import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model, print_ribbon
from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker
# helpers
from embedding_reader import EmbeddingReader
from tqdm import tqdm
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
# constants
REPORT_METRICS_EVERY = 250 # for cosine similarity and other metric reporting during training
tracker = WandbTracker()
# helpers functions
def exists(val):
return val is not None
val is not None
class Timer:
def __init__(self):
self.reset()
def make_model(
prior_config, train_config, device: str = None, accelerator: Accelerator = None
):
# create model from config
diffusion_prior = prior_config.create()
def reset(self):
self.last_time = time.time()
# instantiate the trainer
trainer = DiffusionPriorTrainer(
diffusion_prior=diffusion_prior,
lr=train_config.lr,
wd=train_config.wd,
max_grad_norm=train_config.max_grad_norm,
amp=train_config.amp,
use_ema=train_config.use_ema,
device=device,
accelerator=accelerator,
)
def elapsed(self):
return time.time() - self.last_time
return trainer
# functions
# eval functions
def eval_model(
trainer: DiffusionPriorTrainer,
dataloader: DataLoader,
text_conditioned: bool,
loss_type: str,
tracker_context: str,
tracker: BaseTracker = None,
use_ema: bool = True,
):
trainer.eval()
if trainer.is_main_process():
click.secho(f"Measuring performance on {tracker_context}", fg="green", blink=True)
def eval_model(model, dataloader, text_conditioned, loss_type, phase="Validation"):
model.eval()
with torch.no_grad():
total_loss = 0.0
total_samples = 0.0
total_loss = 0.
total_samples = 0.
for image_embeddings, text_data in dataloader:
image_embeddings = image_embeddings.to(trainer.device)
text_data = text_data.to(trainer.device)
for image_embeddings, text_data in tqdm(dataloader):
batches = image_embeddings.shape[0]
input_args = dict(image_embed=image_embeddings)
if text_conditioned:
input_args = dict(**input_args, text=text_data)
input_args = dict(**input_args, text = text_data)
else:
input_args = dict(**input_args, text_embed=text_data)
if use_ema:
loss = trainer.ema_diffusion_prior(**input_args)
else:
loss = trainer(**input_args)
loss = model(**input_args)
total_loss += loss * batches
total_samples += batches
avg_loss = total_loss / total_samples
avg_loss = (total_loss / total_samples)
stats = {f"{tracker_context}-{loss_type}": avg_loss}
trainer.print(stats)
tracker.log({f'{phase} {loss_type}': avg_loss})
if exists(tracker):
tracker.log(stats, step=trainer.step.item() + 1)
def report_cosine_sims(diffusion_prior, dataloader, text_conditioned):
diffusion_prior.eval()
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
def report_cosine_sims(
trainer: DiffusionPriorTrainer,
dataloader: DataLoader,
text_conditioned: bool,
tracker: BaseTracker,
tracker_context: str = "validation",
):
trainer.eval()
if trainer.is_main_process():
click.secho("Measuring Cosine-Similarity", fg="green", blink=True)
for test_image_embeddings, text_data in dataloader:
test_image_embeddings = test_image_embeddings.to(trainer.device)
text_data = text_data.to(trainer.device)
for test_image_embeddings, text_data in tqdm(dataloader):
# we are text conditioned, we produce an embedding from the tokenized text
if text_conditioned:
text_embedding, text_encodings, text_mask = trainer.embed_text(text_data)
text_cond = dict(
text_embed=text_embedding, text_encodings=text_encodings, mask=text_mask
)
text_embedding, text_encodings, text_mask = diffusion_prior.clip.embed_text(
text_data)
text_cond = dict(text_embed=text_embedding,
text_encodings=text_encodings, mask=text_mask)
else:
text_embedding = text_data
text_cond = dict(text_embed=text_embedding)
@@ -140,9 +90,8 @@ def report_cosine_sims(
# roll the text to simulate "unrelated" captions
rolled_idx = torch.roll(torch.arange(text_embedding.shape[0]), 1)
text_embed_shuffled = text_embed_shuffled[rolled_idx]
text_embed_shuffled = text_embed_shuffled / text_embed_shuffled.norm(
dim=1, keepdim=True
)
text_embed_shuffled = text_embed_shuffled / \
text_embed_shuffled.norm(dim=1, keepdim=True)
if text_conditioned:
text_encodings_shuffled = text_encodings[rolled_idx]
@@ -151,276 +100,276 @@ def report_cosine_sims(
text_encodings_shuffled = None
text_mask_shuffled = None
text_cond_shuffled = dict(
text_embed=text_embed_shuffled,
text_encodings=text_encodings_shuffled,
mask=text_mask_shuffled,
)
text_cond_shuffled = dict(text_embed=text_embed_shuffled,
text_encodings=text_encodings_shuffled, mask=text_mask_shuffled)
# prepare the text embedding
text_embed = text_embedding / text_embedding.norm(dim=1, keepdim=True)
# prepare image embeddings
test_image_embeddings = test_image_embeddings / test_image_embeddings.norm(
dim=1, keepdim=True
)
test_image_embeddings = test_image_embeddings / \
test_image_embeddings.norm(dim=1, keepdim=True)
# predict on the unshuffled text embeddings
predicted_image_embeddings = trainer.p_sample_loop(
test_image_embeddings.shape, text_cond
)
predicted_image_embeddings = (
predicted_image_embeddings
/ predicted_image_embeddings.norm(dim=1, keepdim=True)
)
predicted_image_embeddings = diffusion_prior.p_sample_loop(
test_image_embeddings.shape, text_cond)
predicted_image_embeddings = predicted_image_embeddings / \
predicted_image_embeddings.norm(dim=1, keepdim=True)
# predict on the shuffled embeddings
predicted_unrelated_embeddings = trainer.p_sample_loop(
test_image_embeddings.shape, text_cond_shuffled
)
predicted_unrelated_embeddings = (
predicted_unrelated_embeddings
/ predicted_unrelated_embeddings.norm(dim=1, keepdim=True)
)
predicted_unrelated_embeddings = diffusion_prior.p_sample_loop(
test_image_embeddings.shape, text_cond_shuffled)
predicted_unrelated_embeddings = predicted_unrelated_embeddings / \
predicted_unrelated_embeddings.norm(dim=1, keepdim=True)
# calculate similarities
original_similarity = cos(text_embed, test_image_embeddings).cpu().numpy()
predicted_similarity = cos(text_embed, predicted_image_embeddings).cpu().numpy()
unrelated_similarity = (
cos(text_embed, predicted_unrelated_embeddings).cpu().numpy()
)
predicted_img_similarity = (
cos(test_image_embeddings, predicted_image_embeddings).cpu().numpy()
)
stats = {
f"{tracker_context}/baseline similarity": np.mean(original_similarity),
f"{tracker_context}/similarity with text": np.mean(predicted_similarity),
f"{tracker_context}/similarity with original image": np.mean(
predicted_img_similarity
),
f"{tracker_context}/similarity with unrelated caption": np.mean(unrelated_similarity),
f"{tracker_context}/difference from baseline similarity": np.mean(
predicted_similarity - original_similarity
),
}
for k, v in stats.items():
trainer.print(f"{tracker_context}/{k}: {v}")
if exists(tracker):
tracker.log(stats, step=trainer.step.item() + 1)
# training script
def train(
trainer: DiffusionPriorTrainer,
train_loader: DataLoader,
eval_loader: DataLoader,
test_loader: DataLoader,
config: DiffusionPriorTrainConfig,
):
# distributed tracking with wandb
if trainer.accelerator.num_processes > 1:
os.environ["WANDB_START_METHOD"] = "thread"
tracker = wandb.init(
name=f"RANK:{trainer.device}",
entity=config.tracker.wandb_entity,
project=config.tracker.wandb_project,
config=config.dict(),
group=GROUP,
)
# sync after tracker init
trainer.wait_for_everyone()
# init a timer
timer = Timer()
# do training
for img, txt in train_loader:
trainer.train()
current_step = trainer.step.item() + 1
# place data on device
img = img.to(trainer.device)
txt = txt.to(trainer.device)
# pass to model
loss = trainer(text=txt, image_embed=img)
# display & log loss (will only print from main process)
trainer.print(f"Step {current_step}: Loss {loss}")
# perform backprop & apply EMA updates
trainer.update()
# track samples/sec/rank
samples_per_sec = img.shape[0] / timer.elapsed()
# samples seen
samples_seen = (
config.data.batch_size * trainer.accelerator.num_processes * current_step
)
# ema decay
ema_decay = trainer.ema_diffusion_prior.get_current_decay()
# Log on all processes for debugging
tracker.log(
{
"tracking/samples-sec": samples_per_sec,
"tracking/samples-seen": samples_seen,
"tracking/ema-decay": ema_decay,
"metrics/training-loss": loss,
},
step=current_step,
)
# Metric Tracking & Checkpointing (outside of timer's scope)
if current_step % EVAL_EVERY == 0:
eval_model(
trainer=trainer,
dataloader=eval_loader,
text_conditioned=config.prior.condition_on_text_encodings,
loss_type=config.prior.loss_type,
tracker_context="metrics/online-model-validation",
tracker=tracker,
use_ema=False,
)
eval_model(
trainer=trainer,
dataloader=eval_loader,
text_conditioned=config.prior.condition_on_text_encodings,
loss_type=config.prior.loss_type,
tracker_context="metrics/ema-model-validation",
tracker=tracker,
use_ema=True,
)
report_cosine_sims(
trainer=trainer,
dataloader=eval_loader,
text_conditioned=config.prior.condition_on_text_encodings,
tracker=tracker,
tracker_context="metrics",
)
if current_step % config.train.save_every == 0:
trainer.save(f"{config.tracker.data_path}/chkpt_step_{current_step}.pth")
# reset timer for next round
timer.reset()
# evaluate on test data
eval_model(
trainer=trainer,
dataloader=test_loader,
text_conditioned=config.prior.condition_on_text_encodings,
loss_type=config.prior.loss_type,
tracker_context="test",
tracker=tracker,
)
report_cosine_sims(
trainer,
test_loader,
config.prior.condition_on_text_encodings,
tracker,
tracker_context="test",
)
def initialize_training(config, accelerator=None):
"""
Parse the configuration file, and prepare everything necessary for training
"""
# get a device
if accelerator:
device = accelerator.device
click.secho(f"Accelerating on: {device}", fg="yellow")
else:
if torch.cuda.is_available():
click.secho("GPU detected, defaulting to cuda:0", fg="yellow")
device = "cuda:0"
else:
click.secho("No GPU detected...using cpu", fg="yellow")
device = "cpu"
# make the trainer (will automatically distribute if possible & configured)
trainer = make_model(config.prior, config.train, device, accelerator).to(device)
# reload from chcekpoint
if config.load.resume == True:
click.secho(f"Loading checkpoint: {config.load.source}", fg="cyan")
trainer.load(config.load.source)
# fetch and prepare data
if trainer.is_main_process():
click.secho("Grabbing data from source", fg="blue", blink=True)
img_reader = get_reader(
text_conditioned=trainer.text_conditioned,
img_url=config.data.image_url,
meta_url=config.data.meta_url,
)
train_loader, eval_loader, test_loader = make_splits(
text_conditioned=trainer.text_conditioned,
batch_size=config.data.batch_size,
num_data_points=NUM_DATA_POINTS,
train_split=config.data.splits.train,
eval_split=config.data.splits.val,
image_reader=img_reader,
rank=accelerator.state.process_index if exists(accelerator) else 0,
world_size=accelerator.state.num_processes if exists(accelerator) else 1,
start=START,
)
# wait for everyone to load data before continuing
trainer.wait_for_everyone()
# start training
train(
trainer=trainer,
train_loader=train_loader,
eval_loader=eval_loader,
test_loader=test_loader,
config=config,
)
original_similarity = cos(
text_embed, test_image_embeddings).cpu().numpy()
predicted_similarity = cos(
text_embed, predicted_image_embeddings).cpu().numpy()
unrelated_similarity = cos(
text_embed, predicted_unrelated_embeddings).cpu().numpy()
predicted_img_similarity = cos(
test_image_embeddings, predicted_image_embeddings).cpu().numpy()
tracker.log({"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity),
"CosineSimilarity(text_embed,predicted_image_embed)":np.mean(predicted_similarity),
"CosineSimilarity(orig_image_embed,predicted_image_embed)":np.mean(predicted_img_similarity),
"CosineSimilarity(text_embed,predicted_unrelated_embed)": np.mean(unrelated_similarity),
"Cosine similarity difference":np.mean(predicted_similarity - original_similarity)})
@click.command()
@click.option("--hfa", default=True)
@click.option("--config_path", default="configs/prior.json")
def main(hfa, config_path):
# start HFA if requested
if hfa:
accelerator = Accelerator()
@click.option("--wandb-entity", default="laion")
@click.option("--wandb-project", default="diffusion-prior")
@click.option("--wandb-dataset", default="LAION-5B")
@click.option("--wandb-arch", default="DiffusionPrior")
@click.option("--image-embed-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/")
@click.option("--text-embed-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/")
@click.option("--meta-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/laion2B-en-metadata/")
@click.option("--learning-rate", default=1.1e-4)
@click.option("--weight-decay", default=6.02e-2)
@click.option("--dropout", default=5e-2)
@click.option("--max-grad-norm", default=0.5)
@click.option("--num-data-points", default=250e6)
@click.option("--batch-size", default=320)
@click.option("--num-epochs", default=5)
@click.option("--image-embed-dim", default=768)
@click.option("--train-percent", default=0.9)
@click.option("--val-percent", default=1e-7)
@click.option("--test-percent", default=0.0999999)
@click.option("--dpn-depth", default=12)
@click.option("--dpn-dim-head", default=64)
@click.option("--dpn-heads", default=12)
@click.option("--dp-condition-on-text-encodings", default=True)
@click.option("--dp-timesteps", default=1000)
@click.option("--dp-normformer", default=True)
@click.option("--dp-cond-drop-prob", default=0.1)
@click.option("--dp-loss-type", default="l2")
@click.option("--clip", default="ViT-L/14")
@click.option("--amp", default=False)
@click.option("--save-interval", default=120)
@click.option("--save-path", default="./diffusion_prior_checkpoints")
@click.option("--pretrained-model-path", default=None)
@click.option("--gpu-device", default=0)
def train(
wandb_entity,
wandb_project,
wandb_dataset,
wandb_arch,
image_embed_url,
text_embed_url,
meta_url,
learning_rate,
weight_decay,
dropout,
max_grad_norm,
num_data_points,
batch_size,
num_epochs,
image_embed_dim,
train_percent,
val_percent,
test_percent,
dpn_depth,
dpn_dim_head,
dpn_heads,
dp_condition_on_text_encodings,
dp_timesteps,
dp_normformer,
dp_cond_drop_prob,
dp_loss_type,
clip,
amp,
save_interval,
save_path,
pretrained_model_path,
gpu_device
):
config = {
"learning_rate": learning_rate,
"architecture": wandb_arch,
"dataset": wandb_dataset,
"weight_decay": weight_decay,
"max_gradient_clipping_norm": max_grad_norm,
"batch_size": batch_size,
"epochs": num_epochs,
"diffusion_prior_network": {
"depth": dpn_depth,
"dim_head": dpn_dim_head,
"heads": dpn_heads,
"normformer": dp_normformer
},
"diffusion_prior": {
"condition_on_text_encodings": dp_condition_on_text_encodings,
"timesteps": dp_timesteps,
"cond_drop_prob": dp_cond_drop_prob,
"loss_type": dp_loss_type,
"clip": clip
}
}
# Check if DPRIOR_PATH exists(saved model path)
DPRIOR_PATH = pretrained_model_path
RESUME = exists(DPRIOR_PATH)
if not RESUME:
tracker.init(
entity = wandb_entity,
project = wandb_project,
config = config
)
# Obtain the utilized device.
has_cuda = torch.cuda.is_available()
if has_cuda:
device = torch.device(f"cuda:{gpu_device}")
torch.cuda.set_device(device)
# Training loop
# diffusion prior network
prior_network = DiffusionPriorNetwork(
dim = image_embed_dim,
depth = dpn_depth,
dim_head = dpn_dim_head,
heads = dpn_heads,
attn_dropout = dropout,
ff_dropout = dropout,
normformer = dp_normformer
)
# Load clip model if text-conditioning
if dp_condition_on_text_encodings:
clip_adapter = OpenAIClipAdapter(clip)
else:
accelerator = None
clip_adapter = None
# diffusion prior with text embeddings and image embeddings pre-computed
# load the configuration file on main process
if not exists(accelerator) or accelerator.is_main_process:
click.secho(f"Loading configuration from {config_path}", fg="green")
diffusion_prior = DiffusionPrior(
net = prior_network,
clip = clip_adapter,
image_embed_dim = image_embed_dim,
timesteps = dp_timesteps,
cond_drop_prob = dp_cond_drop_prob,
loss_type = dp_loss_type,
condition_on_text_encodings = dp_condition_on_text_encodings
)
config = TrainDiffusionPriorConfig.from_json_path(config_path)
# Load pre-trained model from DPRIOR_PATH
# send config to get processed
initialize_training(config, accelerator)
if RESUME:
diffusion_prior, loaded_obj = load_diffusion_model(DPRIOR_PATH, device)
tracker.init(entity = wandb_entity, project = wandb_project, config = config)
# diffusion prior trainer
trainer = DiffusionPriorTrainer(
diffusion_prior = diffusion_prior,
lr = learning_rate,
wd = weight_decay,
max_grad_norm = max_grad_norm,
amp = amp,
).to(device)
# load optimizer and scaler
if RESUME:
trainer.optimizer.load_state_dict(loaded_obj['optimizer'])
trainer.scaler.load_state_dict(loaded_obj['scaler'])
# Create save_path if it doesn't exist
Path(save_path).mkdir(exist_ok = True, parents = True)
# Utilize wrapper to abstract away loader logic
print_ribbon("Downloading Embeddings")
loader_args = dict(text_conditioned=dp_condition_on_text_encodings, batch_size=batch_size, num_data_points=num_data_points,
train_split=train_percent, eval_split=val_percent, device=device, img_url=image_embed_url)
if dp_condition_on_text_encodings:
loader_args = dict(**loader_args, meta_url=meta_url)
else:
loader_args = dict(**loader_args, txt_url=text_embed_url)
train_loader, eval_loader, test_loader = make_splits(**loader_args)
### Training code ###
step = 1
timer = Timer()
epochs = num_epochs
for _ in range(epochs):
for image, text in tqdm(train_loader):
diffusion_prior.train()
input_args = dict(image_embed=image)
if dp_condition_on_text_encodings:
input_args = dict(**input_args, text = text)
else:
input_args = dict(**input_args, text_embed=text)
loss = trainer(**input_args)
# Samples per second
samples_per_sec = batch_size * step / timer.elapsed()
# Save checkpoint every save_interval minutes
if(int(timer.elapsed()) >= 60 * save_interval):
timer.reset()
save_diffusion_model(
save_path,
diffusion_prior,
trainer.optimizer,
trainer.scaler,
config,
image_embed_dim)
# Log to wandb
tracker.log({"Training loss": loss,
"Steps": step,
"Samples per second": samples_per_sec})
# Log cosineSim(text_embed,predicted_image_embed) - cosineSim(text_embed,image_embed)
# Use NUM_TEST_EMBEDDINGS samples from the test set each time
# Get embeddings from the most recently saved model
if(step % REPORT_METRICS_EVERY) == 0:
report_cosine_sims(diffusion_prior, eval_loader, dp_condition_on_text_encodings)
### Evaluate model(validation run) ###
eval_model(diffusion_prior, eval_loader, dp_condition_on_text_encodings, dp_loss_type, phase="Validation")
step += 1
trainer.update()
### Test run ###
eval_model(diffusion_prior, test_loader, dp_condition_on_text_encodings, dp_loss_type, phase="Test")
if __name__ == "__main__":
main()
train()