Compare commits

..

1 Commits

Author SHA1 Message Date
Romain Beaumont
3a1dea7d97 Fix decoder test by fixing the resizing output size 2022-07-09 11:36:22 +02:00
14 changed files with 677 additions and 1718 deletions

2
.github/FUNDING.yml vendored
View File

@@ -1 +1 @@
github: [nousr, Veldrovive, lucidrains] github: [lucidrains]

135
README.md
View File

@@ -45,7 +45,6 @@ This library would not have gotten to this working state without the help of
- <a href="https://github.com/rom1504">Romain</a> for the pull request reviews and project management - <a href="https://github.com/rom1504">Romain</a> for the pull request reviews and project management
- <a href="https://github.com/Ciaohe">He Cao</a> and <a href="https://github.com/xiankgx">xiankgx</a> for the Q&A and for identifying of critical bugs - <a href="https://github.com/Ciaohe">He Cao</a> and <a href="https://github.com/xiankgx">xiankgx</a> for the Q&A and for identifying of critical bugs
- <a href="https://github.com/marunine">Marunine</a> for identifying issues with resizing of the low resolution conditioner, when training the upsampler, in addition to various other bug fixes - <a href="https://github.com/marunine">Marunine</a> for identifying issues with resizing of the low resolution conditioner, when training the upsampler, in addition to various other bug fixes
- <a href="https://github.com/malumadev">MalumaDev</a> for proposing the use of pixel shuffle upsampler for fixing checkboard artifacts
- <a href="https://github.com/crowsonkb">Katherine</a> for her advice - <a href="https://github.com/crowsonkb">Katherine</a> for her advice
- <a href="https://stability.ai/">Stability AI</a> for the generous sponsorship - <a href="https://stability.ai/">Stability AI</a> for the generous sponsorship
- <a href="https://huggingface.co">🤗 Huggingface</a> and in particular <a href="https://github.com/sgugger">Sylvain</a> for the <a href="https://github.com/huggingface/accelerate">Accelerate</a> library - <a href="https://huggingface.co">🤗 Huggingface</a> and in particular <a href="https://github.com/sgugger">Sylvain</a> for the <a href="https://github.com/huggingface/accelerate">Accelerate</a> library
@@ -356,8 +355,7 @@ prior_network = DiffusionPriorNetwork(
diffusion_prior = DiffusionPrior( diffusion_prior = DiffusionPrior(
net = prior_network, net = prior_network,
clip = clip, clip = clip,
timesteps = 1000, timesteps = 100,
sample_timesteps = 64,
cond_drop_prob = 0.2 cond_drop_prob = 0.2
).cuda() ).cuda()
@@ -371,7 +369,6 @@ loss.backward()
unet1 = Unet( unet1 = Unet(
dim = 128, dim = 128,
image_embed_dim = 512, image_embed_dim = 512,
text_embed_dim = 512,
cond_dim = 128, cond_dim = 128,
channels = 3, channels = 3,
dim_mults=(1, 2, 4, 8), dim_mults=(1, 2, 4, 8),
@@ -396,7 +393,7 @@ decoder = Decoder(
).cuda() ).cuda()
for unet_number in (1, 2): for unet_number in (1, 2):
loss = decoder(images, text = text, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much loss = decoder(images, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
loss.backward() loss.backward()
# do above for many steps # do above for many steps
@@ -422,7 +419,7 @@ For the layperson, no worries, training will all be automated into a CLI tool, a
## Training on Preprocessed CLIP Embeddings ## Training on Preprocessed CLIP Embeddings
It is likely, when scaling up, that you would first preprocess your images and text into corresponding embeddings before training the prior network. You can do so easily by simply passing in `image_embed`, `text_embed`, and optionally `text_encodings` It is likely, when scaling up, that you would first preprocess your images and text into corresponding embeddings before training the prior network. You can do so easily by simply passing in `image_embed`, `text_embed`, and optionally `text_encodings` and `text_mask`
Working example below Working example below
@@ -586,7 +583,6 @@ unet1 = Unet(
cond_dim = 128, cond_dim = 128,
channels = 3, channels = 3,
dim_mults=(1, 2, 4, 8), dim_mults=(1, 2, 4, 8),
text_embed_dim = 512,
cond_on_text_encodings = True # set to True for any unets that need to be conditioned on text encodings (ex. first unet in cascade) cond_on_text_encodings = True # set to True for any unets that need to be conditioned on text encodings (ex. first unet in cascade)
).cuda() ).cuda()
@@ -602,8 +598,7 @@ decoder = Decoder(
unet = (unet1, unet2), unet = (unet1, unet2),
image_sizes = (128, 256), image_sizes = (128, 256),
clip = clip, clip = clip,
timesteps = 1000, timesteps = 100,
sample_timesteps = (250, 27),
image_cond_drop_prob = 0.1, image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5 text_cond_drop_prob = 0.5
).cuda() ).cuda()
@@ -629,82 +624,6 @@ images = dalle2(
Now you'll just have to worry about training the Prior and the Decoder! Now you'll just have to worry about training the Prior and the Decoder!
## Inpainting
Inpainting is also built into the `Decoder`. You simply have to pass in the `inpaint_image` and `inpaint_mask` (boolean tensor where `True` indicates which regions of the inpaint image to keep)
This repository uses the formulation put forth by <a href="https://arxiv.org/abs/2201.09865">Lugmayr et al. in Repaint</a>
```python
import torch
from dalle2_pytorch import Unet, Decoder, CLIP
# trained clip from step 1
clip = CLIP(
dim_text = 512,
dim_image = 512,
dim_latent = 512,
num_text_tokens = 49408,
text_enc_depth = 6,
text_seq_len = 256,
text_heads = 8,
visual_enc_depth = 6,
visual_image_size = 256,
visual_patch_size = 32,
visual_heads = 8
).cuda()
# 2 unets for the decoder (a la cascading DDPM)
unet = Unet(
dim = 16,
image_embed_dim = 512,
cond_dim = 128,
channels = 3,
dim_mults = (1, 1, 1, 1)
).cuda()
# decoder, which contains the unet(s) and clip
decoder = Decoder(
clip = clip,
unet = (unet,), # insert both unets in order of low resolution to highest resolution (you can have as many stages as you want here)
image_sizes = (256,), # resolutions, 256 for first unet, 512 for second. these must be unique and in ascending order (matches with the unets passed in)
timesteps = 1000,
image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5
).cuda()
# mock images (get a lot of this)
images = torch.randn(4, 3, 256, 256).cuda()
# feed images into decoder, specifying which unet you want to train
# each unet can be trained separately, which is one of the benefits of the cascading DDPM scheme
loss = decoder(images, unet_number = 1)
loss.backward()
# do the above for many steps for both unets
mock_image_embed = torch.randn(1, 512).cuda()
# then to do inpainting
inpaint_image = torch.randn(1, 3, 256, 256).cuda() # (batch, channels, height, width)
inpaint_mask = torch.ones(1, 256, 256).bool().cuda() # (batch, height, width)
inpainted_images = decoder.sample(
image_embed = mock_image_embed,
inpaint_image = inpaint_image, # just pass in the inpaint image
inpaint_mask = inpaint_mask # and the mask
)
inpainted_images.shape # (1, 3, 256, 256)
```
## Experimental ## Experimental
### DALL-E2 with Latent Diffusion ### DALL-E2 with Latent Diffusion
@@ -861,23 +780,25 @@ unet1 = Unet(
text_embed_dim = 512, text_embed_dim = 512,
cond_dim = 128, cond_dim = 128,
channels = 3, channels = 3,
dim_mults=(1, 2, 4, 8), dim_mults=(1, 2, 4, 8)
cond_on_text_encodings = True,
).cuda() ).cuda()
unet2 = Unet( unet2 = Unet(
dim = 16, dim = 16,
image_embed_dim = 512, image_embed_dim = 512,
text_embed_dim = 512,
cond_dim = 128, cond_dim = 128,
channels = 3, channels = 3,
dim_mults = (1, 2, 4, 8, 16), dim_mults = (1, 2, 4, 8, 16),
cond_on_text_encodings = True
).cuda() ).cuda()
decoder = Decoder( decoder = Decoder(
unet = (unet1, unet2), unet = (unet1, unet2),
image_sizes = (128, 256), image_sizes = (128, 256),
clip = clip, clip = clip,
timesteps = 1000 timesteps = 1000,
condition_on_text_encodings = True
).cuda() ).cuda()
decoder_trainer = DecoderTrainer( decoder_trainer = DecoderTrainer(
@@ -902,8 +823,8 @@ for unet_number in (1, 2):
# after much training # after much training
# you can sample from the exponentially moving averaged unets as so # you can sample from the exponentially moving averaged unets as so
mock_image_embed = torch.randn(32, 512).cuda() mock_image_embed = torch.randn(4, 512).cuda()
images = decoder_trainer.sample(image_embed = mock_image_embed, text = text) # (4, 3, 256, 256) images = decoder_trainer.sample(mock_image_embed, text = text) # (4, 3, 256, 256)
``` ```
### Diffusion Prior Training ### Diffusion Prior Training
@@ -1066,12 +987,26 @@ dataset = ImageEmbeddingDataset(
) )
``` ```
### Scripts ### Scripts (wip)
#### `train_diffusion_prior.py` #### `train_diffusion_prior.py`
For detailed information on training the diffusion prior, please refer to the [dedicated readme](prior.md) For detailed information on training the diffusion prior, please refer to the [dedicated readme](prior.md)
## CLI (wip)
```bash
$ dream 'sharing a sunset at the summit of mount everest with my dog'
```
Once built, images will be saved to the same directory the command is invoked
<a href="https://github.com/lucidrains/big-sleep">template</a>
## Training CLI (wip)
<a href="https://github.com/lucidrains/stylegan2-pytorch">template</a>
## Todo ## Todo
- [x] finish off gaussian diffusion class for latent embedding - allow for prediction of epsilon - [x] finish off gaussian diffusion class for latent embedding - allow for prediction of epsilon
@@ -1109,11 +1044,11 @@ For detailed information on training the diffusion prior, please refer to the [d
- [x] bring in skip-layer excitations (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training (doesnt work well) - [x] bring in skip-layer excitations (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training (doesnt work well)
- [x] test out grid attention in cascading ddpm locally, decide whether to keep or remove https://arxiv.org/abs/2204.01697 (keeping, seems to be fine) - [x] test out grid attention in cascading ddpm locally, decide whether to keep or remove https://arxiv.org/abs/2204.01697 (keeping, seems to be fine)
- [x] allow for unet to be able to condition non-cross attention style as well - [x] allow for unet to be able to condition non-cross attention style as well
- [x] speed up inference, read up on papers (ddim) - [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
- [x] add inpainting ability using resampler from repaint paper https://arxiv.org/abs/2201.09865 - [ ] speed up inference, read up on papers (ddim or diffusion-gan, etc)
- [x] add the final combination of upsample feature maps, used in unet squared, seems to have an effect in local experiments - [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824
- [ ] consider elucidated dalle2 https://arxiv.org/abs/2206.00364
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2 - [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
- [ ] add inpainting ability using resampler from repaint paper https://arxiv.org/abs/2201.09865
## Citations ## Citations
@@ -1231,14 +1166,4 @@ For detailed information on training the diffusion prior, please refer to the [d
} }
``` ```
```bibtex
@article{Lugmayr2022RePaintIU,
title = {RePaint: Inpainting using Denoising Diffusion Probabilistic Models},
author = {Andreas Lugmayr and Martin Danelljan and Andr{\'e}s Romero and Fisher Yu and Radu Timofte and Luc Van Gool},
journal = {ArXiv},
year = {2022},
volume = {abs/2201.09865}
}
```
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a> *Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>

View File

@@ -69,12 +69,14 @@ Settings for controlling the training hyperparameters.
| `wd` | No | `0.01` | The weight decay. | | `wd` | No | `0.01` | The weight decay. |
| `max_grad_norm`| No | `0.5` | The grad norm clipping. | | `max_grad_norm`| No | `0.5` | The grad norm clipping. |
| `save_every_n_samples` | No | `100000` | Samples will be generated and a checkpoint will be saved every `save_every_n_samples` samples. | | `save_every_n_samples` | No | `100000` | Samples will be generated and a checkpoint will be saved every `save_every_n_samples` samples. |
| `cond_scale` | No | `1.0` | Conditioning scale to use for sampling. Can also be an array of values, one for each unet. |
| `device` | No | `cuda:0` | The device to train on. | | `device` | No | `cuda:0` | The device to train on. |
| `epoch_samples` | No | `None` | Limits the number of samples iterated through in each epoch. This must be set if resampling. None means no limit. | | `epoch_samples` | No | `None` | Limits the number of samples iterated through in each epoch. This must be set if resampling. None means no limit. |
| `validation_samples` | No | `None` | The number of samples to use for validation. None mean the entire validation set. | | `validation_samples` | No | `None` | The number of samples to use for validation. None mean the entire validation set. |
| `use_ema` | No | `True` | Whether to use exponential moving average models for sampling. | | `use_ema` | No | `True` | Whether to use exponential moving average models for sampling. |
| `ema_beta` | No | `0.99` | The ema coefficient. | | `ema_beta` | No | `0.99` | The ema coefficient. |
| `save_all` | No | `False` | If True, preserves a checkpoint for every epoch. |
| `save_latest` | No | `True` | If True, overwrites the `latest.pth` every time the model is saved. |
| `save_best` | No | `True` | If True, overwrites the `best.pth` every time the model has a lower validation loss than all previous models. |
| `unet_training_mask` | No | `None` | A boolean array of the same length as the number of unets. If false, the unet is frozen. A value of `None` trains all unets. | | `unet_training_mask` | No | `None` | A boolean array of the same length as the number of unets. If false, the unet is frozen. A value of `None` trains all unets. |
**<ins>Evaluate</ins>:** **<ins>Evaluate</ins>:**
@@ -161,10 +163,9 @@ All save locations have these configuration options
| Option | Required | Default | Description | | Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- | | ------ | -------- | ------- | ----------- |
| `save_to` | Yes | N/A | Must be `local`, `huggingface`, or `wandb`. | | `save_to` | Yes | N/A | Must be `local`, `huggingface`, or `wandb`. |
| `save_latest_to` | No | `None` | Sets the relative path to save the latest model to. | | `save_latest_to` | No | `latest.pth` | Sets the relative path to save the latest model to. |
| `save_best_to` | No | `None` | Sets the relative path to save the best model to every time the model has a lower validation loss than all previous models. | | `save_best_to` | No | `best.pth` | Sets the relative path to save the best model to every time the model has a lower validation loss than all previous models. |
| `save_meta_to` | No | `None` | The path to save metadata files in. This includes the config files used to start the training. | | `save_type` | No | `'checkpoint'` | The type of save. `'checkpoint'` saves a checkpoint, `'model'` saves a model without any fluff (Saves with ema if ema is enabled). |
| `save_type` | No | `checkpoint` | The type of save. `checkpoint` saves a checkpoint, `model` saves a model without any fluff (Saves with ema if ema is enabled). |
If using `local` If using `local`
| Option | Required | Default | Description | | Option | Required | Default | Description |
@@ -176,6 +177,7 @@ If using `huggingface`
| ------ | -------- | ------- | ----------- | | ------ | -------- | ------- | ----------- |
| `save_to` | Yes | N/A | Must be `huggingface`. | | `save_to` | Yes | N/A | Must be `huggingface`. |
| `huggingface_repo` | Yes | N/A | The huggingface repository to save to. | | `huggingface_repo` | Yes | N/A | The huggingface repository to save to. |
| `huggingface_base_path` | Yes | N/A | The base path that checkpoints will be saved under. |
| `token_path` | No | `None` | If logging in with the huggingface cli is not possible, point to a token file instead. | | `token_path` | No | `None` | If logging in with the huggingface cli is not possible, point to a token file instead. |
If using `wandb` If using `wandb`

View File

@@ -56,6 +56,9 @@
"use_ema": true, "use_ema": true,
"ema_beta": 0.99, "ema_beta": 0.99,
"amp": false, "amp": false,
"save_all": false,
"save_latest": true,
"save_best": true,
"unet_training_mask": [true] "unet_training_mask": [true]
}, },
"evaluate": { "evaluate": {
@@ -93,15 +96,14 @@
}, },
"save": [{ "save": [{
"save_to": "wandb", "save_to": "wandb"
"save_latest_to": "latest.pth"
}, { }, {
"save_to": "huggingface", "save_to": "huggingface",
"huggingface_repo": "Veldrovive/test_model", "huggingface_repo": "Veldrovive/test_model",
"save_latest_to": "path/to/model_dir/latest.pth", "save_all": true,
"save_best_to": "path/to/model_dir/best.pth", "save_latest": true,
"save_meta_to": "path/to/directory/for/assorted/files", "save_best": true,
"save_type": "model" "save_type": "model"
}] }]

View File

@@ -61,6 +61,9 @@
"use_ema": true, "use_ema": true,
"ema_beta": 0.99, "ema_beta": 0.99,
"amp": false, "amp": false,
"save_all": false,
"save_latest": true,
"save_best": true,
"unet_training_mask": [true] "unet_training_mask": [true]
}, },
"evaluate": { "evaluate": {
@@ -93,8 +96,7 @@
}, },
"save": [{ "save": [{
"save_to": "local", "save_to": "local"
"save_latest_to": "latest.pth"
}] }]
} }
} }

