Compare commits

..

1 Commits

Author SHA1 Message Date
Phil Wang
88f516b5db fix condition_on_text_encodings in dalle2 orchestrator class, fix readme 2022-07-07 07:42:20 -07:00
31 changed files with 835 additions and 2680 deletions

2
.github/FUNDING.yml vendored
View File

@@ -1 +1 @@
github: [nousr, Veldrovive, lucidrains]
github: [lucidrains]

View File

@@ -1,33 +0,0 @@
name: Continuous integration
on:
push:
branches:
- main
pull_request:
branches:
- main
jobs:
tests:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.8]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install
run: |
python3 -m venv .env
source .env/bin/activate
make install
- name: Tests
run: |
source .env/bin/activate
make test

2
.gitignore vendored
View File

@@ -136,5 +136,3 @@ dmypy.json
# Pyre type checker
.pyre/
.tracker_data
*.pth

View File

@@ -1,6 +0,0 @@
install:
pip install -U pip
pip install -e .
test:
CUDA_VISIBLE_DEVICES= python train_decoder.py --config_file configs/train_decoder_config.test.json

207
README.md
View File

@@ -44,12 +44,9 @@ This library would not have gotten to this working state without the help of
- <a href="https://github.com/krish240574">Kumar</a> for working on the initial diffusion training script
- <a href="https://github.com/rom1504">Romain</a> for the pull request reviews and project management
- <a href="https://github.com/Ciaohe">He Cao</a> and <a href="https://github.com/xiankgx">xiankgx</a> for the Q&A and for identifying of critical bugs
- <a href="https://github.com/marunine">Marunine</a> for identifying issues with resizing of the low resolution conditioner, when training the upsampler, in addition to various other bug fixes
- <a href="https://github.com/malumadev">MalumaDev</a> for proposing the use of pixel shuffle upsampler for fixing checkboard artifacts
- <a href="https://github.com/crowsonkb">Katherine</a> for her advice
- <a href="https://stability.ai/">Stability AI</a> for the generous sponsorship
- <a href="https://huggingface.co">🤗 Huggingface</a> and in particular <a href="https://github.com/sgugger">Sylvain</a> for the <a href="https://github.com/huggingface/accelerate">Accelerate</a> library
- <a href="https://github.com/arogozhnikov">Alex</a> for <a href="https://github.com/arogozhnikov/einops">einops</a>, indispensable tool for tensor manipulation
... and many others. Thank you! 🙏
@@ -357,8 +354,7 @@ prior_network = DiffusionPriorNetwork(
diffusion_prior = DiffusionPrior(
net = prior_network,
clip = clip,
timesteps = 1000,
sample_timesteps = 64,
timesteps = 100,
cond_drop_prob = 0.2
).cuda()
@@ -372,7 +368,6 @@ loss.backward()
unet1 = Unet(
dim = 128,
image_embed_dim = 512,
text_embed_dim = 512,
cond_dim = 128,
channels = 3,
dim_mults=(1, 2, 4, 8),
@@ -397,7 +392,7 @@ decoder = Decoder(
).cuda()
for unet_number in (1, 2):
loss = decoder(images, text = text, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
loss = decoder(images, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
loss.backward()
# do above for many steps
@@ -423,7 +418,7 @@ For the layperson, no worries, training will all be automated into a CLI tool, a
## Training on Preprocessed CLIP Embeddings
It is likely, when scaling up, that you would first preprocess your images and text into corresponding embeddings before training the prior network. You can do so easily by simply passing in `image_embed`, `text_embed`, and optionally `text_encodings`
It is likely, when scaling up, that you would first preprocess your images and text into corresponding embeddings before training the prior network. You can do so easily by simply passing in `image_embed`, `text_embed`, and optionally `text_encodings` and `text_mask`
Working example below
@@ -587,7 +582,6 @@ unet1 = Unet(
cond_dim = 128,
channels = 3,
dim_mults=(1, 2, 4, 8),
text_embed_dim = 512,
cond_on_text_encodings = True # set to True for any unets that need to be conditioned on text encodings (ex. first unet in cascade)
).cuda()
@@ -603,14 +597,13 @@ decoder = Decoder(
unet = (unet1, unet2),
image_sizes = (128, 256),
clip = clip,
timesteps = 1000,
sample_timesteps = (250, 27),
timesteps = 100,
image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5
).cuda()
for unet_number in (1, 2):
loss = decoder(images, text = text, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
loss = decoder(images, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
loss.backward()
# do above for many steps
@@ -628,98 +621,8 @@ images = dalle2(
# save your image (in this example, of size 256x256)
```
Alternatively, you can also use <a href="https://github.com/mlfoundations/open_clip">Open Clip</a>
```bash
$ pip install open-clip-torch
```
Ex. using the <a href="https://laion.ai/blog/large-openclip/">SOTA Open Clip</a> model trained by <a href="https://github.com/rom1504">Romain</a>
```python
from dalle2_pytorch import OpenClipAdapter
clip = OpenClipAdapter('ViT-H/14')
```
Now you'll just have to worry about training the Prior and the Decoder!
## Inpainting
Inpainting is also built into the `Decoder`. You simply have to pass in the `inpaint_image` and `inpaint_mask` (boolean tensor where `True` indicates which regions of the inpaint image to keep)
This repository uses the formulation put forth by <a href="https://arxiv.org/abs/2201.09865">Lugmayr et al. in Repaint</a>
```python
import torch
from dalle2_pytorch import Unet, Decoder, CLIP
# trained clip from step 1
clip = CLIP(
dim_text = 512,
dim_image = 512,
dim_latent = 512,
num_text_tokens = 49408,
text_enc_depth = 6,
text_seq_len = 256,
text_heads = 8,
visual_enc_depth = 6,
visual_image_size = 256,
visual_patch_size = 32,
visual_heads = 8
).cuda()
# 2 unets for the decoder (a la cascading DDPM)
unet = Unet(
dim = 16,
image_embed_dim = 512,
cond_dim = 128,
channels = 3,
dim_mults = (1, 1, 1, 1)
).cuda()
# decoder, which contains the unet(s) and clip
decoder = Decoder(
clip = clip,
unet = (unet,), # insert both unets in order of low resolution to highest resolution (you can have as many stages as you want here)
image_sizes = (256,), # resolutions, 256 for first unet, 512 for second. these must be unique and in ascending order (matches with the unets passed in)
timesteps = 1000,
image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5
).cuda()
# mock images (get a lot of this)
images = torch.randn(4, 3, 256, 256).cuda()
# feed images into decoder, specifying which unet you want to train
# each unet can be trained separately, which is one of the benefits of the cascading DDPM scheme
loss = decoder(images, unet_number = 1)
loss.backward()
# do the above for many steps for both unets
mock_image_embed = torch.randn(1, 512).cuda()
# then to do inpainting
inpaint_image = torch.randn(1, 3, 256, 256).cuda() # (batch, channels, height, width)
inpaint_mask = torch.ones(1, 256, 256).bool().cuda() # (batch, height, width)
inpainted_images = decoder.sample(
image_embed = mock_image_embed,
inpaint_image = inpaint_image, # just pass in the inpaint image
inpaint_mask = inpaint_mask # and the mask
)
inpainted_images.shape # (1, 3, 256, 256)
```
## Experimental
### DALL-E2 with Latent Diffusion
@@ -876,23 +779,25 @@ unet1 = Unet(
text_embed_dim = 512,
cond_dim = 128,
channels = 3,
dim_mults=(1, 2, 4, 8),
cond_on_text_encodings = True,
dim_mults=(1, 2, 4, 8)
).cuda()
unet2 = Unet(
dim = 16,
image_embed_dim = 512,
text_embed_dim = 512,
cond_dim = 128,
channels = 3,
dim_mults = (1, 2, 4, 8, 16),
cond_on_text_encodings = True
).cuda()
decoder = Decoder(
unet = (unet1, unet2),
image_sizes = (128, 256),
clip = clip,
timesteps = 1000
timesteps = 1000,
condition_on_text_encodings = True
).cuda()
decoder_trainer = DecoderTrainer(
@@ -917,8 +822,8 @@ for unet_number in (1, 2):
# after much training
# you can sample from the exponentially moving averaged unets as so
mock_image_embed = torch.randn(32, 512).cuda()
images = decoder_trainer.sample(image_embed = mock_image_embed, text = text) # (4, 3, 256, 256)
mock_image_embed = torch.randn(4, 512).cuda()
images = decoder_trainer.sample(mock_image_embed, text = text) # (4, 3, 256, 256)
```
### Diffusion Prior Training
@@ -1068,7 +973,7 @@ dataloader = create_image_embedding_dataloader(
)
for img, emb in dataloader:
print(img.shape) # torch.Size([32, 3, 256, 256])
print(emb["img"].shape) # torch.Size([32, 512])
print(emb.shape) # torch.Size([32, 512])
# Train decoder only as shown above
# Or create a dataset without a loader so you can configure it manually
@@ -1081,12 +986,26 @@ dataset = ImageEmbeddingDataset(
)
```
### Scripts
### Scripts (wip)
#### `train_diffusion_prior.py`
For detailed information on training the diffusion prior, please refer to the [dedicated readme](prior.md)
## CLI (wip)
```bash
$ dream 'sharing a sunset at the summit of mount everest with my dog'
```
Once built, images will be saved to the same directory the command is invoked
<a href="https://github.com/lucidrains/big-sleep">template</a>
## Training CLI (wip)
<a href="https://github.com/lucidrains/stylegan2-pytorch">template</a>
## Todo
- [x] finish off gaussian diffusion class for latent embedding - allow for prediction of epsilon
@@ -1124,12 +1043,11 @@ For detailed information on training the diffusion prior, please refer to the [d
- [x] bring in skip-layer excitations (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training (doesnt work well)
- [x] test out grid attention in cascading ddpm locally, decide whether to keep or remove https://arxiv.org/abs/2204.01697 (keeping, seems to be fine)
- [x] allow for unet to be able to condition non-cross attention style as well
- [x] speed up inference, read up on papers (ddim)
- [x] add inpainting ability using resampler from repaint paper https://arxiv.org/abs/2201.09865
- [x] add the final combination of upsample feature maps, used in unet squared, seems to have an effect in local experiments
- [ ] consider elucidated dalle2 https://arxiv.org/abs/2206.00364
- [ ] add simple outpainting, text-guided 2x size the image for starters
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
- [ ] speed up inference, read up on papers (ddim or diffusion-gan, etc)
- [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
- [ ] add inpainting ability using resampler from repaint paper https://arxiv.org/abs/2201.09865
## Citations
@@ -1247,65 +1165,4 @@ For detailed information on training the diffusion prior, please refer to the [d
}
```
```bibtex
@article{Lugmayr2022RePaintIU,
title = {RePaint: Inpainting using Denoising Diffusion Probabilistic Models},
author = {Andreas Lugmayr and Martin Danelljan and Andr{\'e}s Romero and Fisher Yu and Radu Timofte and Luc Van Gool},
journal = {ArXiv},
year = {2022},
volume = {abs/2201.09865}
}
```
```bibtex
@misc{chen2022analog,
title = {Analog Bits: Generating Discrete Data using Diffusion Models with Self-Conditioning},
author = {Ting Chen and Ruixiang Zhang and Geoffrey Hinton},
year = {2022},
eprint = {2208.04202},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```
```bibtex
@article{Qiao2019WeightS,
title = {Weight Standardization},
author = {Siyuan Qiao and Huiyu Wang and Chenxi Liu and Wei Shen and Alan Loddon Yuille},
journal = {ArXiv},
year = {2019},
volume = {abs/1903.10520}
}
```
```bibtex
@inproceedings{rogozhnikov2022einops,
title = {Einops: Clear and Reliable Tensor Manipulations with Einstein-like Notation},
author = {Alex Rogozhnikov},
booktitle = {International Conference on Learning Representations},
year = {2022},
url = {https://openreview.net/forum?id=oapKSVM2bcj}
}
```
```bibtex
@article{Sunkara2022NoMS,
title = {No More Strided Convolutions or Pooling: A New CNN Building Block for Low-Resolution Images and Small Objects},
author = {Raja Sunkara and Tie Luo},
journal = {ArXiv},
year = {2022},
volume = {abs/2208.03641}
}
```
```bibtex
@article{Salimans2022ProgressiveDF,
title = {Progressive Distillation for Fast Sampling of Diffusion Models},
author = {Tim Salimans and Jonathan Ho},
journal = {ArXiv},
year = {2022},
volume = {abs/2202.00512}
}
```
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>

View File

@@ -30,7 +30,6 @@ Defines the configuration options for the decoder model. The unets defined above
| `loss_type` | No | `l2` | The loss function. Options are `l1`, `huber`, or `l2`. |
| `beta_schedule` | No | `cosine` | The noising schedule. Options are `cosine`, `linear`, `quadratic`, `jsd`, or `sigmoid`. |
| `learned_variance` | No | `True` | Whether to learn the variance. |
| `clip` | No | `None` | The clip model to use if embeddings are being generated on the fly. Takes keys `make` and `model` with defaults `openai` and `ViT-L/14`. |
Any parameter from the `Decoder` constructor can also be given here.
@@ -40,8 +39,7 @@ Settings for creation of the dataloaders.
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `webdataset_base_url` | Yes | N/A | The url of a shard in the webdataset with the shard replaced with `{}`[^1]. |
| `img_embeddings_url` | No | `None` | The url of the folder containing image embeddings shards. Not required if embeddings are in webdataset or clip is being used. |
| `text_embeddings_url` | No | `None` | The url of the folder containing text embeddings shards. Not required if embeddings are in webdataset or clip is being used. |
| `embeddings_url` | No | N/A | The url of the folder containing embeddings shards. Not required if embeddings are in webdataset. |
| `num_workers` | No | `4` | The number of workers used in the dataloader. |
| `batch_size` | No | `64` | The batch size. |
| `start_shard` | No | `0` | Defines the start of the shard range the dataset will recall. |
@@ -69,12 +67,14 @@ Settings for controlling the training hyperparameters.
| `wd` | No | `0.01` | The weight decay. |
| `max_grad_norm`| No | `0.5` | The grad norm clipping. |
| `save_every_n_samples` | No | `100000` | Samples will be generated and a checkpoint will be saved every `save_every_n_samples` samples. |
| `cond_scale` | No | `1.0` | Conditioning scale to use for sampling. Can also be an array of values, one for each unet. |
| `device` | No | `cuda:0` | The device to train on. |
| `epoch_samples` | No | `None` | Limits the number of samples iterated through in each epoch. This must be set if resampling. None means no limit. |
| `validation_samples` | No | `None` | The number of samples to use for validation. None mean the entire validation set. |
| `use_ema` | No | `True` | Whether to use exponential moving average models for sampling. |
| `ema_beta` | No | `0.99` | The ema coefficient. |
| `save_all` | No | `False` | If True, preserves a checkpoint for every epoch. |
| `save_latest` | No | `True` | If True, overwrites the `latest.pth` every time the model is saved. |
| `save_best` | No | `True` | If True, overwrites the `best.pth` every time the model has a lower validation loss than all previous models. |
| `unet_training_mask` | No | `None` | A boolean array of the same length as the number of unets. If false, the unet is frozen. A value of `None` trains all unets. |
**<ins>Evaluate</ins>:**
@@ -106,13 +106,6 @@ Tracking is split up into three sections:
**Logging:**
All loggers have the following keys:
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `log_type` | Yes | N/A | The type of logger class to use. |
| `resume` | No | `False` | For loggers that have the option to resume an old run, resume it using maually input parameters. |
| `auto_resume` | No | `False` | If true, the logger will attempt to resume an old run using parameters from that previous run. |
If using `console` there is no further configuration than setting `log_type` to `console`.
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
@@ -126,15 +119,10 @@ If using `wandb`
| `wandb_project` | Yes | N/A | The wandb project save the run to. |
| `wandb_run_name` | No | `None` | The wandb run name. |
| `wandb_run_id` | No | `None` | The wandb run id. Used if resuming an old run. |
| `wandb_resume` | No | `False` | Whether to resume an old run. |
**Loading:**
All loaders have the following keys:
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `load_from` | Yes | N/A | The type of loader class to use. |
| `only_auto_resume` | No | `False` | If true, the loader will only load the model if the run is being auto resumed. |
If using `local`
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
@@ -161,10 +149,9 @@ All save locations have these configuration options
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `save_to` | Yes | N/A | Must be `local`, `huggingface`, or `wandb`. |
| `save_latest_to` | No | `None` | Sets the relative path to save the latest model to. |
| `save_best_to` | No | `None` | Sets the relative path to save the best model to every time the model has a lower validation loss than all previous models. |
| `save_meta_to` | No | `None` | The path to save metadata files in. This includes the config files used to start the training. |
| `save_type` | No | `checkpoint` | The type of save. `checkpoint` saves a checkpoint, `model` saves a model without any fluff (Saves with ema if ema is enabled). |
| `save_latest_to` | No | `latest.pth` | Sets the relative path to save the latest model to. |
| `save_best_to` | No | `best.pth` | Sets the relative path to save the best model to every time the model has a lower validation loss than all previous models. |
| `save_type` | No | `'checkpoint'` | The type of save. `'checkpoint'` saves a checkpoint, `'model'` saves a model without any fluff (Saves with ema if ema is enabled). |
If using `local`
| Option | Required | Default | Description |
@@ -176,6 +163,7 @@ If using `huggingface`
| ------ | -------- | ------- | ----------- |
| `save_to` | Yes | N/A | Must be `huggingface`. |
| `huggingface_repo` | Yes | N/A | The huggingface repository to save to. |
| `huggingface_base_path` | Yes | N/A | The base path that checkpoints will be saved under. |
| `token_path` | No | `None` | If logging in with the huggingface cli is not possible, point to a token file instead. |
If using `wandb`

View File

@@ -20,7 +20,7 @@
},
"data": {
"webdataset_base_url": "pipe:s3cmd get s3://bucket/path/{}.tar -",
"img_embeddings_url": "s3://bucket/img_embeddings/path/",
"embeddings_url": "s3://bucket/embeddings/path/",
"num_workers": 4,
"batch_size": 64,
"start_shard": 0,
@@ -56,6 +56,9 @@
"use_ema": true,
"ema_beta": 0.99,
"amp": false,
"save_all": false,
"save_latest": true,
"save_best": true,
"unet_training_mask": [true]
},
"evaluate": {
@@ -93,15 +96,14 @@
},
"save": [{
"save_to": "wandb",
"save_latest_to": "latest.pth"
"save_to": "wandb"
}, {
"save_to": "huggingface",
"huggingface_repo": "Veldrovive/test_model",
"save_latest_to": "path/to/model_dir/latest.pth",
"save_best_to": "path/to/model_dir/best.pth",
"save_meta_to": "path/to/directory/for/assorted/files",
"save_all": true,
"save_latest": true,
"save_best": true,
"save_type": "model"
}]

View File

@@ -1,100 +0,0 @@
{
"decoder": {
"unets": [
{
"dim": 16,
"image_embed_dim": 768,
"cond_dim": 16,
"channels": 3,
"dim_mults": [1, 2, 4, 8],
"attn_dim_head": 16,
"attn_heads": 4,
"self_attn": [false, true, true, true]
}
],
"clip": {
"make": "openai",
"model": "ViT-L/14"
},
"timesteps": 10,
"image_sizes": [64],
"channels": 3,
"loss_type": "l2",
"beta_schedule": ["cosine"],
"learned_variance": true
},
"data": {
"webdataset_base_url": "test_data/{}.tar",
"num_workers": 4,
"batch_size": 4,
"start_shard": 0,
"end_shard": 9,
"shard_width": 1,
"index_width": 1,
"splits": {
"train": 0.75,
"val": 0.15,
"test": 0.1
},
"shuffle_train": false,
"resample_train": true,
"preprocessing": {
"RandomResizedCrop": {
"size": [224, 224],
"scale": [0.75, 1.0],
"ratio": [1.0, 1.0]
},
"ToTensor": true
}
},
"train": {
"epochs": 1,
"lr": 1e-16,
"wd": 0.01,
"max_grad_norm": 0.5,
"save_every_n_samples": 100,
"n_sample_images": 1,
"device": "cpu",
"epoch_samples": 50,
"validation_samples": 5,
"use_ema": true,
"ema_beta": 0.99,
"amp": false,
"unet_training_mask": [true]
},
"evaluate": {
"n_evaluation_samples": 2,
"FID": {
"feature": 64
},
"IS": {
"feature": 64,
"splits": 10
},
"KID": {
"feature": 64,
"subset_size": 2
},
"LPIPS": {
"net_type": "vgg",
"reduction": "mean"
}
},
"tracker": {
"overwrite_data_path": true,
"log": {
"log_type": "console"
},
"load": {
"load_from": null
},
"save": [{
"save_to": "local",
"save_latest_to": "latest.pth"
}]
}
}

View File

@@ -1,14 +1,18 @@
{
"prior": {
"clip": {
"make": "openai",
"model": "ViT-L/14"
"make": "x-clip",
"model": "ViT-L/14",
"base_model_kwargs": {
"dim_text": 768,
"dim_image": 768,
"dim_latent": 768
}
},
"net": {
"dim": 768,
"depth": 12,
"num_timesteps": 1000,
"max_text_len": 77,
"num_time_embeds": 1,
"num_image_embeds": 1,
"num_text_embeds": 1,
@@ -16,8 +20,8 @@
"heads": 12,
"ff_mult": 4,
"norm_out": true,
"attn_dropout": 0.05,
"ff_dropout": 0.05,
"attn_dropout": 0.0,
"ff_dropout": 0.0,
"final_proj": true,
"normformer": true,
"rotary_emb": true
@@ -26,7 +30,6 @@
"image_size": 224,
"image_channels": 3,
"timesteps": 1000,
"sample_timesteps": 64,
"cond_drop_prob": 0.1,
"loss_type": "l2",
"predict_x_start": true,
@@ -34,48 +37,34 @@
"condition_on_text_encodings": true
},
"data": {
"batch_size": 128,
"num_data_points": 100000,
"eval_every_seconds": 1600,
"image_url": "<path to your images>",
"meta_url": "<path to your metadata>",
"image_url": "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/",
"text_url": "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/",
"meta_url": "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/laion2B-en-metadata/",
"batch_size": 256,
"splits": {
"train": 0.8,
"val": 0.1,
"test": 0.1
"train": 0.9,
"val": 1e-7,
"test": 0.0999999
}
},
"train": {
"epochs": 5,
"epochs": 1,
"lr": 1.1e-4,
"wd": 6.02e-2,
"max_grad_norm": 0.5,
"use_ema": true,
"ema_beta": 0.9999,
"ema_update_after_step": 50,
"warmup_steps": 50,
"amp": false,
"save_every_seconds": 3600,
"eval_timesteps": [64, 1000],
"random_seed": 84513
"save_every": 10000
},
"load": {
"source": null,
"resume": false
},
"tracker": {
"data_path": ".prior",
"overwrite_data_path": true,
"log": {
"log_type": "wandb",
"wandb_entity": "<your wandb username>",
"wandb_project": "prior_debugging",
"wandb_resume": false,
"verbose": true
},
"save": [
{
"save_to": "local",
"save_type": "checkpoint",
"save_latest_to": ".prior/latest_checkpoint.pth",
"save_best_to": ".prior/best_checkpoint.pth"
}
]
"tracker_type": "wandb",
"data_path": "./prior_checkpoints",
"wandb_entity": "laion",
"wandb_project": "diffusion-prior",
"verbose": true
}
}

View File

@@ -1,6 +1,6 @@
from dalle2_pytorch.version import __version__
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter, OpenClipAdapter
from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter
from dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer
from dalle2_pytorch.vqgan_vae import VQGanVAE

File diff suppressed because it is too large Load Diff

View File

@@ -1,7 +1,6 @@
import os
import webdataset as wds
import torch
from torch.utils.data import DataLoader
import numpy as np
import fsspec
import shutil
@@ -256,7 +255,7 @@ def create_image_embedding_dataloader(
)
if shuffle_num is not None and shuffle_num > 0:
ds.shuffle(1000)
return DataLoader(
return wds.WebLoader(
ds,
num_workers=num_workers,
batch_size=batch_size,

View File

@@ -67,15 +67,6 @@ class PriorEmbeddingDataset(IterableDataset):
def __str__(self):
return f"<PriorEmbeddingDataset: start: {self.start}, stop: {self.stop}, len: {self.__len__()}>"
def set_start(self, start):
"""
Adjust the starting point within the reader, useful for resuming an epoch
"""
self.start = start
def get_start(self):
return self.start
def get_sample(self):
"""
pre-proocess data from either reader into a common format

View File

@@ -1,18 +1,15 @@
import urllib.request
import os
import json
from pathlib import Path
import shutil
from itertools import zip_longest
from typing import Any, Optional, List, Union
from typing import Optional, List, Union
from pydantic import BaseModel
import torch
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
from dalle2_pytorch.utils import import_or_print_error
from dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer
from dalle2_pytorch.version import __version__
from packaging import version
# constants
@@ -23,6 +20,16 @@ DEFAULT_DATA_PATH = './.tracker-data'
def exists(val):
return val is not None
# load file functions
def load_wandb_file(run_path, file_path, **kwargs):
wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb recall function')
file_reference = wandb.restore(file_path, run_path=run_path)
return file_reference.name
def load_local_file(file_path, **kwargs):
return file_path
class BaseLogger:
"""
An abstract class representing an object that can log data.
@@ -30,17 +37,14 @@ class BaseLogger:
data_path (str): A file path for storing temporary data.
verbose (bool): Whether of not to always print logs to the console.
"""
def __init__(self, data_path: str, resume: bool = False, auto_resume: bool = False, verbose: bool = False, **kwargs):
def __init__(self, data_path: str, verbose: bool = False, **kwargs):
self.data_path = Path(data_path)
self.resume = resume
self.auto_resume = auto_resume
self.verbose = verbose
def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:
"""
Initializes the logger.
Errors if the logger is invalid.
full_config is the config file dict while extra_config is anything else from the script that is not defined the config file.
"""
raise NotImplementedError
@@ -56,14 +60,6 @@ class BaseLogger:
def log_error(self, error_string, **kwargs) -> None:
raise NotImplementedError
def get_resume_data(self, **kwargs) -> dict:
"""
Sets tracker attributes that along with { "resume": True } will be used to resume training.
It is assumed that after init is called this data will be complete.
If the logger does not have any resume functionality, it should return an empty dict.
"""
raise NotImplementedError
class ConsoleLogger(BaseLogger):
def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:
print("Logging to console")
@@ -80,9 +76,6 @@ class ConsoleLogger(BaseLogger):
def log_error(self, error_string, **kwargs) -> None:
print(error_string)
def get_resume_data(self, **kwargs) -> dict:
return {}
class WandbLogger(BaseLogger):
"""
Logs to a wandb run.
@@ -92,6 +85,7 @@ class WandbLogger(BaseLogger):
wandb_project (str): The wandb project to log to.
wandb_run_id (str): The wandb run id to resume.
wandb_run_name (str): The wandb run name to use.
wandb_resume (bool): Whether to resume a wandb run.
"""
def __init__(self,
data_path: str,
@@ -99,6 +93,7 @@ class WandbLogger(BaseLogger):
wandb_project: str,
wandb_run_id: Optional[str] = None,
wandb_run_name: Optional[str] = None,
wandb_resume: bool = False,
**kwargs
):
super().__init__(data_path, **kwargs)
@@ -106,6 +101,7 @@ class WandbLogger(BaseLogger):
self.project = wandb_project
self.run_id = wandb_run_id
self.run_name = wandb_run_name
self.resume = wandb_resume
def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:
assert self.entity is not None, "wandb_entity must be specified for wandb logger"
@@ -153,14 +149,6 @@ class WandbLogger(BaseLogger):
print(error_string)
self.wandb.log({"error": error_string, **kwargs}, step=step)
def get_resume_data(self, **kwargs) -> dict:
# In order to resume, we need wandb_entity, wandb_project, and wandb_run_id
return {
"entity": self.entity,
"project": self.project,
"run_id": self.wandb.run.id
}
logger_type_map = {
'console': ConsoleLogger,
'wandb': WandbLogger,
@@ -180,9 +168,8 @@ class BaseLoader:
Parameters:
data_path (str): A file path for storing temporary data.
"""
def __init__(self, data_path: str, only_auto_resume: bool = False, **kwargs):
def __init__(self, data_path: str, **kwargs):
self.data_path = Path(data_path)
self.only_auto_resume = only_auto_resume
def init(self, logger: BaseLogger, **kwargs) -> None:
raise NotImplementedError
@@ -226,7 +213,7 @@ class LocalLoader(BaseLoader):
def init(self, logger: BaseLogger, **kwargs) -> None:
# Makes sure the file exists to be loaded
if not self.file_path.exists() and not self.only_auto_resume:
if not self.file_path.exists():
raise FileNotFoundError(f'Model not found at {self.file_path}')
def recall(self) -> dict:
@@ -275,9 +262,9 @@ def create_loader(loader_type: str, data_path: str, **kwargs) -> BaseLoader:
class BaseSaver:
def __init__(self,
data_path: str,
save_latest_to: Optional[Union[str, bool]] = None,
save_best_to: Optional[Union[str, bool]] = None,
save_meta_to: Optional[str] = None,
save_latest_to: Optional[Union[str, bool]] = 'latest.pth',
save_best_to: Optional[Union[str, bool]] = 'best.pth',
save_meta_to: str = './',
save_type: str = 'checkpoint',
**kwargs
):
@@ -287,10 +274,10 @@ class BaseSaver:
self.save_best_to = save_best_to
self.saving_best = save_best_to is not None and save_best_to is not False
self.save_meta_to = save_meta_to
self.saving_meta = save_meta_to is not None
self.save_type = save_type
assert save_type in ['checkpoint', 'model'], '`save_type` must be one of `checkpoint` or `model`'
assert self.saving_latest or self.saving_best or self.saving_meta, 'At least one saving option must be specified'
assert self.save_meta_to is not None, '`save_meta_to` must be provided'
assert self.saving_latest or self.saving_best, '`save_latest_to` or `save_best_to` must be provided'
def init(self, logger: BaseLogger, **kwargs) -> None:
raise NotImplementedError
@@ -317,10 +304,6 @@ class LocalSaver(BaseSaver):
def save_file(self, local_path: str, save_path: str, **kwargs) -> None:
# Copy the file to save_path
save_path_file_name = Path(save_path).name
# Make sure parent directory exists
save_path_parent = Path(save_path).parent
if not save_path_parent.exists():
save_path_parent.mkdir(parents=True)
print(f"Saving {save_path_file_name} {self.save_type} to local path {save_path}")
shutil.copy(local_path, save_path)
@@ -402,7 +385,11 @@ class Tracker:
def __init__(self, data_path: Optional[str] = DEFAULT_DATA_PATH, overwrite_data_path: bool = False, dummy_mode: bool = False):
self.data_path = Path(data_path)
if not dummy_mode:
if not overwrite_data_path:
if overwrite_data_path:
if self.data_path.exists():
shutil.rmtree(self.data_path)
self.data_path.mkdir(parents=True)
else:
assert not self.data_path.exists(), f'Data path {self.data_path} already exists. Set overwrite_data_path to True to overwrite.'
if not self.data_path.exists():
self.data_path.mkdir(parents=True)
@@ -411,51 +398,7 @@ class Tracker:
self.savers: List[BaseSaver]= []
self.dummy_mode = dummy_mode
def _load_auto_resume(self) -> bool:
# If the file does not exist, we return False. If autoresume is enabled we print a warning so that the user can know that this is the first run.
if not self.auto_resume_path.exists():
if self.logger.auto_resume:
print("Auto_resume is enabled but no auto_resume.json file exists. Assuming this is the first run.")
return False
# Now we know that the autoresume file exists, but if we are not auto resuming we should remove it so that we don't accidentally load it next time
if not self.logger.auto_resume:
print(f'Removing auto_resume.json because auto_resume is not enabled in the config')
self.auto_resume_path.unlink()
return False
# Otherwise we read the json into a dictionary will will override parts of logger.__dict__
with open(self.auto_resume_path, 'r') as f:
auto_resume_dict = json.load(f)
# Check if the logger is of the same type as the autoresume save
if auto_resume_dict["logger_type"] != self.logger.__class__.__name__:
raise Exception(f'The logger type in the auto_resume file is {auto_resume_dict["logger_type"]} but the current logger is {self.logger.__class__.__name__}. Either use the original logger type, set `auto_resume` to `False`, or delete your existing tracker-data folder.')
# Then we are ready to override the logger with the autoresume save
self.logger.__dict__["resume"] = True
print(f"Updating {self.logger.__dict__} with {auto_resume_dict}")
self.logger.__dict__.update(auto_resume_dict)
return True
def _save_auto_resume(self):
# Gets the autoresume dict from the logger and adds "logger_type" to it then saves it to the auto_resume file
auto_resume_dict = self.logger.get_resume_data()
auto_resume_dict['logger_type'] = self.logger.__class__.__name__
with open(self.auto_resume_path, 'w') as f:
json.dump(auto_resume_dict, f)
def init(self, full_config: BaseModel, extra_config: dict):
self.auto_resume_path = self.data_path / 'auto_resume.json'
# Check for resuming the run
self.did_auto_resume = self._load_auto_resume()
if self.did_auto_resume:
print(f'\n\nWARNING: RUN HAS BEEN AUTO-RESUMED WITH THE LOGGER TYPE {self.logger.__class__.__name__}.\nIf this was not your intention, stop this run and set `auto_resume` to `False` in the config.\n\n')
print(f"New logger config: {self.logger.__dict__}")
self.save_metadata = dict(
version = version.parse(__version__)
) # Data that will be saved alongside the checkpoint or model
self.blacklisted_checkpoint_metadata_keys = ['scaler', 'optimizer', 'model', 'version', 'step', 'steps'] # These keys would cause us to error if we try to save them as metadata
assert self.logger is not None, '`logger` must be set before `init` is called'
if self.dummy_mode:
# The only thing we need is a loader
@@ -463,17 +406,12 @@ class Tracker:
self.loader.init(self.logger)
return
assert len(self.savers) > 0, '`savers` must be set before `init` is called'
self.logger.init(full_config, extra_config)
if self.loader is not None:
self.loader.init(self.logger)
for saver in self.savers:
saver.init(self.logger)
if self.logger.auto_resume:
# Then we need to save the autoresume file. It is assumed after logger.init is called that the logger is ready to be saved.
self._save_auto_resume()
def add_logger(self, logger: BaseLogger):
self.logger = logger
@@ -504,15 +442,8 @@ class Tracker:
# Save the config under config_name in the root folder of data_path
shutil.copy(current_config_path, self.data_path / config_name)
for saver in self.savers:
if saver.saving_meta:
remote_path = Path(saver.save_meta_to) / config_name
saver.save_file(current_config_path, str(remote_path))
def add_save_metadata(self, state_dict_key: str, metadata: Any):
"""
Adds a new piece of metadata that will be saved along with the model or decoder.
"""
self.save_metadata[state_dict_key] = metadata
remote_path = Path(saver.save_meta_to) / config_name
saver.save_file(current_config_path, str(remote_path))
def _save_state_dict(self, trainer: Union[DiffusionPriorTrainer, DecoderTrainer], save_type: str, file_path: str, **kwargs) -> Path:
"""
@@ -522,38 +453,24 @@ class Tracker:
"""
assert save_type in ['checkpoint', 'model']
if save_type == 'checkpoint':
# Create a metadata dict without the blacklisted keys so we do not error when we create the state dict
metadata = {k: v for k, v in self.save_metadata.items() if k not in self.blacklisted_checkpoint_metadata_keys}
trainer.save(file_path, overwrite=True, **kwargs, **metadata)
trainer.save(file_path, overwrite=True, **kwargs)
elif save_type == 'model':
if isinstance(trainer, DiffusionPriorTrainer):
prior = trainer.ema_diffusion_prior.ema_model if trainer.use_ema else trainer.diffusion_prior
prior: DiffusionPrior = trainer.accelerator.unwrap_model(prior)
# Remove CLIP if it is part of the model
original_clip = prior.clip
prior.clip = None
model_state_dict = prior.state_dict()
prior.clip = original_clip
state_dict = trainer.unwrap_model(prior).state_dict()
torch.save(state_dict, file_path)
elif isinstance(trainer, DecoderTrainer):
decoder: Decoder = trainer.accelerator.unwrap_model(trainer.decoder)
# Remove CLIP if it is part of the model
original_clip = decoder.clip
decoder.clip = None
decoder = trainer.accelerator.unwrap_model(trainer.decoder)
if trainer.use_ema:
trainable_unets = decoder.unets
decoder.unets = trainer.unets # Swap EMA unets in
model_state_dict = decoder.state_dict()
state_dict = decoder.state_dict()
decoder.unets = trainable_unets # Swap back
else:
model_state_dict = decoder.state_dict()
decoder.clip = original_clip
state_dict = decoder.state_dict()
torch.save(state_dict, file_path)
else:
raise NotImplementedError('Saving this type of model with EMA mode enabled is not yet implemented. Actually, how did you get here?')
state_dict = {
**self.save_metadata,
'model': model_state_dict
}
torch.save(state_dict, file_path)
return Path(file_path)
def save(self, trainer, is_best: bool, is_latest: bool, **kwargs):
@@ -586,16 +503,11 @@ class Tracker:
self.logger.log_error(f'Error saving checkpoint: {e}', **kwargs)
print(f'Error saving checkpoint: {e}')
@property
def can_recall(self):
# Defines whether a recall can be performed.
return self.loader is not None and (not self.loader.only_auto_resume or self.did_auto_resume)
def recall(self):
if self.can_recall:
if self.loader is not None:
return self.loader.recall()
else:
raise ValueError('Tried to recall, but no loader was set or auto-resume was not performed.')
raise ValueError('No loader specified')

View File

@@ -1,16 +1,14 @@
import json
from torchvision import transforms as T
from pydantic import BaseModel, validator, model_validator
from typing import List, Optional, Union, Tuple, Dict, Any, TypeVar
from pydantic import BaseModel, validator, root_validator
from typing import List, Iterable, Optional, Union, Tuple, Dict, Any
from x_clip import CLIP as XCLIP
from open_clip import list_pretrained
from coca_pytorch import CoCa
from dalle2_pytorch.dalle2_pytorch import (
CoCaAdapter,
OpenAIClipAdapter,
OpenClipAdapter,
Unet,
Decoder,
DiffusionPrior,
@@ -27,9 +25,11 @@ def exists(val):
def default(val, d):
return val if exists(val) else d
InnerType = TypeVar('InnerType')
ListOrTuple = Union[List[InnerType], Tuple[InnerType]]
SingularOrIterable = Union[InnerType, ListOrTuple[InnerType]]
def ListOrTuple(inner_type):
return Union[List[inner_type], Tuple[inner_type]]
def SingularOrIterable(inner_type):
return Union[inner_type, ListOrTuple(inner_type)]
# general pydantic classes
@@ -38,17 +38,15 @@ class TrainSplitConfig(BaseModel):
val: float = 0.15
test: float = 0.1
@model_validator(mode = 'after')
def validate_all(self, m):
actual_sum = sum([*dict(self).values()])
@root_validator
def validate_all(cls, fields):
actual_sum = sum([*fields.values()])
if actual_sum != 1.:
raise ValueError(f'{dict(self).keys()} must sum to 1.0. Found: {actual_sum}')
return self
raise ValueError(f'{fields.keys()} must sum to 1.0. Found: {actual_sum}')
return fields
class TrackerLogConfig(BaseModel):
log_type: str = 'console'
resume: bool = False # For logs that are saved to unique locations, resume a previous run
auto_resume: bool = False # If the process crashes and restarts, resume from the run that crashed
verbose: bool = False
class Config:
@@ -59,10 +57,8 @@ class TrackerLogConfig(BaseModel):
kwargs = self.dict()
return create_logger(self.log_type, data_path, **kwargs)
class TrackerLoadConfig(BaseModel):
load_from: Optional[str] = None
only_auto_resume: bool = False # Only attempt to load if the logger is auto-resuming
class Config:
extra = "allow"
@@ -90,7 +86,7 @@ class TrackerConfig(BaseModel):
data_path: str = '.tracker_data'
overwrite_data_path: bool = False
log: TrackerLogConfig
load: Optional[TrackerLoadConfig] = None
load: Optional[TrackerLoadConfig]
save: Union[List[TrackerSaveConfig], TrackerSaveConfig]
def create(self, full_config: BaseModel, extra_config: dict, dummy_mode: bool = False) -> Tracker:
@@ -115,15 +111,11 @@ class TrackerConfig(BaseModel):
class AdapterConfig(BaseModel):
make: str = "openai"
model: str = "ViT-L/14"
base_model_kwargs: Optional[Dict[str, Any]] = None
base_model_kwargs: Dict[str, Any] = None
def create(self):
if self.make == "openai":
return OpenAIClipAdapter(self.model)
elif self.make == "open_clip":
pretrained = dict(list_pretrained())
checkpoint = pretrained[self.model]
return OpenClipAdapter(name=self.model, pretrained=checkpoint)
elif self.make == "x-clip":
return XClipAdapter(XCLIP(**self.base_model_kwargs))
elif self.make == "coca":
@@ -134,15 +126,13 @@ class AdapterConfig(BaseModel):
class DiffusionPriorNetworkConfig(BaseModel):
dim: int
depth: int
max_text_len: Optional[int] = None
num_timesteps: Optional[int] = None
num_timesteps: int = None
num_time_embeds: int = 1
num_image_embeds: int = 1
num_text_embeds: int = 1
dim_head: int = 64
heads: int = 8
ff_mult: int = 4
norm_in: bool = False
norm_out: bool = True
attn_dropout: float = 0.
ff_dropout: float = 0.
@@ -150,21 +140,17 @@ class DiffusionPriorNetworkConfig(BaseModel):
normformer: bool = False
rotary_emb: bool = True
class Config:
extra = "allow"
def create(self):
kwargs = self.dict()
return DiffusionPriorNetwork(**kwargs)
class DiffusionPriorConfig(BaseModel):
clip: Optional[AdapterConfig] = None
clip: AdapterConfig = None
net: DiffusionPriorNetworkConfig
image_embed_dim: int
image_size: int
image_channels: int = 3
timesteps: int = 1000
sample_timesteps: Optional[int] = None
cond_drop_prob: float = 0.
loss_type: str = 'l2'
predict_x_start: bool = True
@@ -195,26 +181,23 @@ class DiffusionPriorTrainConfig(BaseModel):
use_ema: bool = True
ema_beta: float = 0.99
amp: bool = False
warmup_steps: Optional[int] = None # number of warmup steps
save_every_seconds: int = 3600 # how often to save
eval_timesteps: List[int] = [64] # which sampling timesteps to evaluate with
best_validation_loss: float = 1e9 # the current best valudation loss observed
current_epoch: int = 0 # the current epoch
num_samples_seen: int = 0 # the current number of samples seen
random_seed: int = 0 # manual seed for torch
save_every: int = 10000 # what steps to save on
class DiffusionPriorDataConfig(BaseModel):
image_url: str # path to embeddings folder
meta_url: str # path to metadata (captions) for images
splits: TrainSplitConfig # define train, validation, test splits for your dataset
batch_size: int # per-gpu batch size used to train the model
num_data_points: int = 25e7 # total number of datapoints to train on
eval_every_seconds: int = 3600 # validation statistics will be performed this often
image_url: str # path to embeddings folder
meta_url: str # path to metadata (captions) for images
splits: TrainSplitConfig
batch_size: int = 64
class DiffusionPriorLoadConfig(BaseModel):
source: str = None
resume: bool = False
class TrainDiffusionPriorConfig(BaseModel):
prior: DiffusionPriorConfig
data: DiffusionPriorDataConfig
train: DiffusionPriorTrainConfig
load: DiffusionPriorLoadConfig
tracker: TrackerConfig
@classmethod
@@ -227,31 +210,29 @@ class TrainDiffusionPriorConfig(BaseModel):
class UnetConfig(BaseModel):
dim: int
dim_mults: ListOrTuple[int]
image_embed_dim: Optional[int] = None
text_embed_dim: Optional[int] = None
cond_on_text_encodings: Optional[bool] = None
cond_dim: Optional[int] = None
dim_mults: ListOrTuple(int)
image_embed_dim: int = None
text_embed_dim: int = None
cond_on_text_encodings: bool = None
cond_dim: int = None
channels: int = 3
self_attn: SingularOrIterable[bool] = False
self_attn: ListOrTuple(int)
attn_dim_head: int = 32
attn_heads: int = 16
init_cross_embed: bool = True
class Config:
extra = "allow"
class DecoderConfig(BaseModel):
unets: ListOrTuple[UnetConfig]
image_size: Optional[int] = None
image_sizes: ListOrTuple[int] = None
clip: Optional[AdapterConfig] = None # The clip model to use if embeddings are not provided
unets: ListOrTuple(UnetConfig)
image_size: int = None
image_sizes: ListOrTuple(int) = None
clip: Optional[AdapterConfig] # The clip model to use if embeddings are not provided
channels: int = 3
timesteps: int = 1000
sample_timesteps: Optional[SingularOrIterable[Optional[int]]] = None
loss_type: str = 'l2'
beta_schedule: Optional[ListOrTuple[str]] = None # None means all cosine
learned_variance: SingularOrIterable[bool] = True
beta_schedule: ListOrTuple(str) = 'cosine'
learned_variance: bool = True
image_cond_drop_prob: float = 0.1
text_cond_drop_prob: float = 0.5
@@ -278,9 +259,9 @@ class DecoderConfig(BaseModel):
extra = "allow"
class DecoderDataConfig(BaseModel):
webdataset_base_url: str # path to a webdataset with jpg images
img_embeddings_url: Optional[str] = None # path to .npy files with embeddings
text_embeddings_url: Optional[str] = None # path to .npy files with embeddings
webdataset_base_url: str # path to a webdataset with jpg images
img_embeddings_url: Optional[str] # path to .npy files with embeddings
text_embeddings_url: Optional[str] # path to .npy files with embeddings
num_workers: int = 4
batch_size: int = 64
start_shard: int = 0
@@ -310,30 +291,33 @@ class DecoderDataConfig(BaseModel):
class DecoderTrainConfig(BaseModel):
epochs: int = 20
lr: SingularOrIterable[float] = 1e-4
wd: SingularOrIterable[float] = 0.01
warmup_steps: Optional[SingularOrIterable[int]] = None
lr: SingularOrIterable(float) = 1e-4
wd: SingularOrIterable(float) = 0.01
warmup_steps: Optional[SingularOrIterable(int)] = None
find_unused_parameters: bool = True
static_graph: bool = True
max_grad_norm: SingularOrIterable[float] = 0.5
max_grad_norm: SingularOrIterable(float) = 0.5
save_every_n_samples: int = 100000
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
cond_scale: Union[float, List[float]] = 1.0
device: str = 'cuda:0'
epoch_samples: Optional[int] = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
validation_samples: Optional[int] = None # Same as above but for validation.
save_immediately: bool = False
epoch_samples: int = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
validation_samples: int = None # Same as above but for validation.
use_ema: bool = True
ema_beta: float = 0.999
amp: bool = False
unet_training_mask: Optional[ListOrTuple[bool]] = None # If None, use all unets
unet_training_mask: ListOrTuple(bool) = None # If None, use all unets
class DecoderEvaluateConfig(BaseModel):
n_evaluation_samples: int = 1000
FID: Optional[Dict[str, Any]] = None
IS: Optional[Dict[str, Any]] = None
KID: Optional[Dict[str, Any]] = None
LPIPS: Optional[Dict[str, Any]] = None
FID: Dict[str, Any] = None
IS: Dict[str, Any] = None
KID: Dict[str, Any] = None
LPIPS: Dict[str, Any] = None
class DecoderLoadConfig(BaseModel):
source: str = None # Supports file and wandb
run_path: str = '' # Used only if source is wandb
file_path: str = '' # The local filepath if source is file. If source is wandb, the relative path to the model file in wandb.
resume: bool = False # If using wandb, whether to resume the run
class TrainDecoderConfig(BaseModel):
decoder: DecoderConfig
@@ -347,14 +331,11 @@ class TrainDecoderConfig(BaseModel):
def from_json_path(cls, json_path):
with open(json_path) as f:
config = json.load(f)
print(config)
return cls(**config)
@model_validator(mode = 'after')
def check_has_embeddings(self, m):
@root_validator
def check_has_embeddings(cls, values):
# Makes sure that enough information is provided to get the embeddings specified for training
values = dict(self)
data_config, decoder_config = values.get('data'), values.get('decoder')
if not exists(data_config) or not exists(decoder_config):
@@ -379,4 +360,4 @@ class TrainDecoderConfig(BaseModel):
if text_emb_url:
assert using_text_embeddings, "Text embeddings are being loaded, but text embeddings are not being conditioned on. This will slow down the dataloader for no reason."
return m
return values

View File

@@ -9,7 +9,7 @@ from collections.abc import Iterable
import torch
import torch.nn.functional as F
from torch import nn
from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR
from torch.optim.lr_scheduler import LambdaLR
from torch.cuda.amp import autocast, GradScaler
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
@@ -21,7 +21,7 @@ import pytorch_warmup as warmup
from ema_pytorch import EMA
from accelerate import Accelerator, DistributedType
from accelerate import Accelerator
import numpy as np
@@ -76,7 +76,6 @@ def cast_torch_tensor(fn):
def inner(model, *args, **kwargs):
device = kwargs.pop('_device', next(model.parameters()).device)
cast_device = kwargs.pop('_cast_device', True)
cast_deepspeed_precision = kwargs.pop('_cast_deepspeed_precision', True)
kwargs_keys = kwargs.keys()
all_args = (*args, *kwargs.values())
@@ -86,21 +85,6 @@ def cast_torch_tensor(fn):
if cast_device:
all_args = tuple(map(lambda t: t.to(device) if exists(t) and isinstance(t, torch.Tensor) else t, all_args))
if cast_deepspeed_precision:
try:
accelerator = model.accelerator
if accelerator is not None and accelerator.distributed_type == DistributedType.DEEPSPEED:
cast_type_map = {
"fp16": torch.half,
"bf16": torch.bfloat16,
"no": torch.float
}
precision_type = cast_type_map[accelerator.mixed_precision]
all_args = tuple(map(lambda t: t.to(precision_type) if exists(t) and isinstance(t, torch.Tensor) else t, all_args))
except AttributeError:
# Then this model doesn't have an accelerator
pass
args, kwargs_values = all_args[:split_kwargs_index], all_args[split_kwargs_index:]
kwargs = dict(tuple(zip(kwargs_keys, kwargs_values)))
@@ -174,25 +158,26 @@ class DiffusionPriorTrainer(nn.Module):
def __init__(
self,
diffusion_prior,
accelerator = None,
use_ema = True,
lr = 3e-4,
wd = 1e-2,
eps = 1e-6,
max_grad_norm = None,
amp = False,
group_wd_params = True,
warmup_steps = None,
cosine_decay_max_steps = None,
device = None,
accelerator = None,
verbose = True,
**kwargs
):
super().__init__()
assert isinstance(diffusion_prior, DiffusionPrior)
assert not exists(accelerator) or isinstance(accelerator, Accelerator)
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
accelerator_kwargs, kwargs = groupby_prefix_and_trim('accelerator_', kwargs)
if not exists(accelerator):
accelerator = Accelerator(**accelerator_kwargs)
# verbosity
self.verbose = verbose
# assign some helpful member vars
@@ -201,31 +186,22 @@ class DiffusionPriorTrainer(nn.Module):
# setting the device
self.device = accelerator.device
diffusion_prior.to(self.device)
if not exists(accelerator) and not exists(device):
diffusion_prior_device = next(diffusion_prior.parameters()).device
self.print(f'accelerator not given, and device not specified: defaulting to device of diffusion prior parameters - {diffusion_prior_device}')
self.device = diffusion_prior_device
else:
self.device = accelerator.device if exists(accelerator) else device
# save model
self.diffusion_prior = diffusion_prior
# mixed precision checks
# optimizer and mixed precision stuff
if (
exists(self.accelerator)
and self.accelerator.distributed_type == DistributedType.DEEPSPEED
and self.diffusion_prior.clip is not None
):
# Then we need to make sure clip is using the correct precision or else deepspeed will error
cast_type_map = {
"fp16": torch.half,
"bf16": torch.bfloat16,
"no": torch.float
}
precision_type = cast_type_map[accelerator.mixed_precision]
assert precision_type == torch.float, "DeepSpeed currently only supports float32 precision when using on the fly embedding generation from clip"
self.diffusion_prior.clip.to(precision_type)
self.amp = amp
# optimizer stuff
self.scaler = GradScaler(enabled = amp)
self.optim_kwargs = dict(lr=lr, wd=wd, eps=eps, group_wd_params=group_wd_params)
@@ -235,23 +211,16 @@ class DiffusionPriorTrainer(nn.Module):
**kwargs
)
if exists(cosine_decay_max_steps):
self.scheduler = CosineAnnealingLR(self.optimizer, T_max = cosine_decay_max_steps)
else:
self.scheduler = LambdaLR(self.optimizer, lr_lambda = lambda _: 1.0)
self.warmup_scheduler = warmup.LinearWarmup(self.optimizer, warmup_period = warmup_steps) if exists(warmup_steps) else None
# distribute the model if using HFA
self.diffusion_prior, self.optimizer, self.scheduler = self.accelerator.prepare(self.diffusion_prior, self.optimizer, self.scheduler)
if exists(self.accelerator):
self.diffusion_prior, self.optimizer = self.accelerator.prepare(self.diffusion_prior, self.optimizer)
# exponential moving average stuff
self.use_ema = use_ema
if self.use_ema:
self.ema_diffusion_prior = EMA(self.accelerator.unwrap_model(self.diffusion_prior), **ema_kwargs)
self.ema_diffusion_prior = EMA(self.unwrap_model(self.diffusion_prior), **ema_kwargs)
# gradient clipping if needed
@@ -261,25 +230,67 @@ class DiffusionPriorTrainer(nn.Module):
self.register_buffer('step', torch.tensor([0], device = self.device))
# accelerator wrappers
def print(self, msg):
if not self.verbose:
return
if exists(self.accelerator):
self.accelerator.print(msg)
else:
print(msg)
def unwrap_model(self, model):
if exists(self.accelerator):
return self.accelerator.unwrap_model(model)
else:
return model
def wait_for_everyone(self):
if exists(self.accelerator):
self.accelerator.wait_for_everyone()
def is_main_process(self):
if exists(self.accelerator):
return self.accelerator.is_main_process
else:
return True
def clip_grad_norm_(self, *args):
if exists(self.accelerator):
return self.accelerator.clip_grad_norm_(*args)
else:
return torch.nn.utils.clip_grad_norm_(*args)
def backprop(self, x):
if exists(self.accelerator):
self.accelerator.backward(x)
else:
try:
x.backward()
except Exception as e:
self.print(f"Caught error in backprop call: {e}")
# utility
def save(self, path, overwrite = True, **kwargs):
# ensure we sync gradients before continuing
self.wait_for_everyone()
# only save on the main process
if self.accelerator.is_main_process:
print(f"Saving checkpoint at step: {self.step.item()}")
if self.is_main_process():
self.print(f"Saving checkpoint at step: {self.step.item()}")
path = Path(path)
assert not (path.exists() and not overwrite)
path.parent.mkdir(parents = True, exist_ok = True)
# FIXME: LambdaLR can't be saved due to pickling issues
save_obj = dict(
scaler = self.scaler.state_dict(),
optimizer = self.optimizer.state_dict(),
scheduler = self.scheduler.state_dict(),
warmup_scheduler = self.warmup_scheduler,
model = self.accelerator.unwrap_model(self.diffusion_prior).state_dict(),
model = self.unwrap_model(self.diffusion_prior).state_dict(), # unwrap the model from distribution if applicable
version = version.parse(__version__),
step = self.step,
step = self.step.item(),
**kwargs
)
@@ -292,14 +303,14 @@ class DiffusionPriorTrainer(nn.Module):
torch.save(save_obj, str(path))
def load(self, path_or_state, overwrite_lr = True, strict = True):
def load(self, path, overwrite_lr = True, strict = True):
"""
Load a checkpoint of a diffusion prior trainer.
Will load the entire trainer, including the optimizer and EMA.
Params:
- path_or_state (str | torch): a path to the DiffusionPriorTrainer checkpoint file
- path (str): a path to the DiffusionPriorTrainer checkpoint file
- overwrite_lr (bool): wether or not to overwrite the stored LR with the LR specified in the new trainer
- strict (bool): kwarg for `torch.nn.Module.load_state_dict`, will force an exact checkpoint match
@@ -308,58 +319,55 @@ class DiffusionPriorTrainer(nn.Module):
"""
# all processes need to load checkpoint. no restriction here
if isinstance(path_or_state, str):
path = Path(path_or_state)
assert path.exists()
loaded_obj = torch.load(str(path), map_location=self.device)
path = Path(path)
assert path.exists()
elif isinstance(path_or_state, dict):
loaded_obj = path_or_state
loaded_obj = torch.load(str(path), map_location=self.device)
if version.parse(__version__) != loaded_obj['version']:
print(f'loading saved diffusion prior at version {loaded_obj["version"]} but current package version is at {__version__}')
# unwrap the model when loading from checkpoint
self.accelerator.unwrap_model(self.diffusion_prior).load_state_dict(loaded_obj['model'], strict = strict)
self.step.copy_(torch.ones_like(self.step, device=self.device) * loaded_obj['step'].to(self.device))
self.unwrap_model(self.diffusion_prior).load_state_dict(loaded_obj['model'], strict = strict)
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
self.scaler.load_state_dict(loaded_obj['scaler'])
self.optimizer.load_state_dict(loaded_obj['optimizer'])
self.scheduler.load_state_dict(loaded_obj['scheduler'])
# set warmupstep
if exists(self.warmup_scheduler):
self.warmup_scheduler.last_step = self.step.item()
# ensure new lr is used if different from old one
if overwrite_lr:
new_lr = self.optim_kwargs["lr"]
self.print(f"Overriding LR to be {new_lr}")
for group in self.optimizer.param_groups:
group["lr"] = new_lr if group["lr"] > 0.0 else 0.0
group["lr"] = new_lr
if self.use_ema:
assert 'ema' in loaded_obj
self.ema_diffusion_prior.load_state_dict(loaded_obj['ema'], strict = strict)
# below might not be necessary, but I had a suspicion that this wasn't being loaded correctly
# below not be necessary, but I had a suspicion that this wasn't being loaded correctly
self.ema_diffusion_prior.ema_model.load_state_dict(loaded_obj["ema_model"])
# sync and inform
self.wait_for_everyone()
self.print(f"Loaded model")
return loaded_obj
# model functionality
def update(self):
# only continue with updates until all ranks finish
self.wait_for_everyone()
if exists(self.max_grad_norm):
self.accelerator.clip_grad_norm_(self.diffusion_prior.parameters(), self.max_grad_norm)
self.optimizer.step()
self.optimizer.zero_grad()
self.scaler.unscale_(self.optimizer)
# utilize HFA clipping where applicable
self.clip_grad_norm_(self.diffusion_prior.parameters(), self.max_grad_norm)
# accelerator will ocassionally skip optimizer steps in a "dynamic loss scaling strategy"
if not self.accelerator.optimizer_step_was_skipped:
sched_context = self.warmup_scheduler.dampening if exists(self.warmup_scheduler) else nullcontext
with sched_context():
self.scheduler.step()
self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.zero_grad()
if self.use_ema:
self.ema_diffusion_prior.update()
@@ -389,7 +397,7 @@ class DiffusionPriorTrainer(nn.Module):
@cast_torch_tensor
@prior_sample_in_chunks
def embed_text(self, *args, **kwargs):
return self.accelerator.unwrap_model(self.diffusion_prior).clip.embed_text(*args, **kwargs)
return self.unwrap_model(self.diffusion_prior).clip.embed_text(*args, **kwargs)
@cast_torch_tensor
def forward(
@@ -401,14 +409,16 @@ class DiffusionPriorTrainer(nn.Module):
total_loss = 0.
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
with self.accelerator.autocast():
with autocast(enabled = self.amp):
loss = self.diffusion_prior(*chunked_args, **chunked_kwargs)
loss = loss * chunk_size_frac
total_loss += loss.item()
# backprop with accelerate if applicable
if self.training:
self.accelerator.backward(loss)
self.backprop(self.scaler.scale(loss))
return total_loss
@@ -435,13 +445,11 @@ class DecoderTrainer(nn.Module):
self,
decoder,
accelerator = None,
dataloaders = None,
use_ema = True,
lr = 1e-4,
wd = 1e-2,
eps = 1e-8,
warmup_steps = None,
cosine_decay_max_steps = None,
max_grad_norm = 0.5,
amp = False,
group_wd_params = True,
@@ -463,7 +471,7 @@ class DecoderTrainer(nn.Module):
# be able to finely customize learning rate, weight decay
# per unet
lr, wd, eps, warmup_steps, cosine_decay_max_steps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps, warmup_steps, cosine_decay_max_steps))
lr, wd, eps, warmup_steps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps, warmup_steps))
assert all([unet_lr <= 1e-2 for unet_lr in lr]), 'your learning rate is too high, recommend sticking with 1e-4, at most 5e-4'
@@ -471,32 +479,24 @@ class DecoderTrainer(nn.Module):
schedulers = []
warmup_schedulers = []
for unet, unet_lr, unet_wd, unet_eps, unet_warmup_steps, unet_cosine_decay_max_steps in zip(decoder.unets, lr, wd, eps, warmup_steps, cosine_decay_max_steps):
if isinstance(unet, nn.Identity):
optimizers.append(None)
schedulers.append(None)
warmup_schedulers.append(None)
else:
optimizer = get_optimizer(
unet.parameters(),
lr = unet_lr,
wd = unet_wd,
eps = unet_eps,
group_wd_params = group_wd_params,
**kwargs
)
for unet, unet_lr, unet_wd, unet_eps, unet_warmup_steps in zip(decoder.unets, lr, wd, eps, warmup_steps):
optimizer = get_optimizer(
unet.parameters(),
lr = unet_lr,
wd = unet_wd,
eps = unet_eps,
group_wd_params = group_wd_params,
**kwargs
)
optimizers.append(optimizer)
optimizers.append(optimizer)
if exists(unet_cosine_decay_max_steps):
scheduler = CosineAnnealingLR(optimizer, T_max = unet_cosine_decay_max_steps)
else:
scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)
scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)
warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) if exists(unet_warmup_steps) else None
warmup_schedulers.append(warmup_scheduler)
warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) if exists(unet_warmup_steps) else None
warmup_schedulers.append(warmup_scheduler)
schedulers.append(scheduler)
schedulers.append(scheduler)
if self.use_ema:
self.ema_unets.append(EMA(unet, **ema_kwargs))
@@ -507,31 +507,11 @@ class DecoderTrainer(nn.Module):
self.register_buffer('steps', torch.tensor([0] * self.num_unets))
if self.accelerator.distributed_type == DistributedType.DEEPSPEED and decoder.clip is not None:
# Then we need to make sure clip is using the correct precision or else deepspeed will error
cast_type_map = {
"fp16": torch.half,
"bf16": torch.bfloat16,
"no": torch.float
}
precision_type = cast_type_map[accelerator.mixed_precision]
assert precision_type == torch.float, "DeepSpeed currently only supports float32 precision when using on the fly embedding generation from clip"
clip = decoder.clip
clip.to(precision_type)
decoder, *optimizers = list(self.accelerator.prepare(decoder, *optimizers))
schedulers = list(self.accelerator.prepare(*schedulers))
self.decoder = decoder
# prepare dataloaders
train_loader = val_loader = None
if exists(dataloaders):
train_loader, val_loader = self.accelerator.prepare(dataloaders["train"], dataloaders["val"])
self.train_loader = train_loader
self.val_loader = val_loader
# store optimizers
for opt_ind, optimizer in zip(range(len(optimizers)), optimizers):
@@ -546,17 +526,6 @@ class DecoderTrainer(nn.Module):
self.warmup_schedulers = warmup_schedulers
def validate_and_return_unet_number(self, unet_number = None):
if self.num_unets == 1:
unet_number = default(unet_number, 1)
assert exists(unet_number) and 1 <= unet_number <= self.num_unets
return unet_number
def num_steps_taken(self, unet_number = None):
unet_number = self.validate_and_return_unet_number(unet_number)
return self.steps[unet_number - 1].item()
def save(self, path, overwrite = True, **kwargs):
path = Path(path)
assert not (path.exists() and not overwrite)
@@ -571,15 +540,8 @@ class DecoderTrainer(nn.Module):
for ind in range(0, self.num_unets):
optimizer_key = f'optim{ind}'
scheduler_key = f'sched{ind}'
optimizer = getattr(self, optimizer_key)
scheduler = getattr(self, scheduler_key)
optimizer_state_dict = optimizer.state_dict() if exists(optimizer) else None
scheduler_state_dict = scheduler.state_dict() if exists(scheduler) else None
save_obj = {**save_obj, optimizer_key: optimizer_state_dict, scheduler_key: scheduler_state_dict}
save_obj = {**save_obj, optimizer_key: self.accelerator.unwrap_model(optimizer).state_dict()}
if self.use_ema:
save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}
@@ -600,17 +562,9 @@ class DecoderTrainer(nn.Module):
optimizer_key = f'optim{ind}'
optimizer = getattr(self, optimizer_key)
scheduler_key = f'sched{ind}'
scheduler = getattr(self, scheduler_key)
warmup_scheduler = self.warmup_schedulers[ind]
if exists(optimizer):
optimizer.load_state_dict(loaded_obj[optimizer_key])
if exists(scheduler):
scheduler.load_state_dict(loaded_obj[scheduler_key])
self.accelerator.unwrap_model(optimizer).load_state_dict(loaded_obj[optimizer_key])
if exists(warmup_scheduler):
warmup_scheduler.last_step = last_step
@@ -640,7 +594,10 @@ class DecoderTrainer(nn.Module):
self.steps += F.one_hot(unet_index_tensor, num_classes = len(self.steps))
def update(self, unet_number = None):
unet_number = self.validate_and_return_unet_number(unet_number)
if self.num_unets == 1:
unet_number = default(unet_number, 1)
assert exists(unet_number) and 1 <= unet_number <= self.num_unets
index = unet_number - 1
optimizer = getattr(self, f'optim{index}')
@@ -670,14 +627,8 @@ class DecoderTrainer(nn.Module):
def sample(self, *args, **kwargs):
distributed = self.accelerator.num_processes > 1
base_decoder = self.accelerator.unwrap_model(self.decoder)
was_training = base_decoder.training
base_decoder.eval()
if kwargs.pop('use_non_ema', False) or not self.use_ema:
out = base_decoder.sample(*args, **kwargs, distributed = distributed)
base_decoder.train(was_training)
return out
return base_decoder.sample(*args, **kwargs, distributed = distributed)
trainable_unets = self.accelerator.unwrap_model(self.decoder).unets
base_decoder.unets = self.unets # swap in exponential moving averaged unets for sampling
@@ -690,7 +641,6 @@ class DecoderTrainer(nn.Module):
for ema in self.ema_unets:
ema.restore_ema_model_device()
base_decoder.train(was_training)
return output
@torch.no_grad()
@@ -711,32 +661,21 @@ class DecoderTrainer(nn.Module):
*args,
unet_number = None,
max_batch_size = None,
return_lowres_cond_image=False,
**kwargs
):
unet_number = self.validate_and_return_unet_number(unet_number)
if self.num_unets == 1:
unet_number = default(unet_number, 1)
total_loss = 0.
cond_images = []
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
with self.accelerator.autocast():
loss_obj = self.decoder(*chunked_args, unet_number = unet_number, return_lowres_cond_image=return_lowres_cond_image, **chunked_kwargs)
# loss_obj may be a tuple with loss and cond_image
if return_lowres_cond_image:
loss, cond_image = loss_obj
else:
loss = loss_obj
cond_image = None
loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs)
loss = loss * chunk_size_frac
if cond_image is not None:
cond_images.append(cond_image)
total_loss += loss.item()
if self.training:
self.accelerator.backward(loss)
if return_lowres_cond_image:
return total_loss, torch.stack(cond_images)
else:
return total_loss
return total_loss