View File

@@ -1,14 +1,18 @@
{ {
"prior": { "prior": {
"clip": { "clip": {
"make": "openai", "make": "x-clip",
"model": "ViT-L/14" "model": "ViT-L/14",
"base_model_kwargs": {
"dim_text": 768,
"dim_image": 768,
"dim_latent": 768
}
}, },
"net": { "net": {
"dim": 768, "dim": 768,
"depth": 12, "depth": 12,
"num_timesteps": 1000, "num_timesteps": 1000,
"max_text_len": 77,
"num_time_embeds": 1, "num_time_embeds": 1,
"num_image_embeds": 1, "num_image_embeds": 1,
"num_text_embeds": 1, "num_text_embeds": 1,
@@ -16,8 +20,8 @@
"heads": 12, "heads": 12,
"ff_mult": 4, "ff_mult": 4,
"norm_out": true, "norm_out": true,
"attn_dropout": 0.05, "attn_dropout": 0.0,
"ff_dropout": 0.05, "ff_dropout": 0.0,
"final_proj": true, "final_proj": true,
"normformer": true, "normformer": true,
"rotary_emb": true "rotary_emb": true
@@ -26,7 +30,6 @@
"image_size": 224, "image_size": 224,
"image_channels": 3, "image_channels": 3,
"timesteps": 1000, "timesteps": 1000,
"sample_timesteps": 64,
"cond_drop_prob": 0.1, "cond_drop_prob": 0.1,
"loss_type": "l2", "loss_type": "l2",
"predict_x_start": true, "predict_x_start": true,
@@ -34,48 +37,34 @@
"condition_on_text_encodings": true "condition_on_text_encodings": true
}, },
"data": { "data": {
"batch_size": 128, "image_url": "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/",
"num_data_points": 100000, "text_url": "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/",
"eval_every_seconds": 1600, "meta_url": "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/laion2B-en-metadata/",
"image_url": "<path to your images>", "batch_size": 256,
"meta_url": "<path to your metadata>",
"splits": { "splits": {
"train": 0.8, "train": 0.9,
"val": 0.1, "val": 1e-7,
"test": 0.1 "test": 0.0999999
} }
}, },
"train": { "train": {
"epochs": 5, "epochs": 1,
"lr": 1.1e-4, "lr": 1.1e-4,
"wd": 6.02e-2, "wd": 6.02e-2,
"max_grad_norm": 0.5, "max_grad_norm": 0.5,
"use_ema": true, "use_ema": true,
"ema_beta": 0.9999,
"ema_update_after_step": 50,
"warmup_steps": 50,
"amp": false, "amp": false,
"save_every_seconds": 3600, "save_every": 10000
"eval_timesteps": [64, 1000], },
"random_seed": 84513 "load": {
"source": null,
"resume": false
}, },
"tracker": { "tracker": {
"data_path": ".prior", "tracker_type": "wandb",
"overwrite_data_path": true, "data_path": "./prior_checkpoints",
"log": { "wandb_entity": "laion",
"log_type": "wandb", "wandb_project": "diffusion-prior",
"wandb_entity": "<your wandb username>", "verbose": true
"wandb_project": "prior_debugging",
"wandb_resume": false,
"verbose": true
},
"save": [
{
"save_to": "local",
"save_type": "checkpoint",
"save_latest_to": ".prior/latest_checkpoint.pth",
"save_best_to": ".prior/best_checkpoint.pth"
}
]
} }
} }

File diff suppressed because it is too large Load Diff

View File

@@ -67,15 +67,6 @@ class PriorEmbeddingDataset(IterableDataset):
def __str__(self): def __str__(self):
return f"<PriorEmbeddingDataset: start: {self.start}, stop: {self.stop}, len: {self.__len__()}>" return f"<PriorEmbeddingDataset: start: {self.start}, stop: {self.stop}, len: {self.__len__()}>"
def set_start(self, start):
"""
Adjust the starting point within the reader, useful for resuming an epoch
"""
self.start = start
def get_start(self):
return self.start
def get_sample(self): def get_sample(self):
""" """
pre-proocess data from either reader into a common format pre-proocess data from either reader into a common format

View File

@@ -4,15 +4,13 @@ import json
from pathlib import Path from pathlib import Path
import shutil import shutil
from itertools import zip_longest from itertools import zip_longest
from typing import Any, Optional, List, Union from typing import Optional, List, Union
from pydantic import BaseModel from pydantic import BaseModel
import torch import torch
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
from dalle2_pytorch.utils import import_or_print_error from dalle2_pytorch.utils import import_or_print_error
from dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer from dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer
from dalle2_pytorch.version import __version__
from packaging import version
# constants # constants
@@ -23,6 +21,16 @@ DEFAULT_DATA_PATH = './.tracker-data'
def exists(val): def exists(val):
return val is not None return val is not None
# load file functions
def load_wandb_file(run_path, file_path, **kwargs):
wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb recall function')
file_reference = wandb.restore(file_path, run_path=run_path)
return file_reference.name
def load_local_file(file_path, **kwargs):
return file_path
class BaseLogger: class BaseLogger:
""" """
An abstract class representing an object that can log data. An abstract class representing an object that can log data.
@@ -226,7 +234,7 @@ class LocalLoader(BaseLoader):
def init(self, logger: BaseLogger, **kwargs) -> None: def init(self, logger: BaseLogger, **kwargs) -> None:
# Makes sure the file exists to be loaded # Makes sure the file exists to be loaded
if not self.file_path.exists() and not self.only_auto_resume: if not self.file_path.exists():
raise FileNotFoundError(f'Model not found at {self.file_path}') raise FileNotFoundError(f'Model not found at {self.file_path}')
def recall(self) -> dict: def recall(self) -> dict:
@@ -275,9 +283,9 @@ def create_loader(loader_type: str, data_path: str, **kwargs) -> BaseLoader:
class BaseSaver: class BaseSaver:
def __init__(self, def __init__(self,
data_path: str, data_path: str,
save_latest_to: Optional[Union[str, bool]] = None, save_latest_to: Optional[Union[str, bool]] = 'latest.pth',
save_best_to: Optional[Union[str, bool]] = None, save_best_to: Optional[Union[str, bool]] = 'best.pth',
save_meta_to: Optional[str] = None, save_meta_to: str = './',
save_type: str = 'checkpoint', save_type: str = 'checkpoint',
**kwargs **kwargs
): ):
@@ -287,10 +295,10 @@ class BaseSaver:
self.save_best_to = save_best_to self.save_best_to = save_best_to
self.saving_best = save_best_to is not None and save_best_to is not False self.saving_best = save_best_to is not None and save_best_to is not False
self.save_meta_to = save_meta_to self.save_meta_to = save_meta_to
self.saving_meta = save_meta_to is not None
self.save_type = save_type self.save_type = save_type
assert save_type in ['checkpoint', 'model'], '`save_type` must be one of `checkpoint` or `model`' assert save_type in ['checkpoint', 'model'], '`save_type` must be one of `checkpoint` or `model`'
assert self.saving_latest or self.saving_best or self.saving_meta, 'At least one saving option must be specified' assert self.save_meta_to is not None, '`save_meta_to` must be provided'
assert self.saving_latest or self.saving_best, '`save_latest_to` or `save_best_to` must be provided'
def init(self, logger: BaseLogger, **kwargs) -> None: def init(self, logger: BaseLogger, **kwargs) -> None:
raise NotImplementedError raise NotImplementedError
@@ -451,11 +459,6 @@ class Tracker:
print(f'\n\nWARNING: RUN HAS BEEN AUTO-RESUMED WITH THE LOGGER TYPE {self.logger.__class__.__name__}.\nIf this was not your intention, stop this run and set `auto_resume` to `False` in the config.\n\n') print(f'\n\nWARNING: RUN HAS BEEN AUTO-RESUMED WITH THE LOGGER TYPE {self.logger.__class__.__name__}.\nIf this was not your intention, stop this run and set `auto_resume` to `False` in the config.\n\n')
print(f"New logger config: {self.logger.__dict__}") print(f"New logger config: {self.logger.__dict__}")
self.save_metadata = dict(
version = version.parse(__version__)
) # Data that will be saved alongside the checkpoint or model
self.blacklisted_checkpoint_metadata_keys = ['scaler', 'optimizer', 'model', 'version', 'step', 'steps'] # These keys would cause us to error if we try to save them as metadata
assert self.logger is not None, '`logger` must be set before `init` is called' assert self.logger is not None, '`logger` must be set before `init` is called'
if self.dummy_mode: if self.dummy_mode:
# The only thing we need is a loader # The only thing we need is a loader
@@ -504,15 +507,8 @@ class Tracker:
# Save the config under config_name in the root folder of data_path # Save the config under config_name in the root folder of data_path
shutil.copy(current_config_path, self.data_path / config_name) shutil.copy(current_config_path, self.data_path / config_name)
for saver in self.savers: for saver in self.savers:
if saver.saving_meta: remote_path = Path(saver.save_meta_to) / config_name
remote_path = Path(saver.save_meta_to) / config_name saver.save_file(current_config_path, str(remote_path))
saver.save_file(current_config_path, str(remote_path))
def add_save_metadata(self, state_dict_key: str, metadata: Any):
"""
Adds a new piece of metadata that will be saved along with the model or decoder.
"""
self.save_metadata[state_dict_key] = metadata
def _save_state_dict(self, trainer: Union[DiffusionPriorTrainer, DecoderTrainer], save_type: str, file_path: str, **kwargs) -> Path: def _save_state_dict(self, trainer: Union[DiffusionPriorTrainer, DecoderTrainer], save_type: str, file_path: str, **kwargs) -> Path:
""" """
@@ -522,38 +518,24 @@ class Tracker:
""" """
assert save_type in ['checkpoint', 'model'] assert save_type in ['checkpoint', 'model']
if save_type == 'checkpoint': if save_type == 'checkpoint':
# Create a metadata dict without the blacklisted keys so we do not error when we create the state dict trainer.save(file_path, overwrite=True, **kwargs)
metadata = {k: v for k, v in self.save_metadata.items() if k not in self.blacklisted_checkpoint_metadata_keys}
trainer.save(file_path, overwrite=True, **kwargs, **metadata)
elif save_type == 'model': elif save_type == 'model':
if isinstance(trainer, DiffusionPriorTrainer): if isinstance(trainer, DiffusionPriorTrainer):
prior = trainer.ema_diffusion_prior.ema_model if trainer.use_ema else trainer.diffusion_prior prior = trainer.ema_diffusion_prior.ema_model if trainer.use_ema else trainer.diffusion_prior
prior: DiffusionPrior = trainer.accelerator.unwrap_model(prior) state_dict = trainer.unwrap_model(prior).state_dict()
# Remove CLIP if it is part of the model torch.save(state_dict, file_path)
original_clip = prior.clip
prior.clip = None
model_state_dict = prior.state_dict()
prior.clip = original_clip
elif isinstance(trainer, DecoderTrainer): elif isinstance(trainer, DecoderTrainer):
decoder: Decoder = trainer.accelerator.unwrap_model(trainer.decoder) decoder = trainer.accelerator.unwrap_model(trainer.decoder)
# Remove CLIP if it is part of the model
original_clip = decoder.clip
decoder.clip = None
if trainer.use_ema: if trainer.use_ema:
trainable_unets = decoder.unets trainable_unets = decoder.unets
decoder.unets = trainer.unets # Swap EMA unets in decoder.unets = trainer.unets # Swap EMA unets in
model_state_dict = decoder.state_dict() state_dict = decoder.state_dict()
decoder.unets = trainable_unets # Swap back decoder.unets = trainable_unets # Swap back
else: else:
model_state_dict = decoder.state_dict() state_dict = decoder.state_dict()
decoder.clip = original_clip torch.save(state_dict, file_path)
else: else:
raise NotImplementedError('Saving this type of model with EMA mode enabled is not yet implemented. Actually, how did you get here?') raise NotImplementedError('Saving this type of model with EMA mode enabled is not yet implemented. Actually, how did you get here?')
state_dict = {
**self.save_metadata,
'model': model_state_dict
}
torch.save(state_dict, file_path)
return Path(file_path) return Path(file_path)
def save(self, trainer, is_best: bool, is_latest: bool, **kwargs): def save(self, trainer, is_best: bool, is_latest: bool, **kwargs):

View File

@@ -1,7 +1,7 @@
import json import json
from torchvision import transforms as T from torchvision import transforms as T
from pydantic import BaseModel, validator, root_validator from pydantic import BaseModel, validator, root_validator
from typing import List, Optional, Union, Tuple, Dict, Any, TypeVar from typing import List, Iterable, Optional, Union, Tuple, Dict, Any
from x_clip import CLIP as XCLIP from x_clip import CLIP as XCLIP
from coca_pytorch import CoCa from coca_pytorch import CoCa
@@ -25,9 +25,11 @@ def exists(val):
def default(val, d): def default(val, d):
return val if exists(val) else d return val if exists(val) else d
InnerType = TypeVar('InnerType') def ListOrTuple(inner_type):
ListOrTuple = Union[List[InnerType], Tuple[InnerType]] return Union[List[inner_type], Tuple[inner_type]]
SingularOrIterable = Union[InnerType, ListOrTuple[InnerType]]
def SingularOrIterable(inner_type):
return Union[inner_type, ListOrTuple(inner_type)]
# general pydantic classes # general pydantic classes
@@ -127,7 +129,6 @@ class AdapterConfig(BaseModel):
class DiffusionPriorNetworkConfig(BaseModel): class DiffusionPriorNetworkConfig(BaseModel):
dim: int dim: int
depth: int depth: int
max_text_len: int = None
num_timesteps: int = None num_timesteps: int = None
num_time_embeds: int = 1 num_time_embeds: int = 1
num_image_embeds: int = 1 num_image_embeds: int = 1
@@ -135,7 +136,6 @@ class DiffusionPriorNetworkConfig(BaseModel):
dim_head: int = 64 dim_head: int = 64
heads: int = 8 heads: int = 8
ff_mult: int = 4 ff_mult: int = 4
norm_in: bool = False
norm_out: bool = True norm_out: bool = True
attn_dropout: float = 0. attn_dropout: float = 0.
ff_dropout: float = 0. ff_dropout: float = 0.
@@ -143,9 +143,6 @@ class DiffusionPriorNetworkConfig(BaseModel):
normformer: bool = False normformer: bool = False
rotary_emb: bool = True rotary_emb: bool = True
class Config:
extra = "allow"
def create(self): def create(self):
kwargs = self.dict() kwargs = self.dict()
return DiffusionPriorNetwork(**kwargs) return DiffusionPriorNetwork(**kwargs)
@@ -157,7 +154,6 @@ class DiffusionPriorConfig(BaseModel):
image_size: int image_size: int
image_channels: int = 3 image_channels: int = 3
timesteps: int = 1000 timesteps: int = 1000
sample_timesteps: Optional[int] = None
cond_drop_prob: float = 0. cond_drop_prob: float = 0.
loss_type: str = 'l2' loss_type: str = 'l2'
predict_x_start: bool = True predict_x_start: bool = True
@@ -188,26 +184,23 @@ class DiffusionPriorTrainConfig(BaseModel):
use_ema: bool = True use_ema: bool = True
ema_beta: float = 0.99 ema_beta: float = 0.99
amp: bool = False amp: bool = False
warmup_steps: int = None # number of warmup steps save_every: int = 10000 # what steps to save on
save_every_seconds: int = 3600 # how often to save
eval_timesteps: List[int] = [64] # which sampling timesteps to evaluate with
best_validation_loss: float = 1e9 # the current best valudation loss observed
current_epoch: int = 0 # the current epoch
num_samples_seen: int = 0 # the current number of samples seen
random_seed: int = 0 # manual seed for torch
class DiffusionPriorDataConfig(BaseModel): class DiffusionPriorDataConfig(BaseModel):
image_url: str # path to embeddings folder image_url: str # path to embeddings folder
meta_url: str # path to metadata (captions) for images meta_url: str # path to metadata (captions) for images
splits: TrainSplitConfig # define train, validation, test splits for your dataset splits: TrainSplitConfig
batch_size: int # per-gpu batch size used to train the model batch_size: int = 64
num_data_points: int = 25e7 # total number of datapoints to train on
eval_every_seconds: int = 3600 # validation statistics will be performed this often class DiffusionPriorLoadConfig(BaseModel):
source: str = None
resume: bool = False
class TrainDiffusionPriorConfig(BaseModel): class TrainDiffusionPriorConfig(BaseModel):
prior: DiffusionPriorConfig prior: DiffusionPriorConfig
data: DiffusionPriorDataConfig data: DiffusionPriorDataConfig
train: DiffusionPriorTrainConfig train: DiffusionPriorTrainConfig
load: DiffusionPriorLoadConfig
tracker: TrackerConfig tracker: TrackerConfig
@classmethod @classmethod
@@ -220,31 +213,29 @@ class TrainDiffusionPriorConfig(BaseModel):
class UnetConfig(BaseModel): class UnetConfig(BaseModel):
dim: int dim: int
dim_mults: ListOrTuple[int] dim_mults: ListOrTuple(int)
image_embed_dim: int = None image_embed_dim: int = None
text_embed_dim: int = None text_embed_dim: int = None
cond_on_text_encodings: bool = None cond_on_text_encodings: bool = None
cond_dim: int = None cond_dim: int = None
channels: int = 3 channels: int = 3
self_attn: ListOrTuple[int] self_attn: ListOrTuple(int)
attn_dim_head: int = 32 attn_dim_head: int = 32
attn_heads: int = 16 attn_heads: int = 16
init_cross_embed: bool = True
class Config: class Config:
extra = "allow" extra = "allow"
class DecoderConfig(BaseModel): class DecoderConfig(BaseModel):
unets: ListOrTuple[UnetConfig] unets: ListOrTuple(UnetConfig)
image_size: int = None image_size: int = None
image_sizes: ListOrTuple[int] = None image_sizes: ListOrTuple(int) = None
clip: Optional[AdapterConfig] # The clip model to use if embeddings are not provided clip: Optional[AdapterConfig] # The clip model to use if embeddings are not provided
channels: int = 3 channels: int = 3
timesteps: int = 1000 timesteps: int = 1000
sample_timesteps: Optional[SingularOrIterable[int]] = None
loss_type: str = 'l2' loss_type: str = 'l2'
beta_schedule: ListOrTuple[str] = None # None means all cosine beta_schedule: ListOrTuple(str) = 'cosine'
learned_variance: SingularOrIterable[bool] = True learned_variance: bool = True
image_cond_drop_prob: float = 0.1 image_cond_drop_prob: float = 0.1
text_cond_drop_prob: float = 0.5 text_cond_drop_prob: float = 0.5
@@ -303,22 +294,20 @@ class DecoderDataConfig(BaseModel):
class DecoderTrainConfig(BaseModel): class DecoderTrainConfig(BaseModel):
epochs: int = 20 epochs: int = 20
lr: SingularOrIterable[float] = 1e-4 lr: SingularOrIterable(float) = 1e-4
wd: SingularOrIterable[float] = 0.01 wd: SingularOrIterable(float) = 0.01
warmup_steps: Optional[SingularOrIterable[int]] = None warmup_steps: Optional[SingularOrIterable(int)] = None
find_unused_parameters: bool = True find_unused_parameters: bool = True
max_grad_norm: SingularOrIterable[float] = 0.5 max_grad_norm: SingularOrIterable(float) = 0.5
save_every_n_samples: int = 100000 save_every_n_samples: int = 100000
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
cond_scale: Union[float, List[float]] = 1.0
device: str = 'cuda:0' device: str = 'cuda:0'
epoch_samples: int = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite. epoch_samples: int = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
validation_samples: int = None # Same as above but for validation. validation_samples: int = None # Same as above but for validation.
save_immediately: bool = False
use_ema: bool = True use_ema: bool = True
ema_beta: float = 0.999 ema_beta: float = 0.999
amp: bool = False amp: bool = False
unet_training_mask: ListOrTuple[bool] = None # If None, use all unets unet_training_mask: ListOrTuple(bool) = None # If None, use all unets
class DecoderEvaluateConfig(BaseModel): class DecoderEvaluateConfig(BaseModel):
n_evaluation_samples: int = 1000 n_evaluation_samples: int = 1000
@@ -327,6 +316,12 @@ class DecoderEvaluateConfig(BaseModel):
KID: Dict[str, Any] = None KID: Dict[str, Any] = None
LPIPS: Dict[str, Any] = None LPIPS: Dict[str, Any] = None
class DecoderLoadConfig(BaseModel):
source: str = None # Supports file and wandb
run_path: str = '' # Used only if source is wandb
file_path: str = '' # The local filepath if source is file. If source is wandb, the relative path to the model file in wandb.
resume: bool = False # If using wandb, whether to resume the run
class TrainDecoderConfig(BaseModel): class TrainDecoderConfig(BaseModel):
decoder: DecoderConfig decoder: DecoderConfig
data: DecoderDataConfig data: DecoderDataConfig

View File

@@ -174,24 +174,26 @@ class DiffusionPriorTrainer(nn.Module):
def __init__( def __init__(
self, self,
diffusion_prior, diffusion_prior,
accelerator = None,
use_ema = True, use_ema = True,
lr = 3e-4, lr = 3e-4,
wd = 1e-2, wd = 1e-2,
eps = 1e-6, eps = 1e-6,
max_grad_norm = None, max_grad_norm = None,
amp = False,
group_wd_params = True, group_wd_params = True,
warmup_steps = 1, device = None,
accelerator = None,
verbose = True,
**kwargs **kwargs
): ):
super().__init__() super().__init__()
assert isinstance(diffusion_prior, DiffusionPrior) assert isinstance(diffusion_prior, DiffusionPrior)
assert not exists(accelerator) or isinstance(accelerator, Accelerator)
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs) ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
accelerator_kwargs, kwargs = groupby_prefix_and_trim('accelerator_', kwargs)
if not exists(accelerator): # verbosity
accelerator = Accelerator(**accelerator_kwargs)
self.verbose = verbose
# assign some helpful member vars # assign some helpful member vars
@@ -200,31 +202,23 @@ class DiffusionPriorTrainer(nn.Module):
# setting the device # setting the device
self.device = accelerator.device if not exists(accelerator) and not exists(device):
diffusion_prior.to(self.device) diffusion_prior_device = next(diffusion_prior.parameters()).device
self.print(f'accelerator not given, and device not specified: defaulting to device of diffusion prior parameters - {diffusion_prior_device}')
self.device = diffusion_prior_device
else:
self.device = accelerator.device if exists(accelerator) else device
diffusion_prior.to(self.device)
# save model # save model
self.diffusion_prior = diffusion_prior self.diffusion_prior = diffusion_prior
# mixed precision checks # optimizer and mixed precision stuff
if ( self.amp = amp
exists(self.accelerator)
and self.accelerator.distributed_type == DistributedType.DEEPSPEED
and self.diffusion_prior.clip is not None
):
# Then we need to make sure clip is using the correct precision or else deepspeed will error
cast_type_map = {
"fp16": torch.half,
"bf16": torch.bfloat16,
"no": torch.float
}
precision_type = cast_type_map[accelerator.mixed_precision]
assert precision_type == torch.float, "DeepSpeed currently only supports float32 precision when using on the fly embedding generation from clip"
self.diffusion_prior.clip.to(precision_type)
# optimizer stuff self.scaler = GradScaler(enabled = amp)
self.optim_kwargs = dict(lr=lr, wd=wd, eps=eps, group_wd_params=group_wd_params) self.optim_kwargs = dict(lr=lr, wd=wd, eps=eps, group_wd_params=group_wd_params)
@@ -234,20 +228,16 @@ class DiffusionPriorTrainer(nn.Module):
**kwargs **kwargs
) )
self.scheduler = LambdaLR(self.optimizer, lr_lambda = lambda _: 1.0)
self.warmup_scheduler = warmup.LinearWarmup(self.optimizer, warmup_period = warmup_steps) if exists(warmup_steps) else None
# distribute the model if using HFA # distribute the model if using HFA
if exists(self.accelerator):
self.diffusion_prior, self.optimizer, self.scheduler = self.accelerator.prepare(self.diffusion_prior, self.optimizer, self.scheduler) self.diffusion_prior, self.optimizer = self.accelerator.prepare(self.diffusion_prior, self.optimizer)
# exponential moving average stuff # exponential moving average stuff
self.use_ema = use_ema self.use_ema = use_ema
if self.use_ema: if self.use_ema:
self.ema_diffusion_prior = EMA(self.accelerator.unwrap_model(self.diffusion_prior), **ema_kwargs) self.ema_diffusion_prior = EMA(self.unwrap_model(self.diffusion_prior), **ema_kwargs)
# gradient clipping if needed # gradient clipping if needed
@@ -257,24 +247,67 @@ class DiffusionPriorTrainer(nn.Module):
self.register_buffer('step', torch.tensor([0], device = self.device)) self.register_buffer('step', torch.tensor([0], device = self.device))
# accelerator wrappers
def print(self, msg):
if not self.verbose:
return
if exists(self.accelerator):
self.accelerator.print(msg)
else:
print(msg)
def unwrap_model(self, model):
if exists(self.accelerator):
return self.accelerator.unwrap_model(model)
else:
return model
def wait_for_everyone(self):
if exists(self.accelerator):
self.accelerator.wait_for_everyone()
def is_main_process(self):
if exists(self.accelerator):
return self.accelerator.is_main_process
else:
return True
def clip_grad_norm_(self, *args):
if exists(self.accelerator):
return self.accelerator.clip_grad_norm_(*args)
else:
return torch.nn.utils.clip_grad_norm_(*args)
def backprop(self, x):
if exists(self.accelerator):
self.accelerator.backward(x)
else:
try:
x.backward()
except Exception as e:
self.print(f"Caught error in backprop call: {e}")
# utility # utility
def save(self, path, overwrite = True, **kwargs): def save(self, path, overwrite = True, **kwargs):
# ensure we sync gradients before continuing
self.wait_for_everyone()
# only save on the main process # only save on the main process
if self.accelerator.is_main_process: if self.is_main_process():
print(f"Saving checkpoint at step: {self.step.item()}") self.print(f"Saving checkpoint at step: {self.step.item()}")
path = Path(path) path = Path(path)
assert not (path.exists() and not overwrite) assert not (path.exists() and not overwrite)
path.parent.mkdir(parents = True, exist_ok = True) path.parent.mkdir(parents = True, exist_ok = True)
# FIXME: LambdaLR can't be saved due to pickling issues
save_obj = dict( save_obj = dict(
scaler = self.scaler.state_dict(),
optimizer = self.optimizer.state_dict(), optimizer = self.optimizer.state_dict(),
warmup_scheduler = self.warmup_scheduler, model = self.unwrap_model(self.diffusion_prior).state_dict(), # unwrap the model from distribution if applicable
model = self.accelerator.unwrap_model(self.diffusion_prior).state_dict(),
version = version.parse(__version__), version = version.parse(__version__),
step = self.step, step = self.step.item(),
**kwargs **kwargs
) )
@@ -287,14 +320,14 @@ class DiffusionPriorTrainer(nn.Module):
torch.save(save_obj, str(path)) torch.save(save_obj, str(path))
def load(self, path_or_state, overwrite_lr = True, strict = True): def load(self, path, overwrite_lr = True, strict = True):
""" """
Load a checkpoint of a diffusion prior trainer. Load a checkpoint of a diffusion prior trainer.
Will load the entire trainer, including the optimizer and EMA. Will load the entire trainer, including the optimizer and EMA.
Params: Params:
- path_or_state (str | torch): a path to the DiffusionPriorTrainer checkpoint file - path (str): a path to the DiffusionPriorTrainer checkpoint file
- overwrite_lr (bool): wether or not to overwrite the stored LR with the LR specified in the new trainer - overwrite_lr (bool): wether or not to overwrite the stored LR with the LR specified in the new trainer
- strict (bool): kwarg for `torch.nn.Module.load_state_dict`, will force an exact checkpoint match - strict (bool): kwarg for `torch.nn.Module.load_state_dict`, will force an exact checkpoint match
@@ -303,56 +336,56 @@ class DiffusionPriorTrainer(nn.Module):
""" """
# all processes need to load checkpoint. no restriction here # all processes need to load checkpoint. no restriction here
if isinstance(path_or_state, str): path = Path(path)
path = Path(path_or_state) assert path.exists()
assert path.exists()
loaded_obj = torch.load(str(path), map_location=self.device)
elif isinstance(path_or_state, dict): loaded_obj = torch.load(str(path), map_location=self.device)
loaded_obj = path_or_state
if version.parse(__version__) != loaded_obj['version']: if version.parse(__version__) != loaded_obj['version']:
print(f'loading saved diffusion prior at version {loaded_obj["version"]} but current package version is at {__version__}') print(f'loading saved diffusion prior at version {loaded_obj["version"]} but current package version is at {__version__}')
# unwrap the model when loading from checkpoint # unwrap the model when loading from checkpoint
self.accelerator.unwrap_model(self.diffusion_prior).load_state_dict(loaded_obj['model'], strict = strict) self.unwrap_model(self.diffusion_prior).load_state_dict(loaded_obj['model'], strict = strict)
self.step.copy_(torch.ones_like(self.step, device=self.device) * loaded_obj['step'].to(self.device)) self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
self.scaler.load_state_dict(loaded_obj['scaler'])
self.optimizer.load_state_dict(loaded_obj['optimizer']) self.optimizer.load_state_dict(loaded_obj['optimizer'])
# set warmupstep
if exists(self.warmup_scheduler):
self.warmup_scheduler.last_step = self.step.item()
# ensure new lr is used if different from old one
if overwrite_lr: if overwrite_lr:
new_lr = self.optim_kwargs["lr"] new_lr = self.optim_kwargs["lr"]
self.print(f"Overriding LR to be {new_lr}")
for group in self.optimizer.param_groups: for group in self.optimizer.param_groups:
group["lr"] = new_lr if group["lr"] > 0.0 else 0.0 group["lr"] = new_lr
if self.use_ema: if self.use_ema:
assert 'ema' in loaded_obj assert 'ema' in loaded_obj
self.ema_diffusion_prior.load_state_dict(loaded_obj['ema'], strict = strict) self.ema_diffusion_prior.load_state_dict(loaded_obj['ema'], strict = strict)
# below might not be necessary, but I had a suspicion that this wasn't being loaded correctly # below not be necessary, but I had a suspicion that this wasn't being loaded correctly
self.ema_diffusion_prior.ema_model.load_state_dict(loaded_obj["ema_model"]) self.ema_diffusion_prior.ema_model.load_state_dict(loaded_obj["ema_model"])
# sync and inform
self.wait_for_everyone()
self.print(f"Loaded model")
return loaded_obj return loaded_obj
# model functionality # model functionality
def update(self): def update(self):
# only continue with updates until all ranks finish
self.wait_for_everyone()
if exists(self.max_grad_norm): if exists(self.max_grad_norm):
self.accelerator.clip_grad_norm_(self.diffusion_prior.parameters(), self.max_grad_norm) self.scaler.unscale_(self.optimizer)
# utilize HFA clipping where applicable
self.clip_grad_norm_(self.diffusion_prior.parameters(), self.max_grad_norm)
self.optimizer.step() self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.zero_grad() self.optimizer.zero_grad()
# accelerator will ocassionally skip optimizer steps in a "dynamic loss scaling strategy"
if not self.accelerator.optimizer_step_was_skipped:
with self.warmup_scheduler.dampening():
self.scheduler.step()
if self.use_ema: if self.use_ema:
self.ema_diffusion_prior.update() self.ema_diffusion_prior.update()
@@ -381,7 +414,7 @@ class DiffusionPriorTrainer(nn.Module):
@cast_torch_tensor @cast_torch_tensor
@prior_sample_in_chunks @prior_sample_in_chunks
def embed_text(self, *args, **kwargs): def embed_text(self, *args, **kwargs):
return self.accelerator.unwrap_model(self.diffusion_prior).clip.embed_text(*args, **kwargs) return self.unwrap_model(self.diffusion_prior).clip.embed_text(*args, **kwargs)
@cast_torch_tensor @cast_torch_tensor
def forward( def forward(
@@ -393,14 +426,16 @@ class DiffusionPriorTrainer(nn.Module):
total_loss = 0. total_loss = 0.
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs): for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
with self.accelerator.autocast(): with autocast(enabled = self.amp):
loss = self.diffusion_prior(*chunked_args, **chunked_kwargs) loss = self.diffusion_prior(*chunked_args, **chunked_kwargs)
loss = loss * chunk_size_frac loss = loss * chunk_size_frac
total_loss += loss.item() total_loss += loss.item()
# backprop with accelerate if applicable
if self.training: if self.training:
self.accelerator.backward(loss) self.backprop(self.scaler.scale(loss))
return total_loss return total_loss
@@ -463,27 +498,23 @@ class DecoderTrainer(nn.Module):
warmup_schedulers = [] warmup_schedulers = []
for unet, unet_lr, unet_wd, unet_eps, unet_warmup_steps in zip(decoder.unets, lr, wd, eps, warmup_steps): for unet, unet_lr, unet_wd, unet_eps, unet_warmup_steps in zip(decoder.unets, lr, wd, eps, warmup_steps):
if isinstance(unet, nn.Identity): optimizer = get_optimizer(
optimizers.append(None) unet.parameters(),
schedulers.append(None) lr = unet_lr,
warmup_schedulers.append(None) wd = unet_wd,
else: eps = unet_eps,
optimizer = get_optimizer( group_wd_params = group_wd_params,
unet.parameters(), **kwargs
lr = unet_lr, )
wd = unet_wd,
eps = unet_eps,
group_wd_params = group_wd_params,
**kwargs
)
optimizers.append(optimizer) optimizers.append(optimizer)
scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)
warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) if exists(unet_warmup_steps) else None scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)
warmup_schedulers.append(warmup_scheduler)
schedulers.append(scheduler) warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) if exists(unet_warmup_steps) else None
warmup_schedulers.append(warmup_scheduler)
schedulers.append(scheduler)
if self.use_ema: if self.use_ema:
self.ema_unets.append(EMA(unet, **ema_kwargs)) self.ema_unets.append(EMA(unet, **ema_kwargs))
@@ -505,19 +536,11 @@ class DecoderTrainer(nn.Module):
assert precision_type == torch.float, "DeepSpeed currently only supports float32 precision when using on the fly embedding generation from clip" assert precision_type == torch.float, "DeepSpeed currently only supports float32 precision when using on the fly embedding generation from clip"
clip = decoder.clip clip = decoder.clip
clip.to(precision_type) clip.to(precision_type)
decoder, train_loader, val_loader, *optimizers = list(self.accelerator.prepare(decoder, dataloaders["train"], dataloaders["val"], *optimizers))
decoder, *optimizers = list(self.accelerator.prepare(decoder, *optimizers))
self.decoder = decoder
# prepare dataloaders
train_loader = val_loader = None
if exists(dataloaders):
train_loader, val_loader = self.accelerator.prepare(dataloaders["train"], dataloaders["val"])
self.train_loader = train_loader self.train_loader = train_loader
self.val_loader = val_loader self.val_loader = val_loader
self.decoder = decoder
# store optimizers # store optimizers
@@ -559,8 +582,7 @@ class DecoderTrainer(nn.Module):
for ind in range(0, self.num_unets): for ind in range(0, self.num_unets):
optimizer_key = f'optim{ind}' optimizer_key = f'optim{ind}'
optimizer = getattr(self, optimizer_key) optimizer = getattr(self, optimizer_key)
state_dict = optimizer.state_dict() if optimizer is not None else None save_obj = {**save_obj, optimizer_key: self.accelerator.unwrap_model(optimizer).state_dict()}
save_obj = {**save_obj, optimizer_key: state_dict}
if self.use_ema: if self.use_ema:
save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()} save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}
@@ -582,8 +604,8 @@ class DecoderTrainer(nn.Module):
optimizer_key = f'optim{ind}' optimizer_key = f'optim{ind}'
optimizer = getattr(self, optimizer_key) optimizer = getattr(self, optimizer_key)
warmup_scheduler = self.warmup_schedulers[ind] warmup_scheduler = self.warmup_schedulers[ind]
if optimizer is not None:
optimizer.load_state_dict(loaded_obj[optimizer_key]) self.accelerator.unwrap_model(optimizer).load_state_dict(loaded_obj[optimizer_key])
if exists(warmup_scheduler): if exists(warmup_scheduler):
warmup_scheduler.last_step = last_step warmup_scheduler.last_step = last_step
@@ -643,14 +665,8 @@ class DecoderTrainer(nn.Module):
def sample(self, *args, **kwargs): def sample(self, *args, **kwargs):
distributed = self.accelerator.num_processes > 1 distributed = self.accelerator.num_processes > 1
base_decoder = self.accelerator.unwrap_model(self.decoder) base_decoder = self.accelerator.unwrap_model(self.decoder)
was_training = base_decoder.training
base_decoder.eval()
if kwargs.pop('use_non_ema', False) or not self.use_ema: if kwargs.pop('use_non_ema', False) or not self.use_ema:
out = base_decoder.sample(*args, **kwargs, distributed = distributed) return base_decoder.sample(*args, **kwargs, distributed = distributed)
base_decoder.train(was_training)
return out
trainable_unets = self.accelerator.unwrap_model(self.decoder).unets trainable_unets = self.accelerator.unwrap_model(self.decoder).unets
base_decoder.unets = self.unets # swap in exponential moving averaged unets for sampling base_decoder.unets = self.unets # swap in exponential moving averaged unets for sampling
@@ -663,7 +679,6 @@ class DecoderTrainer(nn.Module):
for ema in self.ema_unets: for ema in self.ema_unets:
ema.restore_ema_model_device() ema.restore_ema_model_device()
base_decoder.train(was_training)
return output return output
@torch.no_grad() @torch.no_grad()
@@ -684,32 +699,23 @@ class DecoderTrainer(nn.Module):
*args, *args,
unet_number = None, unet_number = None,
max_batch_size = None, max_batch_size = None,
return_lowres_cond_image=False,
**kwargs **kwargs
): ):
unet_number = self.validate_and_return_unet_number(unet_number) unet_number = self.validate_and_return_unet_number(unet_number)
total_loss = 0. total_loss = 0.
cond_images = []
using_amp = self.accelerator.mixed_precision != 'no'
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs): for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
with self.accelerator.autocast(): with self.accelerator.autocast():
loss_obj = self.decoder(*chunked_args, unet_number = unet_number, return_lowres_cond_image=return_lowres_cond_image, **chunked_kwargs) loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs)
# loss_obj may be a tuple with loss and cond_image
if return_lowres_cond_image:
loss, cond_image = loss_obj
else:
loss = loss_obj
cond_image = None
loss = loss * chunk_size_frac loss = loss * chunk_size_frac
if cond_image is not None:
cond_images.append(cond_image)
total_loss += loss.item() total_loss += loss.item()
if self.training: if self.training:
self.accelerator.backward(loss) self.accelerator.backward(loss)
if return_lowres_cond_image: return total_loss
return total_loss, torch.stack(cond_images)
else:
return total_loss

View File

@@ -1 +1 @@
__version__ = '1.4.3' __version__ = '0.18.0'

View File

@@ -1,6 +1,5 @@
from pathlib import Path from pathlib import Path
from typing import List from typing import List
from datetime import timedelta
from dalle2_pytorch.trainer import DecoderTrainer from dalle2_pytorch.trainer import DecoderTrainer
from dalle2_pytorch.dataloaders import create_image_embedding_dataloader from dalle2_pytorch.dataloaders import create_image_embedding_dataloader
@@ -12,12 +11,11 @@ from clip import tokenize
import torchvision import torchvision
import torch import torch
from torch import nn
from torchmetrics.image.fid import FrechetInceptionDistance from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore from torchmetrics.image.inception import InceptionScore
from torchmetrics.image.kid import KernelInceptionDistance from torchmetrics.image.kid import KernelInceptionDistance
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
from accelerate import Accelerator, DistributedDataParallelKwargs, InitProcessGroupKwargs from accelerate import Accelerator, DistributedDataParallelKwargs
from accelerate.utils import dataclasses as accelerate_dataclasses from accelerate.utils import dataclasses as accelerate_dataclasses
import webdataset as wds import webdataset as wds
import click import click
@@ -134,7 +132,7 @@ def get_example_data(dataloader, device, n=5):
break break
return list(zip(images[:n], img_embeddings[:n], text_embeddings[:n], captions[:n])) return list(zip(images[:n], img_embeddings[:n], text_embeddings[:n], captions[:n]))
def generate_samples(trainer, example_data, start_unet=1, end_unet=None, condition_on_text_encodings=False, cond_scale=1.0, device=None, text_prepend="", match_image_size=True): def generate_samples(trainer, example_data, condition_on_text_encodings=False, text_prepend="", match_image_size=True):
""" """
Takes example data and generates images from the embeddings Takes example data and generates images from the embeddings
Returns three lists: real images, generated images, and captions Returns three lists: real images, generated images, and captions
@@ -159,13 +157,6 @@ def generate_samples(trainer, example_data, start_unet=1, end_unet=None, conditi
# Then we are using precomputed text embeddings # Then we are using precomputed text embeddings
text_embeddings = torch.stack(text_embeddings) text_embeddings = torch.stack(text_embeddings)
sample_params["text_encodings"] = text_embeddings sample_params["text_encodings"] = text_embeddings
sample_params["start_at_unet_number"] = start_unet
sample_params["stop_at_unet_number"] = end_unet
if start_unet > 1:
# If we are only training upsamplers
sample_params["image"] = torch.stack(real_images)
if device is not None:
sample_params["_device"] = device
samples = trainer.sample(**sample_params) samples = trainer.sample(**sample_params)
generated_images = list(samples) generated_images = list(samples)
captions = [text_prepend + txt for txt in txts] captions = [text_prepend + txt for txt in txts]
@@ -174,15 +165,15 @@ def generate_samples(trainer, example_data, start_unet=1, end_unet=None, conditi
real_images = [resize_image_to(image, generated_image_size, clamp_range=(0, 1)) for image in real_images] real_images = [resize_image_to(image, generated_image_size, clamp_range=(0, 1)) for image in real_images]
return real_images, generated_images, captions return real_images, generated_images, captions
def generate_grid_samples(trainer, examples, start_unet=1, end_unet=None, condition_on_text_encodings=False, cond_scale=1.0, device=None, text_prepend=""): def generate_grid_samples(trainer, examples, condition_on_text_encodings=False, text_prepend=""):
""" """
Generates samples and uses torchvision to put them in a side by side grid for easy viewing Generates samples and uses torchvision to put them in a side by side grid for easy viewing
""" """
real_images, generated_images, captions = generate_samples(trainer, examples, start_unet, end_unet, condition_on_text_encodings, cond_scale, device, text_prepend) real_images, generated_images, captions = generate_samples(trainer, examples, condition_on_text_encodings, text_prepend)
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)] grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
return grid_images, captions return grid_images, captions
def evaluate_trainer(trainer, dataloader, device, start_unet, end_unet, condition_on_text_encodings=False, cond_scale=1.0, inference_device=None, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None): def evaluate_trainer(trainer, dataloader, device, condition_on_text_encodings=False, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
""" """
Computes evaluation metrics for the decoder Computes evaluation metrics for the decoder
""" """
@@ -192,7 +183,7 @@ def evaluate_trainer(trainer, dataloader, device, start_unet, end_unet, conditi
if len(examples) == 0: if len(examples) == 0:
print("No data to evaluate. Check that your dataloader has shards.") print("No data to evaluate. Check that your dataloader has shards.")
return metrics return metrics
real_images, generated_images, captions = generate_samples(trainer, examples, start_unet, end_unet, condition_on_text_encodings, cond_scale, inference_device) real_images, generated_images, captions = generate_samples(trainer, examples, condition_on_text_encodings)
real_images = torch.stack(real_images).to(device=device, dtype=torch.float) real_images = torch.stack(real_images).to(device=device, dtype=torch.float)
generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float) generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float)
# Convert from [0, 1] to [0, 255] and from torch.float to torch.uint8 # Convert from [0, 1] to [0, 255] and from torch.float to torch.uint8
@@ -268,13 +259,11 @@ def train(
evaluate_config=None, evaluate_config=None,
epoch_samples = None, # If the training dataset is resampling, we have to manually stop an epoch epoch_samples = None, # If the training dataset is resampling, we have to manually stop an epoch
validation_samples = None, validation_samples = None,
save_immediately=False,
epochs = 20, epochs = 20,
n_sample_images = 5, n_sample_images = 5,
save_every_n_samples = 100000, save_every_n_samples = 100000,
unet_training_mask=None, unet_training_mask=None,
condition_on_text_encodings=False, condition_on_text_encodings=False,
cond_scale=1.0,
**kwargs **kwargs
): ):
""" """
@@ -282,21 +271,6 @@ def train(
""" """
is_master = accelerator.process_index == 0 is_master = accelerator.process_index == 0
if not exists(unet_training_mask):
# Then the unet mask should be true for all unets in the decoder
unet_training_mask = [True] * len(decoder.unets)
assert len(unet_training_mask) == len(decoder.unets), f"The unet training mask should be the same length as the number of unets in the decoder. Got {len(unet_training_mask)} and {trainer.num_unets}"
trainable_unet_numbers = [i+1 for i, trainable in enumerate(unet_training_mask) if trainable]
first_trainable_unet = trainable_unet_numbers[0]
last_trainable_unet = trainable_unet_numbers[-1]
def move_unets(unet_training_mask):
for i in range(len(decoder.unets)):
if not unet_training_mask[i]:
# Replace the unet from the module list with a nn.Identity(). This training script never uses unets that aren't being trained so this is fine.
decoder.unets[i] = nn.Identity().to(inference_device)
# Remove non-trainable unets
move_unets(unet_training_mask)
trainer = DecoderTrainer( trainer = DecoderTrainer(
decoder=decoder, decoder=decoder,
accelerator=accelerator, accelerator=accelerator,
@@ -311,7 +285,6 @@ def train(
sample = 0 sample = 0
samples_seen = 0 samples_seen = 0
val_sample = 0 val_sample = 0
step = lambda: int(trainer.num_steps_taken(unet_number=first_trainable_unet))
if tracker.can_recall: if tracker.can_recall:
start_epoch, validation_losses, next_task, recalled_sample, samples_seen = recall_trainer(tracker, trainer) start_epoch, validation_losses, next_task, recalled_sample, samples_seen = recall_trainer(tracker, trainer)
@@ -323,6 +296,13 @@ def train(
accelerator.print(f"Starting training from task {next_task} at sample {sample} and validation sample {val_sample}") accelerator.print(f"Starting training from task {next_task} at sample {sample} and validation sample {val_sample}")
trainer.to(device=inference_device) trainer.to(device=inference_device)
if not exists(unet_training_mask):
# Then the unet mask should be true for all unets in the decoder
unet_training_mask = [True] * trainer.num_unets
first_training_unet = min(index for index, mask in enumerate(unet_training_mask) if mask)
step = lambda: int(trainer.num_steps_taken(unet_number=first_training_unet+1))
assert len(unet_training_mask) == trainer.num_unets, f"The unet training mask should be the same length as the number of unets in the decoder. Got {len(unet_training_mask)} and {trainer.num_unets}"
accelerator.print(print_ribbon("Generating Example Data", repeat=40)) accelerator.print(print_ribbon("Generating Example Data", repeat=40))
accelerator.print("This can take a while to load the shard lists...") accelerator.print("This can take a while to load the shard lists...")
if is_master: if is_master:
@@ -343,7 +323,7 @@ def train(
last_snapshot = sample last_snapshot = sample
if next_task == 'train': if next_task == 'train':
for i, (img, emb, txt) in enumerate(dataloaders["train"]): for i, (img, emb, txt) in enumerate(trainer.train_loader):
# We want to count the total number of samples across all processes # We want to count the total number of samples across all processes
sample_length_tensor[0] = len(img) sample_length_tensor[0] = len(img)
all_samples = accelerator.gather(sample_length_tensor) # TODO: accelerator.reduce is broken when this was written. If it is fixed replace this. all_samples = accelerator.gather(sample_length_tensor) # TODO: accelerator.reduce is broken when this was written. If it is fixed replace this.
@@ -378,9 +358,8 @@ def train(
else: else:
# Then we need to pass the text instead # Then we need to pass the text instead
tokenized_texts = tokenize(txt, truncate=True) tokenized_texts = tokenize(txt, truncate=True)
assert tokenized_texts.shape[0] == len(img), f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})"
forward_params['text'] = tokenized_texts forward_params['text'] = tokenized_texts
loss = trainer.forward(img, **forward_params, unet_number=unet, _device=inference_device) loss = trainer.forward(img, **forward_params, unet_number=unet)
trainer.update(unet_number=unet) trainer.update(unet_number=unet)
unet_losses_tensor[i % TRAIN_CALC_LOSS_EVERY_ITERS, unet-1] = loss unet_losses_tensor[i % TRAIN_CALC_LOSS_EVERY_ITERS, unet-1] = loss
@@ -393,10 +372,10 @@ def train(
unet_all_losses = accelerator.gather(unet_losses_tensor) unet_all_losses = accelerator.gather(unet_losses_tensor)
mask = unet_all_losses != 0 mask = unet_all_losses != 0
unet_average_loss = (unet_all_losses * mask).sum(dim=0) / mask.sum(dim=0) unet_average_loss = (unet_all_losses * mask).sum(dim=0) / mask.sum(dim=0)
loss_map = { f"Unet {index} Training Loss": loss.item() for index, loss in enumerate(unet_average_loss) if unet_training_mask[index] } loss_map = { f"Unet {index} Training Loss": loss.item() for index, loss in enumerate(unet_average_loss) if loss != 0 }
# gather decay rate on each UNet # gather decay rate on each UNet
ema_decay_list = {f"Unet {index} EMA Decay": ema_unet.get_current_decay() for index, ema_unet in enumerate(trainer.ema_unets) if unet_training_mask[index]} ema_decay_list = {f"Unet {index} EMA Decay": ema_unet.get_current_decay() for index, ema_unet in enumerate(trainer.ema_unets)}
log_data = { log_data = {
"Epoch": epoch, "Epoch": epoch,
@@ -411,7 +390,7 @@ def train(
if is_master: if is_master:
tracker.log(log_data, step=step()) tracker.log(log_data, step=step())
if is_master and (last_snapshot + save_every_n_samples < sample or (save_immediately and i == 0)): # This will miss by some amount every time, but it's not a big deal... I hope if is_master and last_snapshot + save_every_n_samples < sample: # This will miss by some amount every time, but it's not a big deal... I hope
# It is difficult to gather this kind of info on the accelerator, so we have to do it on the master # It is difficult to gather this kind of info on the accelerator, so we have to do it on the master
print("Saving snapshot") print("Saving snapshot")
last_snapshot = sample last_snapshot = sample
@@ -419,7 +398,7 @@ def train(
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, samples_seen) save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, samples_seen)
if exists(n_sample_images) and n_sample_images > 0: if exists(n_sample_images) and n_sample_images > 0:
trainer.eval() trainer.eval()
train_images, train_captions = generate_grid_samples(trainer, train_example_data, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Train: ") train_images, train_captions = generate_grid_samples(trainer, train_example_data, condition_on_text_encodings, "Train: ")
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step()) tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
if epoch_samples is not None and sample >= epoch_samples: if epoch_samples is not None and sample >= epoch_samples:
@@ -437,7 +416,7 @@ def train(
timer = Timer() timer = Timer()
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
i = 0 i = 0
for i, (img, emb, txt) in enumerate(dataloaders['val']): # Use the accelerate prepared loader for i, (img, emb, txt) in enumerate(trainer.val_loader): # Use the accelerate prepared loader
val_sample_length_tensor[0] = len(img) val_sample_length_tensor[0] = len(img)
all_samples = accelerator.gather(val_sample_length_tensor) all_samples = accelerator.gather(val_sample_length_tensor)
total_samples = all_samples.sum().item() total_samples = all_samples.sum().item()
@@ -469,9 +448,8 @@ def train(
else: else:
# Then we need to pass the text instead # Then we need to pass the text instead
tokenized_texts = tokenize(txt, truncate=True) tokenized_texts = tokenize(txt, truncate=True)
assert tokenized_texts.shape[0] == len(img), f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})"
forward_params['text'] = tokenized_texts forward_params['text'] = tokenized_texts
loss = trainer.forward(img.float(), **forward_params, unet_number=unet, _device=inference_device) loss = trainer.forward(img.float(), **forward_params, unet_number=unet)
average_val_loss_tensor[0, unet-1] += loss average_val_loss_tensor[0, unet-1] += loss
if i % VALID_CALC_LOSS_EVERY_ITERS == 0: if i % VALID_CALC_LOSS_EVERY_ITERS == 0:
@@ -498,7 +476,7 @@ def train(
if next_task == 'eval': if next_task == 'eval':
if exists(evaluate_config): if exists(evaluate_config):
accelerator.print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40)) accelerator.print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, first_trainable_unet, last_trainable_unet, inference_device=inference_device, **evaluate_config.dict(), condition_on_text_encodings=condition_on_text_encodings, cond_scale=cond_scale) evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config.dict(), condition_on_text_encodings=condition_on_text_encodings)
if is_master: if is_master:
tracker.log(evaluation, step=step()) tracker.log(evaluation, step=step())
next_task = 'sample' next_task = 'sample'
@@ -509,15 +487,15 @@ def train(
# Generate examples and save the model if we are the master # Generate examples and save the model if we are the master
# Generate sample images # Generate sample images
print(print_ribbon(f"Sampling Set {epoch}", repeat=40)) print(print_ribbon(f"Sampling Set {epoch}", repeat=40))
test_images, test_captions = generate_grid_samples(trainer, test_example_data, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Test: ") test_images, test_captions = generate_grid_samples(trainer, test_example_data, condition_on_text_encodings, "Test: ")
train_images, train_captions = generate_grid_samples(trainer, train_example_data, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Train: ") train_images, train_captions = generate_grid_samples(trainer, train_example_data, condition_on_text_encodings, "Train: ")
tracker.log_images(test_images, captions=test_captions, image_section="Test Samples", step=step()) tracker.log_images(test_images, captions=test_captions, image_section="Test Samples", step=step())
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step()) tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
print(print_ribbon(f"Starting Saving {epoch}", repeat=40)) print(print_ribbon(f"Starting Saving {epoch}", repeat=40))
is_best = False is_best = False
if all_average_val_losses is not None: if all_average_val_losses is not None:
average_loss = all_average_val_losses.mean(dim=0).sum() / sum(unet_training_mask) average_loss = all_average_val_losses.mean(dim=0).item()
if len(validation_losses) == 0 or average_loss < min(validation_losses): if len(validation_losses) == 0 or average_loss < min(validation_losses):
is_best = True is_best = True
validation_losses.append(average_loss) validation_losses.append(average_loss)
@@ -534,7 +512,6 @@ def create_tracker(accelerator: Accelerator, config: TrainDecoderConfig, config_
} }
tracker: Tracker = tracker_config.create(config, accelerator_config, dummy_mode=dummy) tracker: Tracker = tracker_config.create(config, accelerator_config, dummy_mode=dummy)
tracker.save_config(config_path, config_name='decoder_config.json') tracker.save_config(config_path, config_name='decoder_config.json')
tracker.add_save_metadata(state_dict_key='config', metadata=config.dict())
return tracker return tracker
def initialize_training(config: TrainDecoderConfig, config_path): def initialize_training(config: TrainDecoderConfig, config_path):
@@ -543,8 +520,7 @@ def initialize_training(config: TrainDecoderConfig, config_path):
# Set up accelerator for configurable distributed training # Set up accelerator for configurable distributed training
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=config.train.find_unused_parameters) ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=config.train.find_unused_parameters)
init_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=60*60)) accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs, init_kwargs])
if accelerator.num_processes > 1: if accelerator.num_processes > 1:
# We are using distributed training and want to immediately ensure all can connect # We are using distributed training and want to immediately ensure all can connect
@@ -581,7 +557,7 @@ def initialize_training(config: TrainDecoderConfig, config_path):
# Create the decoder model and print basic info # Create the decoder model and print basic info
decoder = config.decoder.create() decoder = config.decoder.create()
get_num_parameters = lambda model, only_training=False: sum(p.numel() for p in model.parameters() if (p.requires_grad or not only_training)) num_parameters = sum(p.numel() for p in decoder.parameters())
# Create and initialize the tracker if we are the master # Create and initialize the tracker if we are the master
tracker = create_tracker(accelerator, config, config_path, dummy = rank!=0) tracker = create_tracker(accelerator, config, config_path, dummy = rank!=0)
@@ -610,10 +586,7 @@ def initialize_training(config: TrainDecoderConfig, config_path):
accelerator.print(print_ribbon("Loaded Config", repeat=40)) accelerator.print(print_ribbon("Loaded Config", repeat=40))
accelerator.print(f"Running training with {accelerator.num_processes} processes and {accelerator.distributed_type} distributed training") accelerator.print(f"Running training with {accelerator.num_processes} processes and {accelerator.distributed_type} distributed training")
accelerator.print(f"Training using {data_source_string}. {'conditioned on text' if conditioning_on_text else 'not conditioned on text'}") accelerator.print(f"Training using {data_source_string}. {'conditioned on text' if conditioning_on_text else 'not conditioned on text'}")
accelerator.print(f"Number of parameters: {get_num_parameters(decoder)} total; {get_num_parameters(decoder, only_training=True)} training") accelerator.print(f"Number of parameters: {num_parameters}")
for i, unet in enumerate(decoder.unets):
accelerator.print(f"Unet {i} has {get_num_parameters(unet)} total; {get_num_parameters(unet, only_training=True)} training")
train(dataloaders, decoder, accelerator, train(dataloaders, decoder, accelerator,
tracker=tracker, tracker=tracker,
inference_device=accelerator.device, inference_device=accelerator.device,

View File

@@ -1,23 +1,31 @@
# TODO: add start, num_data_points, eval_every and group to config
# TODO: switch back to repo's wandb
START = 0
NUM_DATA_POINTS = 250e6
EVAL_EVERY = 1000
GROUP = "distributed"
import os
import click import click
import wandb
import torch import torch
from torch import nn from torch import nn
from typing import List
from accelerate import Accelerator
from accelerate.utils import set_seed
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from embedding_reader import EmbeddingReader
from accelerate.utils import dataclasses as accelerate_dataclasses
from dalle2_pytorch.utils import Timer import numpy as np
from dalle2_pytorch.trackers import Tracker
from dalle2_pytorch import DiffusionPriorTrainer from accelerate import Accelerator
from dalle2_pytorch.dataloaders import get_reader, make_splits from dalle2_pytorch.dataloaders import get_reader, make_splits
from dalle2_pytorch.utils import Timer
from dalle2_pytorch.train_configs import ( from dalle2_pytorch.train_configs import (
DiffusionPriorConfig,
DiffusionPriorTrainConfig, DiffusionPriorTrainConfig,
TrainDiffusionPriorConfig, TrainDiffusionPriorConfig,
) )
from dalle2_pytorch.trackers import BaseTracker, WandbTracker
from dalle2_pytorch import DiffusionPriorTrainer
# helpers # helpers
@@ -30,19 +38,8 @@ def exists(val):
return val is not None return val is not None
def all_between(values: list, lower_bound, upper_bound):
for value in values:
if value < lower_bound or value > upper_bound:
return False
return True
def make_model( def make_model(
prior_config: DiffusionPriorConfig, prior_config, train_config, device: str = None, accelerator: Accelerator = None
train_config: DiffusionPriorTrainConfig,
device: str = None,
accelerator: Accelerator = None,
): ):
# create model from config # create model from config
diffusion_prior = prior_config.create() diffusion_prior = prior_config.create()
@@ -57,214 +54,71 @@ def make_model(
use_ema=train_config.use_ema, use_ema=train_config.use_ema,
device=device, device=device,
accelerator=accelerator, accelerator=accelerator,
warmup_steps=train_config.warmup_steps,
) )
return trainer return trainer
def create_tracker(
accelerator: Accelerator,
config: TrainDiffusionPriorConfig,
config_path: str,
dummy: bool = False,
) -> Tracker:
tracker_config = config.tracker
accelerator_config = {
"Distributed": accelerator.distributed_type
!= accelerate_dataclasses.DistributedType.NO,
"DistributedType": accelerator.distributed_type,
"NumProcesses": accelerator.num_processes,
"MixedPrecision": accelerator.mixed_precision,
}
tracker: Tracker = tracker_config.create(
config, accelerator_config, dummy_mode=dummy
)
tracker.save_config(config_path, config_name="prior_config.json")
return tracker
def pad_gather_reduce(trainer: DiffusionPriorTrainer, x, method="mean"):
"""
pad a value or tensor across all processes and gather
params:
- trainer: a trainer that carries an accelerator object
- x: a number or torch tensor to reduce
- method: "mean", "sum", "max", "min"
return:
- the average tensor after maskin out 0's
- None if the gather resulted in an empty tensor
"""
assert method in [
"mean",
"sum",
"max",
"min",
], "This function has limited capabilities [sum, mean, max, min]"
assert type(x) is not None, "Cannot reduce a None type object"
# wait for everyone to arrive here before gathering
if type(x) is not torch.Tensor:
x = torch.tensor([x])
# verify that the tensor is on the proper device
x = x.to(trainer.device)
# pad across processes
padded_x = trainer.accelerator.pad_across_processes(x, dim=0)
# gather across all procesess
gathered_x = trainer.accelerator.gather(padded_x)
# mask out zeros
masked_x = gathered_x[gathered_x != 0]
# if the tensor is empty, warn and return None
if len(masked_x) == 0:
click.secho(
f"The call to this method resulted in an empty tensor after masking out zeros. The gathered tensor was this: {gathered_x} and the original value passed was: {x}.",
fg="red",
)
return None
if method == "mean":
return torch.mean(masked_x)
elif method == "sum":
return torch.sum(masked_x)
elif method == "max":
return torch.max(masked_x)
elif method == "min":
return torch.min(masked_x)
def save_trainer(
tracker: Tracker,
trainer: DiffusionPriorTrainer,
is_latest: bool,
is_best: bool,
epoch: int,
samples_seen: int,
best_validation_loss: float,
):
"""
Logs the model with an appropriate method depending on the tracker
"""
trainer.accelerator.wait_for_everyone()
if trainer.accelerator.is_main_process:
click.secho(
f"RANK:{trainer.accelerator.process_index} | Saving Model | Best={is_best} | Latest={is_latest}",
fg="magenta",
)
tracker.save(
trainer=trainer,
is_best=is_best,
is_latest=is_latest,
epoch=int(epoch),
samples_seen=int(samples_seen),
best_validation_loss=best_validation_loss,
)
def recall_trainer(tracker: Tracker, trainer: DiffusionPriorTrainer):
"""
Loads the model with an appropriate method depending on the tracker
"""
if trainer.accelerator.is_main_process:
click.secho(f"Loading model from {type(tracker.loader).__name__}", fg="yellow")
state_dict = tracker.recall()
trainer.load(state_dict, strict=True)
return (
int(state_dict.get("epoch", 0)),
state_dict.get("best_validation_loss", 0),
int(state_dict.get("samples_seen", 0)),
)
# eval functions # eval functions
def report_validation_loss( def eval_model(
trainer: DiffusionPriorTrainer, trainer: DiffusionPriorTrainer,
dataloader: DataLoader, dataloader: DataLoader,
text_conditioned: bool, text_conditioned: bool,
use_ema: bool,
tracker: Tracker,
split: str,
tracker_folder: str,
loss_type: str, loss_type: str,
tracker_context: str,
tracker: BaseTracker = None,
use_ema: bool = True,
): ):
""" trainer.eval()
Compute the validation loss on a given subset of data. if trainer.is_main_process():
""" click.secho(f"Measuring performance on {tracker_context}", fg="green", blink=True)
if trainer.accelerator.is_main_process: with torch.no_grad():
click.secho( total_loss = 0.0
f"Measuring performance on {use_ema}-{split} split", total_samples = 0.0
fg="green",
blink=True,
)
total_loss = torch.zeros(1, dtype=torch.float, device=trainer.device) for image_embeddings, text_data in dataloader:
image_embeddings = image_embeddings.to(trainer.device)
text_data = text_data.to(trainer.device)
for image_embeddings, text_data in dataloader: batches = image_embeddings.shape[0]
image_embeddings = image_embeddings.to(trainer.device)
text_data = text_data.to(trainer.device)
input_args = dict(image_embed=image_embeddings) input_args = dict(image_embed=image_embeddings)
if text_conditioned: if text_conditioned:
input_args = dict(**input_args, text=text_data) input_args = dict(**input_args, text=text_data)
else: else:
input_args = dict(**input_args, text_embed=text_data) input_args = dict(**input_args, text_embed=text_data)
if use_ema: if use_ema:
loss = trainer.ema_diffusion_prior(**input_args) loss = trainer.ema_diffusion_prior(**input_args)
else: else:
loss = trainer(**input_args) loss = trainer(**input_args)
total_loss += loss total_loss += loss * batches
total_samples += batches
# compute the average loss across all processes avg_loss = total_loss / total_samples
avg_loss = pad_gather_reduce(trainer, total_loss, method="mean") stats = {f"{tracker_context}-{loss_type}": avg_loss}
stats = {f"{tracker_folder}/{loss_type}-loss": avg_loss} trainer.print(stats)
# print and log results on main process if exists(tracker):
tracker.log(stats, step=trainer.step.item() + 1) tracker.log(stats, step=trainer.step.item() + 1)
return avg_loss
def report_cosine_sims( def report_cosine_sims(
trainer: DiffusionPriorTrainer, trainer: DiffusionPriorTrainer,
dataloader: DataLoader, dataloader: DataLoader,
text_conditioned: bool, text_conditioned: bool,
tracker: Tracker, tracker: BaseTracker,
split: str, tracker_context: str = "validation",
timesteps: int,
tracker_folder: str,
): ):
trainer.eval() trainer.eval()
if trainer.accelerator.is_main_process: if trainer.is_main_process():
click.secho( click.secho("Measuring Cosine-Similarity", fg="green", blink=True)
f"Measuring Cosine-Similarity on {split} split with {timesteps} timesteps",
fg="green",
blink=True,
)
for test_image_embeddings, text_data in dataloader: for test_image_embeddings, text_data in dataloader:
test_image_embeddings = test_image_embeddings.to(trainer.device) test_image_embeddings = test_image_embeddings.to(trainer.device)
@@ -272,8 +126,10 @@ def report_cosine_sims(
# we are text conditioned, we produce an embedding from the tokenized text # we are text conditioned, we produce an embedding from the tokenized text
if text_conditioned: if text_conditioned:
text_embedding, text_encodings = trainer.embed_text(text_data) text_embedding, text_encodings, text_mask = trainer.embed_text(text_data)
text_cond = dict(text_embed=text_embedding, text_encodings=text_encodings) text_cond = dict(
text_embed=text_embedding, text_encodings=text_encodings, mask=text_mask
)
else: else:
text_embedding = text_data text_embedding = text_data
text_cond = dict(text_embed=text_embedding) text_cond = dict(text_embed=text_embedding)
@@ -290,11 +146,15 @@ def report_cosine_sims(
if text_conditioned: if text_conditioned:
text_encodings_shuffled = text_encodings[rolled_idx] text_encodings_shuffled = text_encodings[rolled_idx]
text_mask_shuffled = text_mask[rolled_idx]
else: else:
text_encodings_shuffled = None text_encodings_shuffled = None
text_mask_shuffled = None
text_cond_shuffled = dict( text_cond_shuffled = dict(
text_embed=text_embed_shuffled, text_encodings=text_encodings_shuffled text_embed=text_embed_shuffled,
text_encodings=text_encodings_shuffled,
mask=text_mask_shuffled,
) )
# prepare the text embedding # prepare the text embedding
@@ -307,9 +167,7 @@ def report_cosine_sims(
# predict on the unshuffled text embeddings # predict on the unshuffled text embeddings
predicted_image_embeddings = trainer.p_sample_loop( predicted_image_embeddings = trainer.p_sample_loop(
test_image_embeddings.shape, test_image_embeddings.shape, text_cond
text_cond,
timesteps=timesteps,
) )
predicted_image_embeddings = ( predicted_image_embeddings = (
@@ -319,9 +177,7 @@ def report_cosine_sims(
# predict on the shuffled embeddings # predict on the shuffled embeddings
predicted_unrelated_embeddings = trainer.p_sample_loop( predicted_unrelated_embeddings = trainer.p_sample_loop(
test_image_embeddings.shape, test_image_embeddings.shape, text_cond_shuffled
text_cond_shuffled,
timesteps=timesteps,
) )
predicted_unrelated_embeddings = ( predicted_unrelated_embeddings = (
@@ -330,97 +186,32 @@ def report_cosine_sims(
) )
# calculate similarities # calculate similarities
orig_sim = pad_gather_reduce( original_similarity = cos(text_embed, test_image_embeddings).cpu().numpy()
trainer, cos(text_embed, test_image_embeddings), method="mean" predicted_similarity = cos(text_embed, predicted_image_embeddings).cpu().numpy()
unrelated_similarity = (
cos(text_embed, predicted_unrelated_embeddings).cpu().numpy()
) )
pred_sim = pad_gather_reduce( predicted_img_similarity = (
trainer, cos(text_embed, predicted_image_embeddings), method="mean" cos(test_image_embeddings, predicted_image_embeddings).cpu().numpy()
)
unrel_sim = pad_gather_reduce(
trainer, cos(text_embed, predicted_unrelated_embeddings), method="mean"
)
pred_img_sim = pad_gather_reduce(
trainer,
cos(test_image_embeddings, predicted_image_embeddings),
method="mean",
) )
stats = { stats = {
f"{tracker_folder}/baseline similarity [steps={timesteps}]": orig_sim, f"{tracker_context}/baseline similarity": np.mean(original_similarity),
f"{tracker_folder}/similarity with text [steps={timesteps}]": pred_sim, f"{tracker_context}/similarity with text": np.mean(predicted_similarity),
f"{tracker_folder}/similarity with original image [steps={timesteps}]": pred_img_sim, f"{tracker_context}/similarity with original image": np.mean(
f"{tracker_folder}/similarity with unrelated caption [steps={timesteps}]": unrel_sim, predicted_img_similarity
f"{tracker_folder}/difference from baseline similarity [steps={timesteps}]": pred_sim ),
- orig_sim, f"{tracker_context}/similarity with unrelated caption": np.mean(unrelated_similarity),
f"{tracker_context}/difference from baseline similarity": np.mean(
predicted_similarity - original_similarity
),
} }
tracker.log(stats, step=trainer.step.item() + 1) for k, v in stats.items():
trainer.print(f"{tracker_context}/{k}: {v}")
if exists(tracker):
def eval_model( tracker.log(stats, step=trainer.step.item() + 1)
trainer: DiffusionPriorTrainer,
dataloader: DataLoader,
text_conditioned: bool,
split: str,
tracker: Tracker,
use_ema: bool,
report_cosine: bool,
report_loss: bool,
timesteps: List[int],
loss_type: str = None,
):
"""
Run evaluation on a model and track metrics
returns: loss if requested
"""
trainer.eval()
use_ema = "ema" if use_ema else "online"
tracker_folder = f"metrics/{use_ema}-{split}"
# detemine if valid timesteps are passed
min_timesteps = trainer.accelerator.unwrap_model(
trainer.diffusion_prior
).sample_timesteps
max_timesteps = trainer.accelerator.unwrap_model(
trainer.diffusion_prior
).noise_scheduler.num_timesteps
assert all_between(
timesteps, lower_bound=min_timesteps, upper_bound=max_timesteps
), f"all timesteps values must be between {min_timesteps} and {max_timesteps}: got {timesteps}"
# measure cosine metrics across various eta and timesteps
if report_cosine:
for timestep in timesteps:
report_cosine_sims(
trainer,
dataloader=dataloader,
text_conditioned=text_conditioned,
tracker=tracker,
split=split,
timesteps=timestep,
tracker_folder=tracker_folder,
)
# measure loss on a seperate split of data
if report_loss:
loss = report_validation_loss(
trainer=trainer,
dataloader=dataloader,
text_conditioned=text_conditioned,
use_ema=use_ema,
tracker=tracker,
split=split,
tracker_folder=tracker_folder,
loss_type=loss_type,
)
return loss
# training script # training script
@@ -428,327 +219,182 @@ def eval_model(
def train( def train(
trainer: DiffusionPriorTrainer, trainer: DiffusionPriorTrainer,
tracker: Tracker,
train_loader: DataLoader, train_loader: DataLoader,
eval_loader: DataLoader, eval_loader: DataLoader,
test_loader: DataLoader, test_loader: DataLoader,
config: DiffusionPriorTrainConfig, config: DiffusionPriorTrainConfig,
): ):
# init timers # distributed tracking with wandb
save_timer = Timer() # when to save if trainer.accelerator.num_processes > 1:
samples_timer = Timer() # samples/sec os.environ["WANDB_START_METHOD"] = "thread"
validation_profiler = Timer() # how long is validation taking
validation_countdown = Timer() # when to perform evalutation
# keep track of best validation loss tracker = wandb.init(
name=f"RANK:{trainer.device}",
entity=config.tracker.wandb_entity,
project=config.tracker.wandb_project,
config=config.dict(),
group=GROUP,
)
best_validation_loss = config.train.best_validation_loss # sync after tracker init
samples_seen = config.train.num_samples_seen trainer.wait_for_everyone()
# init a timer
timer = Timer()
# do training # do training
for img, txt in train_loader:
trainer.train()
current_step = trainer.step.item() + 1
start_epoch = config.train.current_epoch # place data on device
img = img.to(trainer.device)
txt = txt.to(trainer.device)
for epoch in range(start_epoch, config.train.epochs): # pass to model
# if we finished out an old epoch, reset the distribution to be a full epoch loss = trainer(text=txt, image_embed=img)
tracker.log({"tracking/epoch": epoch}, step=trainer.step.item())
if train_loader.dataset.get_start() > 0 and epoch == start_epoch+1: # display & log loss (will only print from main process)
if trainer.accelerator.is_main_process: trainer.print(f"Step {current_step}: Loss {loss}")
click.secho(f"Finished resumed epoch...resetting dataloader.")
train_loader.dataset.set_start(0)
for img, txt in train_loader: # perform backprop & apply EMA updates
# setup things every step trainer.update()
trainer.train() # track samples/sec/rank
current_step = trainer.step.item() samples_per_sec = img.shape[0] / timer.elapsed()
samples_timer.reset()
# place data on device # samples seen
samples_seen = (
config.data.batch_size * trainer.accelerator.num_processes * current_step
)
img = img.to(trainer.device) # ema decay
txt = txt.to(trainer.device) ema_decay = trainer.ema_diffusion_prior.get_current_decay()
# pass to model # Log on all processes for debugging
tracker.log(
{
"tracking/samples-sec": samples_per_sec,
"tracking/samples-seen": samples_seen,
"tracking/ema-decay": ema_decay,
"metrics/training-loss": loss,
},
step=current_step,
)
loss = trainer(text=txt, image_embed=img) # Metric Tracking & Checkpointing (outside of timer's scope)
if current_step % EVAL_EVERY == 0:
# perform backprop & apply EMA updates eval_model(
trainer=trainer,
trainer.update() dataloader=eval_loader,
text_conditioned=config.prior.condition_on_text_encodings,
# gather info about training step loss_type=config.prior.loss_type,
tracker_context="metrics/online-model-validation",
all_loss = pad_gather_reduce(trainer, loss, method="mean") tracker=tracker,
num_samples = pad_gather_reduce(trainer, len(txt), method="sum") use_ema=False,
samples_per_sec = num_samples / samples_timer.elapsed()
samples_seen += num_samples
ema_decay = trainer.ema_diffusion_prior.get_current_decay()
# log
tracker.log(
{
"tracking/samples-sec": samples_per_sec,
"tracking/samples-seen": samples_seen,
"tracking/ema-decay": ema_decay,
f"tracking/training-{config.prior.loss_type}": all_loss,
},
step=current_step,
) )
# Metric Tracking @ Timed Intervals eval_model(
trainer=trainer,
eval_delta = pad_gather_reduce( dataloader=eval_loader,
trainer, validation_countdown.elapsed(), method="min" text_conditioned=config.prior.condition_on_text_encodings,
loss_type=config.prior.loss_type,
tracker_context="metrics/ema-model-validation",
tracker=tracker,
use_ema=True,
) )
if eval_delta != None and eval_delta > config.data.eval_every_seconds: report_cosine_sims(
# begin timing how long this takes trainer=trainer,
dataloader=eval_loader,
text_conditioned=config.prior.condition_on_text_encodings,
tracker=tracker,
tracker_context="metrics",
)
validation_profiler.reset() if current_step % config.train.save_every == 0:
trainer.save(f"{config.tracker.data_path}/chkpt_step_{current_step}.pth")
# package kwargs for evaluation # reset timer for next round
timer.reset()
eval_kwargs = {
"trainer": trainer,
"tracker": tracker,
"text_conditioned": config.prior.condition_on_text_encodings,
"timesteps": config.train.eval_timesteps,
}
# ONLINE MODEL : COSINE : LOSS : VALIDATION SPLIT
eval_model(
dataloader=eval_loader,
loss_type=config.prior.loss_type,
split="validation",
use_ema=False,
report_cosine=False,
report_loss=True,
**eval_kwargs,
)
# EMA MODEL : COSINE : LOSS : VALIDATION DATA
ema_val_loss = eval_model(
dataloader=eval_loader,
loss_type=config.prior.loss_type,
split="validation",
use_ema=True,
report_cosine=True,
report_loss=True,
**eval_kwargs,
)
tracker.log(
{
"tracking/validation length (minutes)": validation_profiler.elapsed()
/ 60
}
)
# check if the ema validation is the lowest seen yet
if ema_val_loss < best_validation_loss:
best_validation_loss = ema_val_loss
# go save the model as best
save_trainer(
trainer=trainer,
tracker=tracker,
is_best=True,
is_latest=False,
samples_seen=samples_seen,
epoch=epoch,
best_validation_loss=best_validation_loss,
)
# reset timer for validaiton
validation_countdown.reset()
elif eval_delta is None:
click.secho(
f"Error occured reading the eval time on rank: {trainer.device}",
fg="yellow",
)
# save as latest model on schedule
save_delta = pad_gather_reduce(trainer, save_timer.elapsed(), method="min")
if save_delta != None and save_delta >= config.train.save_every_seconds:
save_trainer(
trainer=trainer,
tracker=tracker,
is_best=False,
is_latest=True,
samples_seen=samples_seen,
epoch=epoch,
best_validation_loss=best_validation_loss,
)
save_timer.reset()
elif save_delta is None:
click.secho(
f"Error occured reading the save time on rank: {trainer.device}",
fg="yellow",
)
# evaluate on test data # evaluate on test data
if trainer.accelerator.is_main_process: eval_model(
click.secho(f"Starting Test", fg="red")
# save one last time as latest before beginning validation
save_trainer(
tracker=tracker,
trainer=trainer,
is_best=False,
is_latest=True,
samples_seen=samples_seen,
epoch=epoch,
best_validation_loss=best_validation_loss,
)
test_loss = eval_model(
trainer=trainer, trainer=trainer,
dataloader=test_loader, dataloader=test_loader,
text_conditioned=config.prior.condition_on_text_encodings, text_conditioned=config.prior.condition_on_text_encodings,
split="test",
tracker=tracker,
use_ema=True,
report_cosine=False,
report_loss=True,
timesteps=config.train.eval_timesteps,
loss_type=config.prior.loss_type, loss_type=config.prior.loss_type,
tracker_context="test",
tracker=tracker,
) )
if test_loss < best_validation_loss: report_cosine_sims(
best_validation_loss = test_loss trainer,
test_loader,
# go save the model as best config.prior.condition_on_text_encodings,
tracker,
save_trainer( tracker_context="test",
trainer=trainer, )
tracker=tracker,
is_best=True,
is_latest=False,
samples_seen=samples_seen,
epoch=epoch,
best_validation_loss=test_loss,
)
def initialize_training(config_file, accelerator): def initialize_training(config, accelerator=None):
""" """
Parse the configuration file, and prepare everything necessary for training Parse the configuration file, and prepare everything necessary for training
""" """
# load the configuration file
if accelerator.is_main_process:
click.secho(f"Loading configuration from {config_file}", fg="green")
config = TrainDiffusionPriorConfig.from_json_path(config_file)
# seed
set_seed(config.train.random_seed)
# get a device # get a device
device = accelerator.device if accelerator:
device = accelerator.device
click.secho(f"Accelerating on: {device}", fg="yellow")
else:
if torch.cuda.is_available():
click.secho("GPU detected, defaulting to cuda:0", fg="yellow")
device = "cuda:0"
else:
click.secho("No GPU detected...using cpu", fg="yellow")
device = "cpu"
# make the trainer (will automatically distribute if possible & configured) # make the trainer (will automatically distribute if possible & configured)
trainer: DiffusionPriorTrainer = make_model( trainer = make_model(config.prior, config.train, device, accelerator).to(device)
config.prior, config.train, device, accelerator
).to(device)
# create a tracker
tracker = create_tracker(
accelerator, config, config_file, dummy=accelerator.process_index != 0
)
# reload from chcekpoint # reload from chcekpoint
if tracker.can_recall: if config.load.resume == True:
current_epoch, best_validation_loss, samples_seen = recall_trainer( click.secho(f"Loading checkpoint: {config.load.source}", fg="cyan")
tracker=tracker, trainer=trainer trainer.load(config.load.source)
)
# display best values
if trainer.accelerator.is_main_process:
click.secho(f"Current Epoch: {current_epoch} | Best Val Loss: {best_validation_loss} | Samples Seen: {samples_seen}", fg="yellow")
# update config to reflect recalled values
config.train.num_samples_seen = samples_seen
config.train.current_epoch = current_epoch
config.train.best_validation_loss = best_validation_loss
# fetch and prepare data # fetch and prepare data
if trainer.accelerator.is_main_process: if trainer.is_main_process():
click.secho("Grabbing data...", fg="blue", blink=True) click.secho("Grabbing data from source", fg="blue", blink=True)
trainer.accelerator.wait_for_everyone()
img_reader = get_reader( img_reader = get_reader(
text_conditioned=trainer.text_conditioned, text_conditioned=trainer.text_conditioned,
img_url=config.data.image_url, img_url=config.data.image_url,
meta_url=config.data.meta_url, meta_url=config.data.meta_url,
) )
# calculate start point within epoch
trainer.accelerator.wait_for_everyone()
train_loader, eval_loader, test_loader = make_splits( train_loader, eval_loader, test_loader = make_splits(
text_conditioned=trainer.text_conditioned, text_conditioned=trainer.text_conditioned,
batch_size=config.data.batch_size, batch_size=config.data.batch_size,
num_data_points=config.data.num_data_points, num_data_points=NUM_DATA_POINTS,
train_split=config.data.splits.train, train_split=config.data.splits.train,
eval_split=config.data.splits.val, eval_split=config.data.splits.val,
image_reader=img_reader, image_reader=img_reader,
rank=accelerator.state.process_index, rank=accelerator.state.process_index if exists(accelerator) else 0,
world_size=accelerator.state.num_processes, world_size=accelerator.state.num_processes if exists(accelerator) else 1,
start=0, start=START,
) )
# update the start point to finish out the epoch on a resumed run # wait for everyone to load data before continuing
trainer.wait_for_everyone()
if tracker.can_recall:
samples_seen = config.train.num_samples_seen
length = (
config.data.num_data_points
if samples_seen <= img_reader.count
else img_reader.count
)
scaled_samples = length * config.train.current_epoch
start_point = (
scaled_samples - samples_seen if scaled_samples > samples_seen else samples_seen
)
if trainer.accelerator.is_main_process:
click.secho(f"Resuming at sample: {start_point}", fg="yellow")
train_loader.dataset.set_start(start_point)
# start training # start training
if trainer.accelerator.is_main_process:
click.secho(
f"Beginning Prior Training : Distributed={accelerator.state.distributed_type != accelerate_dataclasses.DistributedType.NO}",
fg="yellow",
)
train( train(
trainer=trainer, trainer=trainer,
tracker=tracker,
train_loader=train_loader, train_loader=train_loader,
eval_loader=eval_loader, eval_loader=eval_loader,
test_loader=test_loader, test_loader=test_loader,
@@ -757,13 +403,23 @@ def initialize_training(config_file, accelerator):
@click.command() @click.command()
@click.option("--config_file", default="configs/train_prior_config.example.json") @click.option("--hfa", default=True)
def main(config_file): @click.option("--config_path", default="configs/prior.json")
# start HFA def main(hfa, config_path):
accelerator = Accelerator() # start HFA if requested
if hfa:
accelerator = Accelerator()
else:
accelerator = None
# setup training # load the configuration file on main process
initialize_training(config_file, accelerator) if not exists(accelerator) or accelerator.is_main_process:
click.secho(f"Loading configuration from {config_path}", fg="green")
config = TrainDiffusionPriorConfig.from_json_path(config_path)
# send config to get processed
initialize_training(config, accelerator)
if __name__ == "__main__": if __name__ == "__main__":