View File

@@ -1 +1 @@
__version__ = '1.15.6'
__version__ = '0.16.15'

View File

@@ -11,7 +11,8 @@ import torch.nn.functional as F
from torch.autograd import grad as torch_grad
import torchvision
from einops import rearrange, reduce, repeat, pack, unpack
from einops import rearrange, reduce, repeat
from einops_exts import rearrange_many
from einops.layers.torch import Rearrange
# constants
@@ -407,7 +408,7 @@ class Attention(nn.Module):
x = self.norm(x)
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = h)
q = q * self.scale
sim = einsum('b h i d, b h j d -> b h i j', q, k)

View File

@@ -26,17 +26,17 @@ setup(
install_requires=[
'accelerate',
'click',
'open-clip-torch>=2.0.0,<3.0.0',
'clip-anytorch>=2.5.2',
'clip-anytorch',
'coca-pytorch>=0.0.5',
'ema-pytorch>=0.0.7',
'einops>=0.7.0',
'einops>=0.4',
'einops-exts>=0.0.3',
'embedding-reader',
'kornia>=0.5.4',
'numpy',
'packaging',
'pillow',
'pydantic>=2',
'pydantic',
'pytorch-warmup',
'resize-right>=0.0.2',
'rotary-embedding-torch',

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -1,6 +1,5 @@
from pathlib import Path
from typing import List
from datetime import timedelta
from dalle2_pytorch.trainer import DecoderTrainer
from dalle2_pytorch.dataloaders import create_image_embedding_dataloader
@@ -12,12 +11,11 @@ from clip import tokenize
import torchvision
import torch
from torch import nn
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore
from torchmetrics.image.kid import KernelInceptionDistance
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
from accelerate import Accelerator, DistributedDataParallelKwargs, InitProcessGroupKwargs
from accelerate import Accelerator, DistributedDataParallelKwargs
from accelerate.utils import dataclasses as accelerate_dataclasses
import webdataset as wds
import click
@@ -134,7 +132,7 @@ def get_example_data(dataloader, device, n=5):
break
return list(zip(images[:n], img_embeddings[:n], text_embeddings[:n], captions[:n]))
def generate_samples(trainer, example_data, clip=None, start_unet=1, end_unet=None, condition_on_text_encodings=False, cond_scale=1.0, device=None, text_prepend="", match_image_size=True):
def generate_samples(trainer, example_data, condition_on_text_encodings=False, text_prepend=""):
"""
Takes example data and generates images from the embeddings
Returns three lists: real images, generated images, and captions
@@ -144,9 +142,7 @@ def generate_samples(trainer, example_data, clip=None, start_unet=1, end_unet=No
if img_embeddings[0] is None:
# Generate image embeddings from clip
imgs_tensor = torch.stack(real_images)
assert clip is not None, "clip is None, but img_embeddings is None"
imgs_tensor.to(device=device)
img_embeddings, img_encoding = clip.embed_image(imgs_tensor)
img_embeddings, *_ = trainer.embed_image(imgs_tensor)
sample_params["image_embed"] = img_embeddings
else:
# Then we are using precomputed image embeddings
@@ -155,38 +151,34 @@ def generate_samples(trainer, example_data, clip=None, start_unet=1, end_unet=No
if condition_on_text_encodings:
if text_embeddings[0] is None:
# Generate text embeddings from text
assert clip is not None, "clip is None, but text_embeddings is None"
tokenized_texts = tokenize(txts, truncate=True).to(device=device)
text_embed, text_encodings = clip.embed_text(tokenized_texts)
sample_params["text_encodings"] = text_encodings
tokenized_texts = tokenize(txts, truncate=True)
sample_params["text"] = tokenized_texts
else:
# Then we are using precomputed text embeddings
text_embeddings = torch.stack(text_embeddings)
sample_params["text_encodings"] = text_embeddings
sample_params["start_at_unet_number"] = start_unet
sample_params["stop_at_unet_number"] = end_unet
if start_unet > 1:
# If we are only training upsamplers
sample_params["image"] = torch.stack(real_images)
if device is not None:
sample_params["_device"] = device
samples = trainer.sample(**sample_params, _cast_deepspeed_precision=False) # At sampling time we don't want to cast to FP16
samples = trainer.sample(**sample_params)
generated_images = list(samples)
captions = [text_prepend + txt for txt in txts]
if match_image_size:
generated_image_size = generated_images[0].shape[-1]
real_images = [resize_image_to(image, generated_image_size, clamp_range=(0, 1)) for image in real_images]
return real_images, generated_images, captions
def generate_grid_samples(trainer, examples, clip=None, start_unet=1, end_unet=None, condition_on_text_encodings=False, cond_scale=1.0, device=None, text_prepend=""):
def generate_grid_samples(trainer, examples, condition_on_text_encodings=False, text_prepend=""):
"""
Generates samples and uses torchvision to put them in a side by side grid for easy viewing
"""
real_images, generated_images, captions = generate_samples(trainer, examples, clip, start_unet, end_unet, condition_on_text_encodings, cond_scale, device, text_prepend)
real_images, generated_images, captions = generate_samples(trainer, examples, condition_on_text_encodings, text_prepend)
real_image_size = real_images[0].shape[-1]
generated_image_size = generated_images[0].shape[-1]
# training images may be larger than the generated one
if real_image_size > generated_image_size:
real_images = [resize_image_to(image, generated_image_size) for image in real_images]
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
return grid_images, captions
def evaluate_trainer(trainer, dataloader, device, start_unet, end_unet, clip=None, condition_on_text_encodings=False, cond_scale=1.0, inference_device=None, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
def evaluate_trainer(trainer, dataloader, device, condition_on_text_encodings=False, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
"""
Computes evaluation metrics for the decoder
"""
@@ -196,7 +188,7 @@ def evaluate_trainer(trainer, dataloader, device, start_unet, end_unet, clip=Non
if len(examples) == 0:
print("No data to evaluate. Check that your dataloader has shards.")
return metrics
real_images, generated_images, captions = generate_samples(trainer, examples, clip, start_unet, end_unet, condition_on_text_encodings, cond_scale, inference_device)
real_images, generated_images, captions = generate_samples(trainer, examples, condition_on_text_encodings)
real_images = torch.stack(real_images).to(device=device, dtype=torch.float)
generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float)
# Convert from [0, 1] to [0, 255] and from torch.float to torch.uint8
@@ -229,8 +221,8 @@ def evaluate_trainer(trainer, dataloader, device, start_unet, end_unet, clip=Non
metrics["KID_std"] = kid_std.item()
if exists(LPIPS):
# Convert from [0, 1] to [-1, 1]
renorm_real_images = real_images.mul(2).sub(1).clamp(-1,1)
renorm_generated_images = generated_images.mul(2).sub(1).clamp(-1,1)
renorm_real_images = real_images.mul(2).sub(1)
renorm_generated_images = generated_images.mul(2).sub(1)
lpips = LearnedPerceptualImagePatchSimilarity(**LPIPS, dist_sync_fn=null_sync)
lpips.to(device=device)
lpips.update(renorm_real_images, renorm_generated_images)
@@ -269,17 +261,14 @@ def train(
accelerator: Accelerator,
tracker: Tracker,
inference_device,
clip=None,
evaluate_config=None,
epoch_samples = None, # If the training dataset is resampling, we have to manually stop an epoch
validation_samples = None,
save_immediately=False,
epochs = 20,
n_sample_images = 5,
save_every_n_samples = 100000,
unet_training_mask=None,
condition_on_text_encodings=False,
cond_scale=1.0,
**kwargs
):
"""
@@ -287,25 +276,9 @@ def train(
"""
is_master = accelerator.process_index == 0
if not exists(unet_training_mask):
# Then the unet mask should be true for all unets in the decoder
unet_training_mask = [True] * len(decoder.unets)
assert len(unet_training_mask) == len(decoder.unets), f"The unet training mask should be the same length as the number of unets in the decoder. Got {len(unet_training_mask)} and {trainer.num_unets}"
trainable_unet_numbers = [i+1 for i, trainable in enumerate(unet_training_mask) if trainable]
first_trainable_unet = trainable_unet_numbers[0]
last_trainable_unet = trainable_unet_numbers[-1]
def move_unets(unet_training_mask):
for i in range(len(decoder.unets)):
if not unet_training_mask[i]:
# Replace the unet from the module list with a nn.Identity(). This training script never uses unets that aren't being trained so this is fine.
decoder.unets[i] = nn.Identity().to(inference_device)
# Remove non-trainable unets
move_unets(unet_training_mask)
trainer = DecoderTrainer(
decoder=decoder,
accelerator=accelerator,
dataloaders=dataloaders,
**kwargs
)
@@ -316,9 +289,9 @@ def train(
sample = 0
samples_seen = 0
val_sample = 0
step = lambda: int(trainer.num_steps_taken(unet_number=first_trainable_unet))
step = lambda: int(trainer.step.item())
if tracker.can_recall:
if tracker.loader is not None:
start_epoch, validation_losses, next_task, recalled_sample, samples_seen = recall_trainer(tracker, trainer)
if next_task == 'train':
sample = recalled_sample
@@ -328,6 +301,11 @@ def train(
accelerator.print(f"Starting training from task {next_task} at sample {sample} and validation sample {val_sample}")
trainer.to(device=inference_device)
if not exists(unet_training_mask):
# Then the unet mask should be true for all unets in the decoder
unet_training_mask = [True] * trainer.num_unets
assert len(unet_training_mask) == trainer.num_unets, f"The unet training mask should be the same length as the number of unets in the decoder. Got {len(unet_training_mask)} and {trainer.num_unets}"
accelerator.print(print_ribbon("Generating Example Data", repeat=40))
accelerator.print("This can take a while to load the shard lists...")
if is_master:
@@ -376,20 +354,15 @@ def train(
forward_params['image_embed'] = img_emb
else:
# Forward pass automatically generates embedding
assert clip is not None
img_embed, img_encoding = clip.embed_image(img)
forward_params['image_embed'] = img_embed
pass
if condition_on_text_encodings:
if has_text_embedding:
forward_params['text_encodings'] = text_emb
else:
# Then we need to pass the text instead
assert clip is not None
tokenized_texts = tokenize(txt, truncate=True).to(inference_device)
assert tokenized_texts.shape[0] == len(img), f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})"
text_embed, text_encodings = clip.embed_text(tokenized_texts)
forward_params['text_encodings'] = text_encodings
loss = trainer.forward(img, **forward_params, unet_number=unet, _device=inference_device)
tokenized_texts = tokenize(txt, truncate=True)
forward_params['text'] = tokenized_texts
loss = trainer.forward(img, **forward_params, unet_number=unet)
trainer.update(unet_number=unet)
unet_losses_tensor[i % TRAIN_CALC_LOSS_EVERY_ITERS, unet-1] = loss
@@ -402,10 +375,10 @@ def train(
unet_all_losses = accelerator.gather(unet_losses_tensor)
mask = unet_all_losses != 0
unet_average_loss = (unet_all_losses * mask).sum(dim=0) / mask.sum(dim=0)
loss_map = { f"Unet {index} Training Loss": loss.item() for index, loss in enumerate(unet_average_loss) if unet_training_mask[index] }
loss_map = { f"Unet {index} Training Loss": loss.item() for index, loss in enumerate(unet_average_loss) if loss != 0 }
# gather decay rate on each UNet
ema_decay_list = {f"Unet {index} EMA Decay": ema_unet.get_current_decay() for index, ema_unet in enumerate(trainer.ema_unets) if unet_training_mask[index]}
ema_decay_list = {f"Unet {index} EMA Decay": ema_unet.get_current_decay() for index, ema_unet in enumerate(trainer.ema_unets)}
log_data = {
"Epoch": epoch,
@@ -420,7 +393,7 @@ def train(
if is_master:
tracker.log(log_data, step=step())
if is_master and (last_snapshot + save_every_n_samples < sample or (save_immediately and i == 0)): # This will miss by some amount every time, but it's not a big deal... I hope
if is_master and last_snapshot + save_every_n_samples < sample: # This will miss by some amount every time, but it's not a big deal... I hope
# It is difficult to gather this kind of info on the accelerator, so we have to do it on the master
print("Saving snapshot")
last_snapshot = sample
@@ -428,7 +401,7 @@ def train(
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, samples_seen)
if exists(n_sample_images) and n_sample_images > 0:
trainer.eval()
train_images, train_captions = generate_grid_samples(trainer, train_example_data, clip, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Train: ")
train_images, train_captions = generate_grid_samples(trainer, train_example_data, condition_on_text_encodings, "Train: ")
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
if epoch_samples is not None and sample >= epoch_samples:
@@ -446,7 +419,7 @@ def train(
timer = Timer()
accelerator.wait_for_everyone()
i = 0
for i, (img, emb, txt) in enumerate(dataloaders['val']): # Use the accelerate prepared loader
for i, (img, emb, txt) in enumerate(dataloaders["val"]):
val_sample_length_tensor[0] = len(img)
all_samples = accelerator.gather(val_sample_length_tensor)
total_samples = all_samples.sum().item()
@@ -471,20 +444,15 @@ def train(
forward_params['image_embed'] = img_emb.float()
else:
# Forward pass automatically generates embedding
assert clip is not None
img_embed, img_encoding = clip.embed_image(img)
forward_params['image_embed'] = img_embed
pass
if condition_on_text_encodings:
if has_text_embedding:
forward_params['text_encodings'] = text_emb.float()
else:
# Then we need to pass the text instead
assert clip is not None
tokenized_texts = tokenize(txt, truncate=True).to(device=inference_device)
assert tokenized_texts.shape[0] == len(img), f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})"
text_embed, text_encodings = clip.embed_text(tokenized_texts)
forward_params['text_encodings'] = text_encodings
loss = trainer.forward(img.float(), **forward_params, unet_number=unet, _device=inference_device)
tokenized_texts = tokenize(txt, truncate=True)
forward_params['text'] = tokenized_texts
loss = trainer.forward(img.float(), **forward_params, unet_number=unet)
average_val_loss_tensor[0, unet-1] += loss
if i % VALID_CALC_LOSS_EVERY_ITERS == 0:
@@ -511,7 +479,7 @@ def train(
if next_task == 'eval':
if exists(evaluate_config):
accelerator.print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, first_trainable_unet, last_trainable_unet, clip=clip, inference_device=inference_device, **evaluate_config.model_dump(), condition_on_text_encodings=condition_on_text_encodings, cond_scale=cond_scale)
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config.dict(), condition_on_text_encodings=condition_on_text_encodings)
if is_master:
tracker.log(evaluation, step=step())
next_task = 'sample'
@@ -522,15 +490,15 @@ def train(
# Generate examples and save the model if we are the master
# Generate sample images
print(print_ribbon(f"Sampling Set {epoch}", repeat=40))
test_images, test_captions = generate_grid_samples(trainer, test_example_data, clip, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Test: ")
train_images, train_captions = generate_grid_samples(trainer, train_example_data, clip, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Train: ")
test_images, test_captions = generate_grid_samples(trainer, test_example_data, condition_on_text_encodings, "Test: ")
train_images, train_captions = generate_grid_samples(trainer, train_example_data, condition_on_text_encodings, "Train: ")
tracker.log_images(test_images, captions=test_captions, image_section="Test Samples", step=step())
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
print(print_ribbon(f"Starting Saving {epoch}", repeat=40))
is_best = False
if all_average_val_losses is not None:
average_loss = all_average_val_losses.mean(dim=0).sum() / sum(unet_training_mask)
average_loss = all_average_val_losses.mean(dim=0).item()
if len(validation_losses) == 0 or average_loss < min(validation_losses):
is_best = True
validation_losses.append(average_loss)
@@ -545,10 +513,8 @@ def create_tracker(accelerator: Accelerator, config: TrainDecoderConfig, config_
"NumProcesses": accelerator.num_processes,
"MixedPrecision": accelerator.mixed_precision
}
accelerator.wait_for_everyone() # If nodes arrive at this point at different times they might try to autoresume the current run which makes no sense and will cause errors
tracker: Tracker = tracker_config.create(config, accelerator_config, dummy_mode=dummy)
tracker.save_config(config_path, config_name='decoder_config.json')
tracker.add_save_metadata(state_dict_key='config', metadata=config.model_dump())
return tracker
def initialize_training(config: TrainDecoderConfig, config_path):
@@ -556,19 +522,8 @@ def initialize_training(config: TrainDecoderConfig, config_path):
torch.manual_seed(config.seed)
# Set up accelerator for configurable distributed training
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=config.train.find_unused_parameters, static_graph=config.train.static_graph)
init_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=60*60))
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs, init_kwargs])
if accelerator.num_processes > 1:
# We are using distributed training and want to immediately ensure all can connect
accelerator.print("Waiting for all processes to connect...")
accelerator.wait_for_everyone()
accelerator.print("All processes online and connected")
# If we are in deepspeed fp16 mode, we must ensure learned variance is off
if accelerator.mixed_precision == "fp16" and accelerator.distributed_type == accelerate_dataclasses.DistributedType.DEEPSPEED and config.decoder.learned_variance:
raise ValueError("DeepSpeed fp16 mode does not support learned variance")
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=config.train.find_unused_parameters)
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
# Set up data
all_shards = list(range(config.data.start_shard, config.data.end_shard + 1))
@@ -577,7 +532,6 @@ def initialize_training(config: TrainDecoderConfig, config_path):
shards_per_process = len(all_shards) // world_size
assert shards_per_process > 0, "Not enough shards to split evenly"
my_shards = all_shards[rank * shards_per_process: (rank + 1) * shards_per_process]
dataloaders = create_dataloaders (
available_shards=my_shards,
img_preproc = config.data.img_preproc,
@@ -585,19 +539,14 @@ def initialize_training(config: TrainDecoderConfig, config_path):
val_prop = config.data.splits.val,
test_prop = config.data.splits.test,
n_sample_images=config.train.n_sample_images,
**config.data.model_dump(),
**config.data.dict(),
rank = rank,
seed = config.seed,
)
# If clip is in the model, we need to remove it for compatibility with deepspeed
clip = None
if config.decoder.clip is not None:
clip = config.decoder.clip.create() # Of course we keep it to use it during training, just not in the decoder as that causes issues
config.decoder.clip = None
# Create the decoder model and print basic info
decoder = config.decoder.create()
get_num_parameters = lambda model, only_training=False: sum(p.numel() for p in model.parameters() if (p.requires_grad or not only_training))
num_parameters = sum(p.numel() for p in decoder.parameters())
# Create and initialize the tracker if we are the master
tracker = create_tracker(accelerator, config, config_path, dummy = rank!=0)
@@ -606,7 +555,7 @@ def initialize_training(config: TrainDecoderConfig, config_path):
has_text_embeddings = config.data.text_embeddings_url is not None
conditioning_on_text = any([unet.cond_on_text_encodings for unet in config.decoder.unets])
has_clip_model = clip is not None
has_clip_model = config.decoder.clip is not None
data_source_string = ""
if has_img_embeddings:
@@ -626,17 +575,13 @@ def initialize_training(config: TrainDecoderConfig, config_path):
accelerator.print(print_ribbon("Loaded Config", repeat=40))
accelerator.print(f"Running training with {accelerator.num_processes} processes and {accelerator.distributed_type} distributed training")
accelerator.print(f"Training using {data_source_string}. {'conditioned on text' if conditioning_on_text else 'not conditioned on text'}")
accelerator.print(f"Number of parameters: {get_num_parameters(decoder)} total; {get_num_parameters(decoder, only_training=True)} training")
for i, unet in enumerate(decoder.unets):
accelerator.print(f"Unet {i} has {get_num_parameters(unet)} total; {get_num_parameters(unet, only_training=True)} training")
accelerator.print(f"Number of parameters: {num_parameters}")
train(dataloaders, decoder, accelerator,
clip=clip,
tracker=tracker,
inference_device=accelerator.device,
evaluate_config=config.evaluate,
condition_on_text_encodings=conditioning_on_text,
**config.train.model_dump(),
**config.train.dict(),
)
# Create a simple click command line interface to load the config and start the training

View File

@@ -1,23 +1,31 @@
# TODO: add start, num_data_points, eval_every and group to config
# TODO: switch back to repo's wandb
START = 0
NUM_DATA_POINTS = 250e6
EVAL_EVERY = 1000
GROUP = "distributed"
import os
import click
import wandb
import torch
from torch import nn
from typing import List
from accelerate import Accelerator
from accelerate.utils import set_seed
from torch.utils.data import DataLoader
from embedding_reader import EmbeddingReader
from accelerate.utils import dataclasses as accelerate_dataclasses
from dalle2_pytorch.utils import Timer
from dalle2_pytorch.trackers import Tracker
from dalle2_pytorch import DiffusionPriorTrainer
import numpy as np
from accelerate import Accelerator
from dalle2_pytorch.dataloaders import get_reader, make_splits
from dalle2_pytorch.utils import Timer
from dalle2_pytorch.train_configs import (
DiffusionPriorConfig,
DiffusionPriorTrainConfig,
TrainDiffusionPriorConfig,
)
from dalle2_pytorch.trackers import BaseTracker, WandbTracker
from dalle2_pytorch import DiffusionPriorTrainer
# helpers
@@ -30,19 +38,8 @@ def exists(val):
return val is not None
def all_between(values: list, lower_bound, upper_bound):
for value in values:
if value < lower_bound or value > upper_bound:
return False
return True
def make_model(
prior_config: DiffusionPriorConfig,
train_config: DiffusionPriorTrainConfig,
device: str = None,
accelerator: Accelerator = None,
prior_config, train_config, device: str = None, accelerator: Accelerator = None
):
# create model from config
diffusion_prior = prior_config.create()
@@ -57,214 +54,71 @@ def make_model(
use_ema=train_config.use_ema,
device=device,
accelerator=accelerator,
warmup_steps=train_config.warmup_steps,
)
return trainer
def create_tracker(
accelerator: Accelerator,
config: TrainDiffusionPriorConfig,
config_path: str,
dummy: bool = False,
) -> Tracker:
tracker_config = config.tracker
accelerator_config = {
"Distributed": accelerator.distributed_type
!= accelerate_dataclasses.DistributedType.NO,
"DistributedType": accelerator.distributed_type,
"NumProcesses": accelerator.num_processes,
"MixedPrecision": accelerator.mixed_precision,
}
tracker: Tracker = tracker_config.create(
config, accelerator_config, dummy_mode=dummy
)
tracker.save_config(config_path, config_name="prior_config.json")
return tracker
def pad_gather_reduce(trainer: DiffusionPriorTrainer, x, method="mean"):
"""
pad a value or tensor across all processes and gather
params:
- trainer: a trainer that carries an accelerator object
- x: a number or torch tensor to reduce
- method: "mean", "sum", "max", "min"
return:
- the average tensor after maskin out 0's
- None if the gather resulted in an empty tensor
"""
assert method in [
"mean",
"sum",
"max",
"min",
], "This function has limited capabilities [sum, mean, max, min]"
assert type(x) is not None, "Cannot reduce a None type object"
# wait for everyone to arrive here before gathering
if type(x) is not torch.Tensor:
x = torch.tensor([x])
# verify that the tensor is on the proper device
x = x.to(trainer.device)
# pad across processes
padded_x = trainer.accelerator.pad_across_processes(x, dim=0)
# gather across all procesess
gathered_x = trainer.accelerator.gather(padded_x)
# mask out zeros
masked_x = gathered_x[gathered_x != 0]
# if the tensor is empty, warn and return None
if len(masked_x) == 0:
click.secho(
f"The call to this method resulted in an empty tensor after masking out zeros. The gathered tensor was this: {gathered_x} and the original value passed was: {x}.",
fg="red",
)
return None
if method == "mean":
return torch.mean(masked_x)
elif method == "sum":
return torch.sum(masked_x)
elif method == "max":
return torch.max(masked_x)
elif method == "min":
return torch.min(masked_x)
def save_trainer(
tracker: Tracker,
trainer: DiffusionPriorTrainer,
is_latest: bool,
is_best: bool,
epoch: int,
samples_seen: int,
best_validation_loss: float,
):
"""
Logs the model with an appropriate method depending on the tracker
"""
trainer.accelerator.wait_for_everyone()
if trainer.accelerator.is_main_process:
click.secho(
f"RANK:{trainer.accelerator.process_index} | Saving Model | Best={is_best} | Latest={is_latest}",
fg="magenta",
)
tracker.save(
trainer=trainer,
is_best=is_best,
is_latest=is_latest,
epoch=int(epoch),
samples_seen=int(samples_seen),
best_validation_loss=best_validation_loss,
)
def recall_trainer(tracker: Tracker, trainer: DiffusionPriorTrainer):
"""
Loads the model with an appropriate method depending on the tracker
"""
if trainer.accelerator.is_main_process:
click.secho(f"Loading model from {type(tracker.loader).__name__}", fg="yellow")
state_dict = tracker.recall()
trainer.load(state_dict, strict=True)
return (
int(state_dict.get("epoch", 0)),
state_dict.get("best_validation_loss", 0),
int(state_dict.get("samples_seen", 0)),
)
# eval functions
def report_validation_loss(
def eval_model(
trainer: DiffusionPriorTrainer,
dataloader: DataLoader,
text_conditioned: bool,
use_ema: bool,
tracker: Tracker,
split: str,
tracker_folder: str,
loss_type: str,
tracker_context: str,
tracker: BaseTracker = None,
use_ema: bool = True,
):
"""
Compute the validation loss on a given subset of data.
"""
trainer.eval()
if trainer.is_main_process():
click.secho(f"Measuring performance on {tracker_context}", fg="green", blink=True)
if trainer.accelerator.is_main_process:
click.secho(
f"Measuring performance on {use_ema}-{split} split",
fg="green",
blink=True,
)
with torch.no_grad():
total_loss = 0.0
total_samples = 0.0
total_loss = torch.zeros(1, dtype=torch.float, device=trainer.device)
for image_embeddings, text_data in dataloader:
image_embeddings = image_embeddings.to(trainer.device)
text_data = text_data.to(trainer.device)
for image_embeddings, text_data in dataloader:
image_embeddings = image_embeddings.to(trainer.device)
text_data = text_data.to(trainer.device)
batches = image_embeddings.shape[0]
input_args = dict(image_embed=image_embeddings)
input_args = dict(image_embed=image_embeddings)
if text_conditioned:
input_args = dict(**input_args, text=text_data)
else:
input_args = dict(**input_args, text_embed=text_data)
if text_conditioned:
input_args = dict(**input_args, text=text_data)
else:
input_args = dict(**input_args, text_embed=text_data)
if use_ema:
loss = trainer.ema_diffusion_prior(**input_args)
else:
loss = trainer(**input_args)
if use_ema:
loss = trainer.ema_diffusion_prior(**input_args)
else:
loss = trainer(**input_args)
total_loss += loss
total_loss += loss * batches
total_samples += batches
# compute the average loss across all processes
avg_loss = total_loss / total_samples
avg_loss = pad_gather_reduce(trainer, total_loss, method="mean")
stats = {f"{tracker_folder}/{loss_type}-loss": avg_loss}
stats = {f"{tracker_context}-{loss_type}": avg_loss}
trainer.print(stats)
# print and log results on main process
tracker.log(stats, step=trainer.step.item() + 1)
return avg_loss
if exists(tracker):
tracker.log(stats, step=trainer.step.item() + 1)
def report_cosine_sims(
trainer: DiffusionPriorTrainer,
dataloader: DataLoader,
text_conditioned: bool,
tracker: Tracker,
split: str,
timesteps: int,
tracker_folder: str,
tracker: BaseTracker,
tracker_context: str = "validation",
):
trainer.eval()
if trainer.accelerator.is_main_process:
click.secho(
f"Measuring Cosine-Similarity on {split} split with {timesteps} timesteps",
fg="green",
blink=True,
)
if trainer.is_main_process():
click.secho("Measuring Cosine-Similarity", fg="green", blink=True)
for test_image_embeddings, text_data in dataloader:
test_image_embeddings = test_image_embeddings.to(trainer.device)
@@ -272,8 +126,10 @@ def report_cosine_sims(
# we are text conditioned, we produce an embedding from the tokenized text
if text_conditioned:
text_embedding, text_encodings = trainer.embed_text(text_data)
text_cond = dict(text_embed=text_embedding, text_encodings=text_encodings)
text_embedding, text_encodings, text_mask = trainer.embed_text(text_data)
text_cond = dict(
text_embed=text_embedding, text_encodings=text_encodings, mask=text_mask
)
else:
text_embedding = text_data
text_cond = dict(text_embed=text_embedding)
@@ -290,11 +146,15 @@ def report_cosine_sims(
if text_conditioned:
text_encodings_shuffled = text_encodings[rolled_idx]
text_mask_shuffled = text_mask[rolled_idx]
else:
text_encodings_shuffled = None
text_mask_shuffled = None
text_cond_shuffled = dict(
text_embed=text_embed_shuffled, text_encodings=text_encodings_shuffled
text_embed=text_embed_shuffled,
text_encodings=text_encodings_shuffled,
mask=text_mask_shuffled,
)
# prepare the text embedding
@@ -307,9 +167,7 @@ def report_cosine_sims(
# predict on the unshuffled text embeddings
predicted_image_embeddings = trainer.p_sample_loop(
test_image_embeddings.shape,
text_cond,
timesteps=timesteps,
test_image_embeddings.shape, text_cond
)
predicted_image_embeddings = (
@@ -319,9 +177,7 @@ def report_cosine_sims(
# predict on the shuffled embeddings
predicted_unrelated_embeddings = trainer.p_sample_loop(
test_image_embeddings.shape,
text_cond_shuffled,
timesteps=timesteps,
test_image_embeddings.shape, text_cond_shuffled
)
predicted_unrelated_embeddings = (
@@ -330,97 +186,32 @@ def report_cosine_sims(
)
# calculate similarities
orig_sim = pad_gather_reduce(
trainer, cos(text_embed, test_image_embeddings), method="mean"
original_similarity = cos(text_embed, test_image_embeddings).cpu().numpy()
predicted_similarity = cos(text_embed, predicted_image_embeddings).cpu().numpy()
unrelated_similarity = (
cos(text_embed, predicted_unrelated_embeddings).cpu().numpy()
)
pred_sim = pad_gather_reduce(
trainer, cos(text_embed, predicted_image_embeddings), method="mean"
)
unrel_sim = pad_gather_reduce(
trainer, cos(text_embed, predicted_unrelated_embeddings), method="mean"
)
pred_img_sim = pad_gather_reduce(
trainer,
cos(test_image_embeddings, predicted_image_embeddings),
method="mean",
predicted_img_similarity = (
cos(test_image_embeddings, predicted_image_embeddings).cpu().numpy()
)
stats = {
f"{tracker_folder}/baseline similarity [steps={timesteps}]": orig_sim,
f"{tracker_folder}/similarity with text [steps={timesteps}]": pred_sim,
f"{tracker_folder}/similarity with original image [steps={timesteps}]": pred_img_sim,
f"{tracker_folder}/similarity with unrelated caption [steps={timesteps}]": unrel_sim,
f"{tracker_folder}/difference from baseline similarity [steps={timesteps}]": pred_sim
- orig_sim,
f"{tracker_context}/baseline similarity": np.mean(original_similarity),
f"{tracker_context}/similarity with text": np.mean(predicted_similarity),
f"{tracker_context}/similarity with original image": np.mean(
predicted_img_similarity
),
f"{tracker_context}/similarity with unrelated caption": np.mean(unrelated_similarity),
f"{tracker_context}/difference from baseline similarity": np.mean(
predicted_similarity - original_similarity
),
}
tracker.log(stats, step=trainer.step.item() + 1)
for k, v in stats.items():
trainer.print(f"{tracker_context}/{k}: {v}")
def eval_model(
trainer: DiffusionPriorTrainer,
dataloader: DataLoader,
text_conditioned: bool,
split: str,
tracker: Tracker,
use_ema: bool,
report_cosine: bool,
report_loss: bool,
timesteps: List[int],
loss_type: str = None,
):
"""
Run evaluation on a model and track metrics
returns: loss if requested
"""
trainer.eval()
use_ema = "ema" if use_ema else "online"
tracker_folder = f"metrics/{use_ema}-{split}"
# detemine if valid timesteps are passed
min_timesteps = trainer.accelerator.unwrap_model(
trainer.diffusion_prior
).sample_timesteps
max_timesteps = trainer.accelerator.unwrap_model(
trainer.diffusion_prior
).noise_scheduler.num_timesteps
assert all_between(
timesteps, lower_bound=min_timesteps, upper_bound=max_timesteps
), f"all timesteps values must be between {min_timesteps} and {max_timesteps}: got {timesteps}"
# measure cosine metrics across various eta and timesteps
if report_cosine:
for timestep in timesteps:
report_cosine_sims(
trainer,
dataloader=dataloader,
text_conditioned=text_conditioned,
tracker=tracker,
split=split,
timesteps=timestep,
tracker_folder=tracker_folder,
)
# measure loss on a seperate split of data
if report_loss:
loss = report_validation_loss(
trainer=trainer,
dataloader=dataloader,
text_conditioned=text_conditioned,
use_ema=use_ema,
tracker=tracker,
split=split,
tracker_folder=tracker_folder,
loss_type=loss_type,
)
return loss
if exists(tracker):
tracker.log(stats, step=trainer.step.item() + 1)
# training script
@@ -428,327 +219,182 @@ def eval_model(
def train(
trainer: DiffusionPriorTrainer,
tracker: Tracker,
train_loader: DataLoader,
eval_loader: DataLoader,
test_loader: DataLoader,
config: DiffusionPriorTrainConfig,
):
# init timers
save_timer = Timer() # when to save
samples_timer = Timer() # samples/sec
validation_profiler = Timer() # how long is validation taking
validation_countdown = Timer() # when to perform evalutation
# distributed tracking with wandb
if trainer.accelerator.num_processes > 1:
os.environ["WANDB_START_METHOD"] = "thread"
# keep track of best validation loss
tracker = wandb.init(
name=f"RANK:{trainer.device}",
entity=config.tracker.wandb_entity,
project=config.tracker.wandb_project,
config=config.dict(),
group=GROUP,
)
best_validation_loss = config.train.best_validation_loss
samples_seen = config.train.num_samples_seen
# sync after tracker init
trainer.wait_for_everyone()
# init a timer
timer = Timer()
# do training
for img, txt in train_loader:
trainer.train()
current_step = trainer.step.item() + 1
start_epoch = config.train.current_epoch
# place data on device
img = img.to(trainer.device)
txt = txt.to(trainer.device)
for epoch in range(start_epoch, config.train.epochs):
# if we finished out an old epoch, reset the distribution to be a full epoch
tracker.log({"tracking/epoch": epoch}, step=trainer.step.item())
# pass to model
loss = trainer(text=txt, image_embed=img)
if train_loader.dataset.get_start() > 0 and epoch == start_epoch+1:
if trainer.accelerator.is_main_process:
click.secho(f"Finished resumed epoch...resetting dataloader.")
train_loader.dataset.set_start(0)
# display & log loss (will only print from main process)
trainer.print(f"Step {current_step}: Loss {loss}")
for img, txt in train_loader:
# setup things every step
# perform backprop & apply EMA updates
trainer.update()
trainer.train()
current_step = trainer.step.item()
samples_timer.reset()
# track samples/sec/rank
samples_per_sec = img.shape[0] / timer.elapsed()
# place data on device
# samples seen
samples_seen = (
config.data.batch_size * trainer.accelerator.num_processes * current_step
)
img = img.to(trainer.device)
txt = txt.to(trainer.device)
# ema decay
ema_decay = trainer.ema_diffusion_prior.get_current_decay()
# pass to model
# Log on all processes for debugging
tracker.log(
{
"tracking/samples-sec": samples_per_sec,
"tracking/samples-seen": samples_seen,
"tracking/ema-decay": ema_decay,
"metrics/training-loss": loss,
},
step=current_step,
)
loss = trainer(text=txt, image_embed=img)
# perform backprop & apply EMA updates
trainer.update()
# gather info about training step
all_loss = pad_gather_reduce(trainer, loss, method="mean")
num_samples = pad_gather_reduce(trainer, len(txt), method="sum")
samples_per_sec = num_samples / samples_timer.elapsed()
samples_seen += num_samples
ema_decay = trainer.ema_diffusion_prior.get_current_decay()
# log
tracker.log(
{
"tracking/samples-sec": samples_per_sec,
"tracking/samples-seen": samples_seen,
"tracking/ema-decay": ema_decay,
f"tracking/training-{config.prior.loss_type}": all_loss,
},
step=current_step,
# Metric Tracking & Checkpointing (outside of timer's scope)
if current_step % EVAL_EVERY == 0:
eval_model(
trainer=trainer,
dataloader=eval_loader,
text_conditioned=config.prior.condition_on_text_encodings,
loss_type=config.prior.loss_type,
tracker_context="metrics/online-model-validation",
tracker=tracker,
use_ema=False,
)
# Metric Tracking @ Timed Intervals
eval_delta = pad_gather_reduce(
trainer, validation_countdown.elapsed(), method="min"
eval_model(
trainer=trainer,
dataloader=eval_loader,
text_conditioned=config.prior.condition_on_text_encodings,
loss_type=config.prior.loss_type,
tracker_context="metrics/ema-model-validation",
tracker=tracker,
use_ema=True,
)
if eval_delta != None and eval_delta > config.data.eval_every_seconds:
# begin timing how long this takes
report_cosine_sims(
trainer=trainer,
dataloader=eval_loader,
text_conditioned=config.prior.condition_on_text_encodings,
tracker=tracker,
tracker_context="metrics",
)
validation_profiler.reset()
if current_step % config.train.save_every == 0:
trainer.save(f"{config.tracker.data_path}/chkpt_step_{current_step}.pth")
# package kwargs for evaluation
eval_kwargs = {
"trainer": trainer,
"tracker": tracker,
"text_conditioned": config.prior.condition_on_text_encodings,
"timesteps": config.train.eval_timesteps,
}
# ONLINE MODEL : COSINE : LOSS : VALIDATION SPLIT
eval_model(
dataloader=eval_loader,
loss_type=config.prior.loss_type,
split="validation",
use_ema=False,
report_cosine=False,
report_loss=True,
**eval_kwargs,
)
# EMA MODEL : COSINE : LOSS : VALIDATION DATA
ema_val_loss = eval_model(
dataloader=eval_loader,
loss_type=config.prior.loss_type,
split="validation",
use_ema=True,
report_cosine=True,
report_loss=True,
**eval_kwargs,
)
tracker.log(
{
"tracking/validation length (minutes)": validation_profiler.elapsed()
/ 60
}
)
# check if the ema validation is the lowest seen yet
if ema_val_loss < best_validation_loss:
best_validation_loss = ema_val_loss
# go save the model as best
save_trainer(
trainer=trainer,
tracker=tracker,
is_best=True,
is_latest=False,
samples_seen=samples_seen,
epoch=epoch,
best_validation_loss=best_validation_loss,
)
# reset timer for validaiton
validation_countdown.reset()
elif eval_delta is None:
click.secho(
f"Error occured reading the eval time on rank: {trainer.device}",
fg="yellow",
)
# save as latest model on schedule
save_delta = pad_gather_reduce(trainer, save_timer.elapsed(), method="min")
if save_delta != None and save_delta >= config.train.save_every_seconds:
save_trainer(
trainer=trainer,
tracker=tracker,
is_best=False,
is_latest=True,
samples_seen=samples_seen,
epoch=epoch,
best_validation_loss=best_validation_loss,
)
save_timer.reset()
elif save_delta is None:
click.secho(
f"Error occured reading the save time on rank: {trainer.device}",
fg="yellow",
)
# reset timer for next round
timer.reset()
# evaluate on test data
if trainer.accelerator.is_main_process:
click.secho(f"Starting Test", fg="red")
# save one last time as latest before beginning validation
save_trainer(
tracker=tracker,
trainer=trainer,
is_best=False,
is_latest=True,
samples_seen=samples_seen,
epoch=epoch,
best_validation_loss=best_validation_loss,
)
test_loss = eval_model(
eval_model(
trainer=trainer,
dataloader=test_loader,
text_conditioned=config.prior.condition_on_text_encodings,
split="test",
tracker=tracker,
use_ema=True,
report_cosine=False,
report_loss=True,
timesteps=config.train.eval_timesteps,
loss_type=config.prior.loss_type,
tracker_context="test",
tracker=tracker,
)
if test_loss < best_validation_loss:
best_validation_loss = test_loss
# go save the model as best
save_trainer(
trainer=trainer,
tracker=tracker,
is_best=True,
is_latest=False,
samples_seen=samples_seen,
epoch=epoch,
best_validation_loss=test_loss,
)
report_cosine_sims(
trainer,
test_loader,
config.prior.condition_on_text_encodings,
tracker,
tracker_context="test",
)
def initialize_training(config_file, accelerator):
def initialize_training(config, accelerator=None):
"""
Parse the configuration file, and prepare everything necessary for training
"""
# load the configuration file
if accelerator.is_main_process:
click.secho(f"Loading configuration from {config_file}", fg="green")
config = TrainDiffusionPriorConfig.from_json_path(config_file)
# seed
set_seed(config.train.random_seed)
# get a device
device = accelerator.device
if accelerator:
device = accelerator.device
click.secho(f"Accelerating on: {device}", fg="yellow")
else:
if torch.cuda.is_available():
click.secho("GPU detected, defaulting to cuda:0", fg="yellow")
device = "cuda:0"
else:
click.secho("No GPU detected...using cpu", fg="yellow")
device = "cpu"
# make the trainer (will automatically distribute if possible & configured)
trainer: DiffusionPriorTrainer = make_model(
config.prior, config.train, device, accelerator
).to(device)
# create a tracker
tracker = create_tracker(
accelerator, config, config_file, dummy=accelerator.process_index != 0
)
trainer = make_model(config.prior, config.train, device, accelerator).to(device)
# reload from chcekpoint
if tracker.can_recall:
current_epoch, best_validation_loss, samples_seen = recall_trainer(
tracker=tracker, trainer=trainer
)
# display best values
if trainer.accelerator.is_main_process:
click.secho(f"Current Epoch: {current_epoch} | Best Val Loss: {best_validation_loss} | Samples Seen: {samples_seen}", fg="yellow")
# update config to reflect recalled values
config.train.num_samples_seen = samples_seen
config.train.current_epoch = current_epoch
config.train.best_validation_loss = best_validation_loss
if config.load.resume == True:
click.secho(f"Loading checkpoint: {config.load.source}", fg="cyan")
trainer.load(config.load.source)
# fetch and prepare data
if trainer.accelerator.is_main_process:
click.secho("Grabbing data...", fg="blue", blink=True)
if trainer.is_main_process():
click.secho("Grabbing data from source", fg="blue", blink=True)
trainer.accelerator.wait_for_everyone()
img_reader = get_reader(
text_conditioned=trainer.text_conditioned,
img_url=config.data.image_url,
meta_url=config.data.meta_url,
)
# calculate start point within epoch
trainer.accelerator.wait_for_everyone()
train_loader, eval_loader, test_loader = make_splits(
text_conditioned=trainer.text_conditioned,
batch_size=config.data.batch_size,
num_data_points=config.data.num_data_points,
num_data_points=NUM_DATA_POINTS,
train_split=config.data.splits.train,
eval_split=config.data.splits.val,
image_reader=img_reader,
rank=accelerator.state.process_index,
world_size=accelerator.state.num_processes,
start=0,
rank=accelerator.state.process_index if exists(accelerator) else 0,
world_size=accelerator.state.num_processes if exists(accelerator) else 1,
start=START,
)
# update the start point to finish out the epoch on a resumed run
if tracker.can_recall:
samples_seen = config.train.num_samples_seen
length = (
config.data.num_data_points
if samples_seen <= img_reader.count
else img_reader.count
)
scaled_samples = length * config.train.current_epoch
start_point = (
scaled_samples - samples_seen if scaled_samples > samples_seen else samples_seen
)
if trainer.accelerator.is_main_process:
click.secho(f"Resuming at sample: {start_point}", fg="yellow")
train_loader.dataset.set_start(start_point)
# wait for everyone to load data before continuing
trainer.wait_for_everyone()
# start training
if trainer.accelerator.is_main_process:
click.secho(
f"Beginning Prior Training : Distributed={accelerator.state.distributed_type != accelerate_dataclasses.DistributedType.NO}",
fg="yellow",
)
train(
trainer=trainer,
tracker=tracker,
train_loader=train_loader,
eval_loader=eval_loader,
test_loader=test_loader,
@@ -757,13 +403,23 @@ def initialize_training(config_file, accelerator):
@click.command()
@click.option("--config_file", default="configs/train_prior_config.example.json")
def main(config_file):
# start HFA
accelerator = Accelerator()
@click.option("--hfa", default=True)
@click.option("--config_path", default="configs/prior.json")
def main(hfa, config_path):
# start HFA if requested
if hfa:
accelerator = Accelerator()
else:
accelerator = None
# setup training
initialize_training(config_file, accelerator)
# load the configuration file on main process
if not exists(accelerator) or accelerator.is_main_process:
click.secho(f"Loading configuration from {config_path}", fg="green")
config = TrainDiffusionPriorConfig.from_json_path(config_path)
# send config to get processed
initialize_training(config, accelerator)
if __name__ == "__main__":