Compare commits

...

94 Commits
0.1.9 ... 0.4.1

Author SHA1 Message Date
Phil Wang
c6629c431a make training splits into its own pydantic base model, validate it sums to 1, make decoder script cleaner 2022-05-22 14:43:22 -07:00
Phil Wang
7ac2fc79f2 add renamed train decoder json file 2022-05-22 14:32:50 -07:00
Phil Wang
a1ef023193 use pydantic to manage decoder training configs + defaults and refactor training script 2022-05-22 14:27:40 -07:00
Phil Wang
d49eca62fa dep 2022-05-21 11:27:52 -07:00
Phil Wang
8aab69b91e final thought 2022-05-21 10:47:45 -07:00
Phil Wang
b432df2f7b final cleanup to decoder script 2022-05-21 10:42:16 -07:00
Phil Wang
ebaa0d28c2 product management 2022-05-21 10:30:52 -07:00
Phil Wang
8b0d459b25 move config parsing logic to own file, consider whether to find an off-the-shelf solution at future date 2022-05-21 10:30:10 -07:00
Phil Wang
0064661729 small cleanup of decoder train script 2022-05-21 10:17:13 -07:00
Phil Wang
b895f52843 appreciation section 2022-05-21 08:32:12 -07:00
Phil Wang
80497e9839 accept unets as list for decoder 2022-05-20 20:31:26 -07:00
Phil Wang
f526f14d7c bump 2022-05-20 20:20:40 -07:00
Phil Wang
8997f178d6 small cleanup with timer 2022-05-20 20:05:01 -07:00
Aidan Dempster
022c94e443 Added single GPU training script for decoder (#108)
Added config files for training

Changed example image generation to be more efficient

Added configuration description to README

Removed unused import
2022-05-20 19:46:19 -07:00
Phil Wang
430961cb97 it was correct the first time, my bad 2022-05-20 18:05:15 -07:00
Phil Wang
721f9687c1 fix wandb logging in tracker, and do some cleanup 2022-05-20 17:27:43 -07:00
Aidan Dempster
e0524a6aff Implemented the wandb tracker (#106)
Added a base_path parameter to all trackers for storing any local information they need to
2022-05-20 16:39:23 -07:00
Aidan Dempster
c85e0d5c35 Update decoder dataloader (#105)
* Updated the decoder dataloader
Removed unnecessary logging for required packages
Transferred to using index width instead of shard width
Added the ability to select extra keys to return from the webdataset

* Added README for decoder loader
2022-05-20 16:38:55 -07:00
Phil Wang
db0642c4cd quick fix for @marunine 2022-05-18 20:22:52 -07:00
Phil Wang
bb86ab2404 update sample, and set default gradient clipping value for decoder training 2022-05-16 17:38:30 -07:00
Phil Wang
ae056dd67c samples 2022-05-16 13:46:35 -07:00
Phil Wang
033d6b0ce8 last update 2022-05-16 13:38:33 -07:00
Phil Wang
c7ea8748db default decoder learning rate to what was in the paper 2022-05-16 13:33:54 -07:00
Phil Wang
13382885d9 final update to dalle2 repository for a while - sampling from prior in chunks automatically with max_batch_size keyword given 2022-05-16 12:57:31 -07:00
Phil Wang
c3d4a7ffe4 update working unconditional decoder example 2022-05-16 12:50:07 -07:00
Phil Wang
164d9be444 use a decorator and take care of sampling in chunks (max_batch_size keyword), in case one is sampling a huge grid of images 2022-05-16 12:34:28 -07:00
Phil Wang
5562ec6be2 status updates 2022-05-16 12:01:54 -07:00
Phil Wang
89ff04cfe2 final tweak to EMA class 2022-05-16 11:54:34 -07:00
Phil Wang
f4016f6302 allow for overriding use of EMA during sampling in decoder trainer with use_non_ema keyword, also fix some issues with automatic normalization of images and low res conditioning image if latent diffusion is in play 2022-05-16 11:18:30 -07:00
Phil Wang
1212f7058d allow text encodings and text mask to be passed in on forward and sampling for Decoder class 2022-05-16 10:40:32 -07:00
Phil Wang
dab106d4e5 back to no_grad for now, also keep track and restore unet devices in one_unet_in_gpu contextmanager 2022-05-16 09:36:14 -07:00
Phil Wang
bb151ca6b1 unet_number on decoder trainer only needs to be passed in if there is greater than 1 unet, so that unconditional training of a single ddpm is seamless (experiment in progress locally) 2022-05-16 09:17:17 -07:00
zion
4a59dea4cf Migrate to text-conditioned prior training (#95)
* migrate to conditioned prior

* unify reader logic with a wrapper (#1)

* separate out reader logic

* support both training methods

* Update train prior to use embedding wrapper (#3)

* Support Both Methods

* bug fixes

* small bug fixes

* embedding only wrapper bug

* use smaller val perc

* final bug fix for embedding-only

Co-authored-by: nousr <>
2022-05-15 20:16:38 -07:00
Phil Wang
ecf9e8027d make sure classifier free guidance is used only if conditional dropout is present on the DiffusionPrior and Decoder classes. also make sure prior can have a different conditional scale than decoder 2022-05-15 19:09:38 -07:00
Phil Wang
36c5079bd7 LazyLinear is not mature, make users pass in text_embed_dim if text conditioning is turned on 2022-05-15 18:56:52 -07:00
Phil Wang
4a4c7ac9e6 cond drop prob for diffusion prior network should default to 0 2022-05-15 18:47:45 -07:00
Phil Wang
fad7481479 todo 2022-05-15 17:00:25 -07:00
Phil Wang
123658d082 cite Ho et al, since cascading ddpm is now trainable 2022-05-15 16:56:53 -07:00
Phil Wang
11d4e11f10 allow for training unconditional ddpm or cascading ddpms 2022-05-15 16:54:56 -07:00
Phil Wang
99778e12de trainer classes now takes care of auto-casting numpy to torch tensors, and setting correct device based on model parameter devices 2022-05-15 15:25:45 -07:00
Phil Wang
0f0011caf0 todo 2022-05-15 14:28:35 -07:00
Phil Wang
7b7a62044a use eval vs training mode to determine whether to call backprop on trainer forward 2022-05-15 14:20:59 -07:00
Phil Wang
156fe5ed9f final cleanup for the day 2022-05-15 12:38:41 -07:00
Phil Wang
5ec34bebe1 cleanup readme 2022-05-15 12:29:26 -07:00
Phil Wang
8eaacf1ac1 remove indirection 2022-05-15 12:05:45 -07:00
Phil Wang
e66c7b0249 incorrect naming 2022-05-15 11:23:52 -07:00
Phil Wang
f7cd4a0992 product management 2022-05-15 11:21:12 -07:00
Phil Wang
68e7d2f241 make sure gradient accumulation feature works even if all arguments passed in are keyword arguments 2022-05-15 11:16:16 -07:00
Phil Wang
74f222596a remove todo 2022-05-15 11:01:35 -07:00
Phil Wang
aa6772dcff make sure optimizer and scaler is reloaded on resume for training diffusion prior script, move argparse to click 2022-05-15 10:48:10 -07:00
Phil Wang
71d0c4edae cleanup to use diffusion prior trainer 2022-05-15 10:16:05 -07:00
Phil Wang
f7eee09d8b 0.2.30 2022-05-15 09:56:59 -07:00
Phil Wang
89de5af63e experiment tracker agnostic 2022-05-15 09:56:40 -07:00
Phil Wang
4ec6d0ba81 backwards pass is not recommended under the autocast context, per pytorch docs 2022-05-14 18:26:19 -07:00
Phil Wang
aee92dba4a simplify more 2022-05-14 17:16:46 -07:00
Phil Wang
b0cd5f24b6 take care of gradient accumulation automatically for researchers, by passing in a max_batch_size on the decoder or diffusion prior trainer forward 2022-05-14 17:04:09 -07:00
Phil Wang
b494ed81d4 take care of backwards within trainer classes for diffusion prior and decoder, readying to take care of gradient accumulation as well (plus, unsure if loss should be backwards within autocast block) 2022-05-14 15:49:24 -07:00
Phil Wang
ff3474f05c normalize conditioning tokens outside of cross attention blocks 2022-05-14 14:23:52 -07:00
Phil Wang
d5293f19f1 lineup with paper 2022-05-14 13:57:00 -07:00
Phil Wang
e697183849 be able to customize adam eps 2022-05-14 13:55:04 -07:00
Phil Wang
591d37e266 lower default initial learning rate to what Jonathan Ho had in his original repo 2022-05-14 13:22:43 -07:00
Phil Wang
d1f02e8f49 always use sandwich norm for attention layer 2022-05-14 12:13:41 -07:00
Phil Wang
9faab59b23 use post-attn-branch layernorm in attempt to stabilize cross attention conditioning in decoder 2022-05-14 11:58:09 -07:00
Phil Wang
5d27029e98 make sure lowres conditioning image is properly normalized to -1 to 1 for cascading ddpm 2022-05-14 01:23:54 -07:00
Phil Wang
3115fa17b3 fix everything around normalizing images to -1 to 1 for ddpm training automatically 2022-05-14 01:17:11 -07:00
Phil Wang
124d8577c8 move the inverse normalization function called before image embeddings are derived from clip to within the diffusion prior and decoder classes 2022-05-14 00:37:52 -07:00
Phil Wang
2db0c9794c comments 2022-05-12 14:25:20 -07:00
Phil Wang
2277b47ffd make sure learned variance can work for any number of unets in the decoder, defaults to first unet, as suggested was used in the paper 2022-05-12 14:18:15 -07:00
Phil Wang
28b58e568c cleanup in preparation of option for learned variance 2022-05-12 12:04:52 -07:00
Phil Wang
924455d97d align the ema model device back after sampling from the cascading ddpm in the decoder 2022-05-11 19:56:54 -07:00
Phil Wang
6021945fc8 default to l2 loss 2022-05-11 19:24:51 -07:00
Light-V
6f76652d11 fix typo in README.md (#85)
The default config for clip from openai should be ViT-B/32
2022-05-11 13:38:16 -07:00
Phil Wang
3dda2570ed fix amp issue for https://github.com/lucidrains/DALLE2-pytorch/issues/82 2022-05-11 08:21:39 -07:00
Phil Wang
2f3c02dba8 numerical accuracy for noise schedule parameters 2022-05-10 15:28:46 -07:00
Phil Wang
908088cfea wrap up cross embed layer feature 2022-05-10 12:19:34 -07:00
Phil Wang
8dc8a3de0d product management 2022-05-10 11:51:38 -07:00
Phil Wang
35f89556ba bring in the cross embed layer from Crossformer paper for initial convolution in unet 2022-05-10 11:50:38 -07:00
Phil Wang
2b55f753b9 fix new issue with github actions and auto pypi package uploading 2022-05-10 10:51:15 -07:00
Phil Wang
fc8fce38fb make sure cascading DDPM can be trained unconditionally, to ready for CLI one command training for the public 2022-05-10 10:48:10 -07:00
Phil Wang
a1bfb03ba4 project management 2022-05-10 10:13:51 -07:00
Phil Wang
b1e7b5f6bb make sure resnet groups in unet is finely customizable 2022-05-10 10:12:50 -07:00
z
10b905b445 smol typo (#81) 2022-05-10 09:52:50 -07:00
Phil Wang
9b322ea634 patch 2022-05-09 19:46:19 -07:00
Phil Wang
ba64ea45cc 0.2.3 2022-05-09 16:50:31 -07:00
Phil Wang
64f7be1926 some cleanup 2022-05-09 16:50:21 -07:00
Phil Wang
db805e73e1 fix a bug with numerical stability in attention, sorry! 🐛 2022-05-09 16:23:37 -07:00
z
cb07b37970 Ensure Eval Mode In Metric Functions (#79)
* add eval/train toggles

* train/eval flags

* shift train toggle

Co-authored-by: nousr <z@localhost.com>
2022-05-09 16:05:40 -07:00
Phil Wang
a774bfefe2 add attention and feedforward dropouts to train_diffusion_prior script 2022-05-09 13:57:15 -07:00
Phil Wang
2ae57f0cf5 cleanup 2022-05-09 13:51:26 -07:00
Phil Wang
e46eaec817 deal the diffusion prior problem yet another blow 2022-05-09 11:08:52 -07:00
Kumar R
8647cb5e76 Val loss changes, with quite a few other changes. This is in place of the earlier PR(https://github.com/lucidrains/DALLE2-pytorch/pull/67) (#77)
* Val_loss changes - no rebased with lucidrains' master.

* Val Loss changes - now rebased with lucidrains' master

* train_diffusion_prior.py updates

* dalle2_pytorch.py updates

* __init__.py changes

* Update train_diffusion_prior.py

* Update dalle2_pytorch.py

* Update train_diffusion_prior.py

* Update train_diffusion_prior.py

* Update dalle2_pytorch.py

* Update train_diffusion_prior.py

* Update train_diffusion_prior.py

* Update train_diffusion_prior.py

* Update train_diffusion_prior.py

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md
2022-05-09 08:53:29 -07:00
Phil Wang
53c189e46a give more surface area for attention in diffusion prior 2022-05-09 08:08:11 -07:00
Phil Wang
dde51fd362 revert restriction for classifier free guidance for diffusion prior, given @crowsonkb advice 2022-05-07 20:55:41 -07:00
Nasir Khalid
2eac7996fa Additional image_embed metric (#75)
Added metric to track image_embed vs predicted_image_embed
2022-05-07 14:32:33 -07:00
22 changed files with 2604 additions and 685 deletions

9
.gitignore vendored
View File

@@ -1,3 +1,12 @@
# default experiment tracker data
.tracker-data/
# Configuration Files
configs/*
!configs/*.example
!configs/*_defaults.py
!configs/README.md
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]

192
README.md
View File

@@ -14,6 +14,16 @@ Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord
There was enough interest for a <a href="https://github.com/lucidrains/dalle2-jax">Jax version</a>. I will also eventually extend this to <a href="https://github.com/lucidrains/dalle2-video">text to video</a>, once the repository is in a good place.
## Status
- A research group has used the code in this repository to train a functional diffusion prior for their CLIP generations. Will share their work once they release their preprint. This, and <a href="https://github.com/crowsonkb">Katherine's</a> own experiments, validate OpenAI's finding that the extra prior increases variety of generations.
- Decoder is now verified working for unconditional generation on my experimental setup for Oxford flowers. 2 researchers have also confirmed Decoder is working for them.
<img src="./samples/oxford.png" width="600px" />
*ongoing at 21k steps*
## Install
```bash
@@ -508,7 +518,7 @@ To use a pretrained OpenAI CLIP, simply import `OpenAIClipAdapter` and pass it i
import torch
from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder, OpenAIClipAdapter
# openai pretrained clip - defaults to ViT/B-32
# openai pretrained clip - defaults to ViT-B/32
clip = OpenAIClipAdapter()
@@ -706,7 +716,7 @@ mock_image_embed = torch.randn(1, 512).cuda()
images = decoder.sample(mock_image_embed) # (1, 3, 1024, 1024)
```
## Training wrapper (wip)
## Training wrapper
### Decoder Training
@@ -732,8 +742,8 @@ clip = CLIP(
# mock data
text = torch.randint(0, 49408, (4, 256)).cuda()
images = torch.randn(4, 3, 256, 256).cuda()
text = torch.randint(0, 49408, (32, 256)).cuda()
images = torch.randn(32, 3, 256, 256).cuda()
# decoder (with unet)
@@ -774,8 +784,12 @@ decoder_trainer = DecoderTrainer(
)
for unet_number in (1, 2):
loss = decoder_trainer(images, text = text, unet_number = unet_number) # use the decoder_trainer forward
loss.backward()
loss = decoder_trainer(
images,
text = text,
unet_number = unet_number, # which unet to train on
max_batch_size = 4 # gradient accumulation - this sets the maximum batch size in which to do forward and backwards pass - for this example 32 / 4 == 8 times
)
decoder_trainer.update(unet_number) # update the specific unet as well as its exponential moving average
@@ -810,8 +824,8 @@ clip = CLIP(
# mock data
text = torch.randint(0, 49408, (4, 256)).cuda()
images = torch.randn(4, 3, 256, 256).cuda()
text = torch.randint(0, 49408, (512, 256)).cuda()
images = torch.randn(512, 3, 256, 256).cuda()
# prior networks (with transformer)
@@ -838,16 +852,70 @@ diffusion_prior_trainer = DiffusionPriorTrainer(
ema_update_every = 10,
)
loss = diffusion_prior_trainer(text, images)
loss.backward()
loss = diffusion_prior_trainer(text, images, max_batch_size = 4)
diffusion_prior_trainer.update() # this will update the optimizer as well as the exponential moving averaged diffusion prior
# after much of the above three lines in a loop
# you can sample from the exponential moving average of the diffusion prior identically to how you do so for DiffusionPrior
image_embeds = diffusion_prior_trainer.sample(text) # (4, 512) - exponential moving averaged image embeddings
image_embeds = diffusion_prior_trainer.sample(text, max_batch_size = 4) # (512, 512) - exponential moving averaged image embeddings
```
## Bonus
### Unconditional Training
The repository also contains the means to train unconditional DDPM model, or even cascading DDPMs. You simply have to set `unconditional = True` in the `Decoder`
ex.
```python
import torch
from dalle2_pytorch import Unet, Decoder, DecoderTrainer
# unet for the cascading ddpm
unet1 = Unet(
dim = 128,
dim_mults=(1, 2, 4, 8)
).cuda()
unet2 = Unet(
dim = 32,
dim_mults = (1, 2, 4, 8, 16)
).cuda()
# decoder, which contains the unets
decoder = Decoder(
unet = (unet1, unet2),
image_sizes = (256, 512), # first unet up to 256px, then second to 512px
timesteps = 1000,
unconditional = True
).cuda()
# decoder trainer
decoder_trainer = DecoderTrainer(decoder)
# images (get a lot of this)
images = torch.randn(1, 3, 512, 512).cuda()
# feed images into decoder
for i in (1, 2):
loss = decoder_trainer(images, unet_number = i)
decoder_trainer.update(unet_number = i)
# do the above for many many many many images
# then it will learn to generate images
images = decoder_trainer.sample(batch_size = 36, max_batch_size = 4) # (36, 3, 512, 512)
```
## Dataloaders
### Decoder Dataloaders
In order to make loading data simple and efficient, we include some general dataloaders that can be used to train portions of the network.
@@ -892,14 +960,14 @@ dataset = ImageEmbeddingDataset(
)
```
## Scripts
### Scripts (wip)
### Using the `train_diffusion_prior.py` script
#### `train_diffusion_prior.py`
This script allows training the DiffusionPrior on pre-computed text and image embeddings. The working example below elucidates this process.
Please note that the script internally passes text_embed and image_embed to the DiffusionPrior, unlike the example below.
### Usage
#### Usage
```bash
$ python train_diffusion_prior.py
@@ -907,27 +975,50 @@ $ python train_diffusion_prior.py
The most significant parameters for the script are as follows:
--image-embed-url, default = "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/")
- `image-embed-url`, default = `"https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/"`
--text-embed-url, default = "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/")
- `text-embed-url`, default = `"https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/"`
--image-embed-dim, default=768 - 768 corresponds to the ViT iL/14 embedding size,change it to what your chosen ViT generates
- `image-embed-dim`, default = `768` - 768 corresponds to the ViT iL/14 embedding size,change it to what your chosen ViT generates
--learning-rate, default=1.1e-4
- `learning-rate`, default = `1.1e-4`
--weight-decay, default=6.02e-2
- `weight-decay`, default = `6.02e-2`
--max-grad-norm, default=0.5
- `max-grad-norm`, default = `0.5`
--batch-size, default=10 ** 4
- `batch-size`, default = `10 ** 4`
--num-epochs, default=5
- `num-epochs`, default = `5`
--clip, default=None # Signals the prior to use pre-computed embeddings
- `clip`, default = `None` # Signals the prior to use pre-computed embeddings
### Sample wandb run log
#### Loading and Saving the DiffusionPrior model
Please find a sample wandb run log at : https://wandb.ai/laion/diffusion-prior/runs/aul0rhv5?workspace=
Two methods are provided, load_diffusion_model and save_diffusion_model, the names being self-explanatory.
```python
from dalle2_pytorch.train import load_diffusion_model, save_diffusion_model
```
##### Loading
load_diffusion_model(dprior_path, device)
dprior_path : path to saved model(.pth)
device : the cuda device you're running on
##### Saving
save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim)
save_path : path to save at
model : object of Diffusion_Prior
optimizer : optimizer object - see train_diffusion_prior.py for how to create one.
e.g: optimizer = get_optimizer(diffusion_prior.net.parameters(), wd=weight_decay, lr=learning_rate)
scaler : a GradScaler object.
e.g: scaler = GradScaler(enabled=amp)
config : config object created in train_diffusion_prior.py - see file for example.
image_embed_dim - the dimension of the image_embedding
e.g: 768
## CLI (wip)
@@ -943,6 +1034,18 @@ Once built, images will be saved to the same directory the command is invoked
<a href="https://github.com/lucidrains/stylegan2-pytorch">template</a>
## Appreciation
This library would not have gotten to this working state without the help of
- <a href="https://github.com/nousr">Zion</a> and <a href="https://github.com/krish240574">Kumar</a> for the diffusion training script
- <a href="https://github.com/Veldrovive">Aidan</a> for the decoder training script and dataloaders
- <a href="https://github.com/rom1504">Romain</a> for the pull request reviews and project management
- <a href="https://github.com/Ciaohe">He Cao</a> and <a href="https://github.com/xiankgx">xiankgx</a> for the Q&A and for identifying of critical bugs
- <a href="https://github.com/crowsonkb">Katherine</a> for her advice
... and many others. Thank you! 🙏
## Todo
- [x] finish off gaussian diffusion class for latent embedding - allow for prediction of epsilon
@@ -966,23 +1069,30 @@ Once built, images will be saved to the same directory the command is invoked
- [x] add convnext backbone for vqgan-vae (in addition to vit [vit-vqgan] + resnet)
- [x] make sure DDPMs can be run with traditional resnet blocks (but leave convnext as an option for experimentation)
- [x] make sure for the latter unets in the cascade, one can train on crops for learning super resolution (constrain the unet to be only convolutions in that case, or allow conv-like attention with rel pos bias)
- [x] offer setting in diffusion prior to split time and image embeddings into multiple tokens, configurable, for more surface area during attention
- [x] make sure resnet hyperparameters can be configurable across unet depth (groups and expansion factor)
- [x] pull logic for training diffusion prior into a class DiffusionPriorTrainer, for eventual script based + CLI based training
- [x] make sure the cascading ddpm in the repository can be trained unconditionally, offer a one-line CLI tool for training on a folder of images
- [x] bring in cross-scale embedding from iclr paper https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/crossformer.py#L14
- [x] cross embed layers for downsampling, as an option
- [x] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a>
- [x] use pydantic for config drive training
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
- [ ] make sure the cascading ddpm in the repository can be trained unconditionally, offer a one-line CLI tool for training on a folder of images
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
- [ ] pull logic for training diffusion prior into a class DiffusionPriorTrainer, for eventual script based + CLI based training
- [ ] train on a toy task, offer in colab
- [ ] think about how best to design a declarative training config that handles preencoding for prior and training of multiple networks in decoder
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
- [ ] bring in cross-scale embedding from iclr paper https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/crossformer.py#L14
- [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824
- [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove
- [ ] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a>
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
- [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783
- [ ] make sure resnet hyperparameters can be configurable across unet depth (groups and expansion factor)
- [ ] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
- [ ] offer setting in diffusion prior to split time and image embeddings into multiple tokens, configurable, for more surface area during attention
- [ ] bring in skip-layer excitatons (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training
- [ ] decoder needs one day worth of refactor for tech debt
- [ ] allow for unet to be able to condition non-cross attention style as well
- [ ] for all model classes with hyperparameters that changes the network architecture, make it requirement that they must expose a config property, and write a simple function that asserts that it restores the object correctly
- [ ] for both diffusion prior and decoder, all exponential moving averaged models needs to be saved and restored as well (as well as the step number)
- [ ] read the paper, figure it out, and build it https://github.com/lucidrains/DALLE2-pytorch/issues/89
## Citations
@@ -1060,4 +1170,24 @@ Once built, images will be saved to the same directory the command is invoked
}
```
```bibtex
@misc{wang2021crossformer,
title = {CrossFormer: A Versatile Vision Transformer Hinging on Cross-scale Attention},
author = {Wenxiao Wang and Lu Yao and Long Chen and Binbin Lin and Deng Cai and Xiaofei He and Wei Liu},
year = {2021},
eprint = {2108.00154},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```
```bibtex
@article{ho2021cascaded,
title = {Cascaded Diffusion Models for High Fidelity Image Generation},
author = {Ho, Jonathan and Saharia, Chitwan and Chan, William and Fleet, David J and Norouzi, Mohammad and Salimans, Tim},
journal = {arXiv preprint arXiv:2106.15282},
year = {2021}
}
```
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>

109
configs/README.md Normal file
View File

@@ -0,0 +1,109 @@
## DALLE2 Training Configurations
For more complex configuration, we provide the option of using a configuration file instead of command line arguments.
### Decoder Trainer
The decoder trainer has 7 main configuration options. A full example of their use can be found in the [example decoder configuration](train_decoder_config.example.json).
**<ins>Unets</ins>:**
Each member of this array defines a single unet that will be added to the decoder.
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `dim` | Yes | N/A | The starting channels of the unet. |
| `image_embed_dim` | Yes | N/A | The dimension of the image embeddings. |
| `dim_mults` | No | `(1, 2, 4, 8)` | The growth factors of the channels. |
Any parameter from the `Unet` constructor can also be given here.
**<ins>Decoder</ins>:**
Defines the configuration options for the decoder model. The unets defined above will automatically be inserted.
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `image_sizes` | Yes | N/A | The resolution of the image after each upsampling step. The length of this array should be the number of unets defined. |
| `image_size` | Yes | N/A | Not used. Can be any number. |
| `timesteps` | No | `1000` | The number of diffusion timesteps used for generation. |
| `loss_type` | No | `l2` | The loss function. Options are `l1`, `huber`, or `l2`. |
| `beta_schedule` | No | `cosine` | The noising schedule. Options are `cosine`, `linear`, `quadratic`, `jsd`, or `sigmoid`. |
| `learned_variance` | No | `True` | Whether to learn the variance. |
Any parameter from the `Decoder` constructor can also be given here.
**<ins>Data</ins>:**
Settings for creation of the dataloaders.
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `webdataset_base_url` | Yes | N/A | The url of a shard in the webdataset with the shard replaced with `{}`[^1]. |
| `embeddings_url` | No | N/A | The url of the folder containing embeddings shards. Not required if embeddings are in webdataset. |
| `num_workers` | No | `4` | The number of workers used in the dataloader. |
| `batch_size` | No | `64` | The batch size. |
| `start_shard` | No | `0` | Defines the start of the shard range the dataset will recall. |
| `end_shard` | No | `9999999` | Defines the end of the shard range the dataset will recall. |
| `shard_width` | No | `6` | Defines the width of one webdataset shard number[^2]. |
| `index_width` | No | `4` | Defines the width of the index of a file inside a shard[^3]. |
| `splits` | No | `{ "train": 0.75, "val": 0.15, "test": 0.1 }` | Defines the proportion of shards that will be allocated to the training, validation, and testing datasets. |
| `shuffle_train` | No | `True` | Whether to shuffle the shards of the training dataset. |
| `resample_train` | No | `False` | If true, shards will be randomly sampled with replacement from the datasets making the epoch length infinite if a limit is not set. Cannot be enabled if `shuffle_train` is enabled. |
| `preprocessing` | No | `{ "ToTensor": True }` | Defines preprocessing applied to images from the datasets. |
[^1]: If your shard files have the paths `protocol://path/to/shard/00104.tar`, then the base url would be `protocol://path/to/shard/{}.tar`. If you are using a protocol like `s3`, you need to pipe the tars. For example `pipe:s3cmd get s3://bucket/path/{}.tar -`.
[^2]: This refers to the string length of the shard number for your webdataset shards. For instance, if your webdataset shard has the filename `00104.tar`, your shard length is 5.
[^3]: Inside the webdataset `tar`, you have files named something like `001045945.jpg`. 5 of these characters refer to the shard, and 4 refer to the index of the file in the webdataset (shard is `001041` and index is `5945`). The `index_width` in this case is 4.
**<ins>Train</ins>:**
Settings for controlling the training hyperparameters.
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `epochs` | No | `20` | The number of epochs in the training run. |
| `lr` | No | `1e-4` | The learning rate. |
| `wd` | No | `0.01` | The weight decay. |
| `max_grad_norm`| No | `0.5` | The grad norm clipping. |
| `save_every_n_samples` | No | `100000` | Samples will be generated and a checkpoint will be saved every `save_every_n_samples` samples. |
| `device` | No | `cuda:0` | The device to train on. |
| `epoch_samples` | No | `None` | Limits the number of samples iterated through in each epoch. This must be set if resampling. None means no limit. |
| `validation_samples` | No | `None` | The number of samples to use for validation. None mean the entire validation set. |
| `use_ema` | No | `True` | Whether to use exponential moving average models for sampling. |
| `ema_beta` | No | `0.99` | The ema coefficient. |
| `save_all` | No | `False` | If True, preserves a checkpoint for every epoch. |
| `save_latest` | No | `True` | If True, overwrites the `latest.pth` every time the model is saved. |
| `save_best` | No | `True` | If True, overwrites the `best.pth` every time the model has a lower validation loss than all previous models. |
| `unet_training_mask` | No | `None` | A boolean array of the same length as the number of unets. If false, the unet is frozen. A value of `None` trains all unets. |
**<ins>Evaluate</ins>:**
Defines which evaluation metrics will be used to test the model.
Each metric can be enabled by setting its configuration. The configuration keys for each metric are defined by the torchmetrics constructors which will be linked.
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `n_evalation_samples` | No | `1000` | The number of samples to generate to test the model. |
| `FID` | No | `None` | Setting to an object enables the [Frechet Inception Distance](https://torchmetrics.readthedocs.io/en/stable/image/frechet_inception_distance.html) metric.
| `IS` | No | `None` | Setting to an object enables the [Inception Score](https://torchmetrics.readthedocs.io/en/stable/image/inception_score.html) metric.
| `KID` | No | `None` | Setting to an object enables the [Kernel Inception Distance](https://torchmetrics.readthedocs.io/en/stable/image/kernel_inception_distance.html) metric. |
| `LPIPS` | No | `None` | Setting to an object enables the [Learned Perceptual Image Patch Similarity](https://torchmetrics.readthedocs.io/en/stable/image/learned_perceptual_image_patch_similarity.html) metric. |
**<ins>Tracker</ins>:**
Selects which tracker to use and configures it.
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `tracker_type` | No | `console` | Which tracker to use. Currently accepts `console` or `wandb`. |
| `data_path` | No | `./models` | Where the tracker will store local data. |
| `verbose` | No | `False` | Enables console logging for non-console trackers. |
Other configuration options are required for the specific trackers. To see which are required, reference the initializer parameters of each [tracker](../dalle2_pytorch/trackers.py).
**<ins>Load</ins>:**
Selects where to load a pretrained model from.
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `source` | No | `None` | Supports `file` or `wandb`. |
| `resume` | No | `False` | If the tracker support resuming the run, resume it. |
Other configuration options are required for loading from a specific source. To see which are required, reference the load methods at the top of the [tracker file](../dalle2_pytorch/trackers.py).

View File

@@ -0,0 +1,99 @@
{
"unets": [
{
"dim": 128,
"image_embed_dim": 768,
"cond_dim": 64,
"channels": 3,
"dim_mults": [1, 2, 4, 8],
"attn_dim_head": 32,
"attn_heads": 16
}
],
"decoder": {
"image_sizes": [64],
"channels": 3,
"timesteps": 1000,
"loss_type": "l2",
"beta_schedule": "cosine",
"learned_variance": true
},
"data": {
"webdataset_base_url": "pipe:s3cmd get s3://bucket/path/{}.tar -",
"embeddings_url": "s3://bucket/embeddings/path/",
"num_workers": 4,
"batch_size": 64,
"start_shard": 0,
"end_shard": 9999999,
"shard_width": 6,
"index_width": 4,
"splits": {
"train": 0.75,
"val": 0.15,
"test": 0.1
},
"shuffle_train": true,
"resample_train": false,
"preprocessing": {
"RandomResizedCrop": {
"size": [128, 128],
"scale": [0.75, 1.0],
"ratio": [1.0, 1.0]
},
"ToTensor": true
}
},
"train": {
"epochs": 20,
"lr": 1e-4,
"wd": 0.01,
"max_grad_norm": 0.5,
"save_every_n_samples": 100000,
"n_sample_images": 6,
"device": "cuda:0",
"epoch_samples": null,
"validation_samples": null,
"use_ema": true,
"ema_beta": 0.99,
"amp": false,
"save_all": false,
"save_latest": true,
"save_best": true,
"unet_training_mask": [true]
},
"evaluate": {
"n_evaluation_samples": 1000,
"FID": {
"feature": 64
},
"IS": {
"feature": 64,
"splits": 10
},
"KID": {
"feature": 64,
"subset_size": 10
},
"LPIPS": {
"net_type": "vgg",
"reduction": "mean"
}
},
"tracker": {
"tracker_type": "console",
"data_path": "./models",
"wandb_entity": "",
"wandb_project": "",
"verbose": false
},
"load": {
"source": null,
"run_path": "",
"file_path": "",
"resume": false
}
}

View File

@@ -1,6 +1,6 @@
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter
from dalle2_pytorch.train import DecoderTrainer, DiffusionPriorTrainer
from dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer
from dalle2_pytorch.vqgan_vae import VQGanVAE
from x_clip import CLIP

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,41 @@
## Dataloaders
In order to make loading data simple and efficient, we include some general dataloaders that can be used to train portions of the network.
### Decoder: Image Embedding Dataset
When training the decoder (and up samplers if training together) in isolation, you will need to load images and corresponding image embeddings. This dataset can read two similar types of datasets. First, it can read a [webdataset](https://github.com/webdataset/webdataset) that contains `.jpg` and `.npy` files in the `.tar`s that contain the images and associated image embeddings respectively. Alternatively, you can also specify a source for the embeddings outside of the webdataset. In this case, the path to the embeddings should contain `.npy` files with the same shard numbers as the webdataset and there should be a correspondence between the filename of the `.jpg` and the index of the embedding in the `.npy`. So, for example, `0001.tar` from the webdataset with image `00010509.jpg` (the first 4 digits are the shard number and the last 4 are the index) in it should be paralleled by a `img_emb_0001.npy` which contains a NumPy array with the embedding at index 509.
Generating a dataset of this type:
1. Use [img2dataset](https://github.com/rom1504/img2dataset) to generate a webdataset.
2. Use [clip-retrieval](https://github.com/rom1504/clip-retrieval) to convert the images to embeddings.
3. Use [embedding-dataset-reordering](https://github.com/Veldrovive/embedding-dataset-reordering) to reorder the embeddings into the expected format.
Usage:
```python
from dalle2_pytorch.dataloaders import ImageEmbeddingDataset, create_image_embedding_dataloader
# Create a dataloader directly.
dataloader = create_image_embedding_dataloader(
tar_url="/path/or/url/to/webdataset/{0000..9999}.tar", # Uses braket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar
embeddings_url="path/or/url/to/embeddings/folder", # Included if .npy files are not in webdataset. Left out or set to None otherwise
num_workers=4,
batch_size=32,
shard_width=4, # If a file in the webdataset shard 3 is named 0003039.jpg, we know the shard width is 4 and the last three digits are the index
shuffle_num=200, # Does a shuffle of the data with a buffer size of 200
shuffle_shards=True, # Shuffle the order the shards are read in
resample_shards=False, # Sample shards with replacement. If true, an epoch will be infinite unless stopped manually
)
for img, emb in dataloader:
print(img.shape) # torch.Size([32, 3, 256, 256])
print(emb.shape) # torch.Size([32, 512])
# Train decoder only as shown above
# Or create a dataset without a loader so you can configure it manually
dataset = ImageEmbeddingDataset(
urls="/path/or/url/to/webdataset/{0000..9999}.tar",
embedding_folder_url="path/or/url/to/embeddings/folder",
shard_width=4,
shuffle_shards=True,
resample=False
)
```

View File

@@ -1 +1,2 @@
from dalle2_pytorch.dataloaders.decoder_loader import ImageEmbeddingDataset, create_image_embedding_dataloader
from dalle2_pytorch.dataloaders.decoder_loader import ImageEmbeddingDataset, create_image_embedding_dataloader
from dalle2_pytorch.dataloaders.embedding_wrapper import make_splits

View File

@@ -3,6 +3,7 @@ import webdataset as wds
import torch
import numpy as np
import fsspec
import shutil
def get_shard(filename):
"""
@@ -20,7 +21,7 @@ def get_example_file(fs, path, file_format):
"""
return fs.glob(os.path.join(path, f"*.{file_format}"))[0]
def embedding_inserter(samples, embeddings_url, shard_width, handler=wds.handlers.reraise_exception):
def embedding_inserter(samples, embeddings_url, index_width, handler=wds.handlers.reraise_exception):
"""Given a datum of {"__key__": str, "__url__": str, ...} adds the cooresponding embedding and yields"""
previous_tar_url = None
current_embeddings = None
@@ -50,8 +51,12 @@ def embedding_inserter(samples, embeddings_url, shard_width, handler=wds.handler
previous_tar_url = tar_url
current_embeddings = load_corresponding_embeds(tar_url)
embedding_index = int(key[shard_width:])
sample["npy"] = current_embeddings[embedding_index]
embedding_index = int(key[-index_width:])
embedding = current_embeddings[embedding_index]
# We need to check if this sample is nonzero. If it is, this embedding is not valid and we should continue to the next loop
if torch.count_nonzero(embedding) == 0:
raise RuntimeError(f"Webdataset had a sample, but no embedding was found. ImgShard: {key[:-index_width]} - Index: {key[-index_width:]}")
sample["npy"] = embedding
yield sample
except Exception as exn: # From wds implementation
if handler(exn):
@@ -60,6 +65,28 @@ def embedding_inserter(samples, embeddings_url, shard_width, handler=wds.handler
break
insert_embedding = wds.filters.pipelinefilter(embedding_inserter)
def unassociated_shard_skipper(tarfiles, embeddings_url, handler=wds.handlers.reraise_exception):
"""Finds if the is a corresponding embedding for the tarfile at { url: [URL] }"""
embeddings_fs, embeddings_path = fsspec.core.url_to_fs(embeddings_url)
embedding_files = embeddings_fs.ls(embeddings_path)
get_embedding_shard = lambda embedding_file: int(embedding_file.split("_")[-1].split(".")[0])
embedding_shards = set([get_embedding_shard(filename) for filename in embedding_files]) # Sets have O(1) check for member
get_tar_shard = lambda tar_file: int(tar_file.split("/")[-1].split(".")[0])
for tarfile in tarfiles:
try:
webdataset_shard = get_tar_shard(tarfile["url"])
# If this shard has an associated embeddings file, we pass it through. Otherwise we iterate until we do have one
if webdataset_shard in embedding_shards:
yield tarfile
except Exception as exn: # From wds implementation
if handler(exn):
continue
else:
break
skip_unassociated_shards = wds.filters.pipelinefilter(unassociated_shard_skipper)
def verify_keys(samples, handler=wds.handlers.reraise_exception):
"""
Requires that both the image and embedding are present in the sample
@@ -86,7 +113,9 @@ class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
self,
urls,
embedding_folder_url=None,
shard_width=None,
index_width=None,
img_preproc=None,
extra_keys=[],
handler=wds.handlers.reraise_exception,
resample=False,
shuffle_shards=True
@@ -97,13 +126,31 @@ class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
:param urls: A url pointing to the tar files of the webdataset formatted as /path/to/webdataset/{0000..9999}.tar
:param embedding_folder_url: Required if webdataset does not contain embeddings. A url pointing to the npy files of the embeddings. Should have the same number of shards as the webdataset.
Webdataset image keys should align with the index of the embedding. This means missing image indices must have a corresponding embedding of all zeros.
:param shard_width: The number of digits in the shard number. This is used to align the embedding index with the image index.
For example, if a file in the webdataset shard 3 is named 0003039.jpg, we know the shard with this 4 and the last three digits are the index.
:param index_width: The number of digits in the index. This is used to align the embedding index with the image index.
For example, if a file in the webdataset shard 3 is named 0003039.jpg, we know the shard is 4 digits and the last 3 digits are the index_width.
:param img_preproc: This function is run on the img before it is batched and returned. Useful for data augmentation or converting to torch tensor.
:param handler: A webdataset handler.
:param resample: If true, resample webdataset shards with replacement. You need to set your own epoch size if this is true since it will resample infinitely.
:param shuffle_shards: If true, shuffle the shards before resampling. This cannot be true if resample is true.
"""
super().__init__()
keys = ["jpg", "npy"] + extra_keys
self.key_map = {key: i for i, key in enumerate(keys)}
self.resampling = resample
self.img_preproc = img_preproc
# If s3, check if s3fs is installed and s3cmd is installed and check if the data is piped instead of straight up
if (isinstance(urls, str) and "s3:" in urls) or (isinstance(urls, list) and any(["s3:" in url for url in urls])):
# Then this has an s3 link for the webdataset and we need extra packages
if shutil.which("s3cmd") is None:
raise RuntimeError("s3cmd is required for s3 webdataset")
if "s3:" in embedding_folder_url:
# Then the embeddings are being loaded from s3 and fsspec requires s3fs
try:
import s3fs
except ImportError:
raise RuntimeError("s3fs is required to load embeddings from s3")
# Add the shardList and randomize or resample if requested
if resample:
assert not shuffle_shards, "Cannot both resample and shuffle"
@@ -112,28 +159,43 @@ class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
self.append(wds.SimpleShardList(urls))
if shuffle_shards:
self.append(wds.filters.shuffle(1000))
if embedding_folder_url is not None:
# There may be webdataset shards that do not have a embedding shard associated with it. If we do not skip these, they would cause issues.
self.append(skip_unassociated_shards(embeddings_url=embedding_folder_url, handler=handler))
self.append(wds.split_by_node)
self.append(wds.split_by_worker)
self.append(wds.tarfile_to_samples(handler=handler))
self.append(wds.decode("torchrgb"))
self.append(wds.decode("pilrgb", handler=handler))
if embedding_folder_url is not None:
assert shard_width is not None, "Reading embeddings separately requires shard length to be given"
self.append(insert_embedding(embeddings_url=embedding_folder_url, shard_width=shard_width, handler=handler))
# Then we are loading embeddings for a remote source
assert index_width is not None, "Reading embeddings separately requires index width length to be given"
self.append(insert_embedding(embeddings_url=embedding_folder_url, index_width=index_width, handler=handler))
self.append(verify_keys)
self.append(wds.to_tuple("jpg", "npy"))
# Apply preprocessing
self.append(wds.map(self.preproc))
self.append(wds.to_tuple(*keys))
def preproc(self, sample):
"""Applies the preprocessing for images"""
if self.img_preproc is not None:
sample["jpg"] = self.img_preproc(sample["jpg"])
return sample
def create_image_embedding_dataloader(
tar_url,
num_workers,
batch_size,
embeddings_url=None,
shard_width=None,
index_width=None,
shuffle_num = None,
shuffle_shards = True,
resample_shards = False,
handler=wds.handlers.warn_and_continue
img_preproc=None,
extra_keys=[],
handler=wds.handlers.reraise_exception#warn_and_continue
):
"""
Convenience function to create an image embedding dataseta and dataloader in one line
@@ -143,8 +205,8 @@ def create_image_embedding_dataloader(
:param batch_size: The batch size to use for the dataloader
:param embeddings_url: Required if webdataset does not contain embeddings. A url pointing to the npy files of the embeddings. Should have the same number of shards as the webdataset.
Webdataset image keys should align with the index of the embedding. This means missing image indices must have a corresponding embedding of all zeros.
:param shard_width: The number of digits in the shard number. This is used to align the embedding index with the image index.
For example, if a file in the webdataset shard 3 is named 0003039.jpg, we know the shard width is 4 and the last three digits are the index.
:param index_width: The number of digits in the index. This is used to align the embedding index with the image index.
For example, if a file in the webdataset shard 3 is named 0003039.jpg, we know the shard is 4 digits and the last 3 digits are the index_width.
:param shuffle_num: If not None, shuffle the dataset with this size buffer after sampling.
:param shuffle_shards: If true, shuffle the shards before sampling. This cannot be true if resample is true.
:param resample_shards: If true, resample webdataset shards with replacement. You need to set your own epoch size if this is true since it will resample infinitely.
@@ -153,9 +215,11 @@ def create_image_embedding_dataloader(
ds = ImageEmbeddingDataset(
tar_url,
embeddings_url,
shard_width=shard_width,
index_width=index_width,
shuffle_shards=shuffle_shards,
resample=resample_shards,
extra_keys=extra_keys,
img_preproc=img_preproc,
handler=handler
)
if shuffle_num is not None and shuffle_num > 0:

View File

@@ -0,0 +1,180 @@
from torch.utils.data import IterableDataset
from torch import from_numpy
from clip import tokenize
from embedding_reader import EmbeddingReader
class PriorEmbeddingLoader(IterableDataset):
def __init__(
self,
text_conditioned: bool,
batch_size: int,
start: int,
stop: int,
image_reader,
text_reader: EmbeddingReader = None,
device: str = "cpu",
) -> None:
super(PriorEmbeddingLoader).__init__()
self.text_conditioned = text_conditioned
if not self.text_conditioned:
self.text_reader = text_reader
self.image_reader = image_reader
self.batch_size = batch_size
self.start = start
self.stop = stop
self.device = device
def __iter__(self):
self.n = 0
loader_args = dict(
batch_size=self.batch_size,
start=self.start,
end=self.stop,
show_progress=False,
)
if self.text_conditioned:
self.loader = self.image_reader(**loader_args)
else:
self.loader = zip(
self.image_reader(**loader_args), self.text_reader(**loader_args)
)
return self
def __next__(self):
try:
return self.get_sample()
except StopIteration:
raise StopIteration
def get_sample(self):
"""
pre-proocess data from either reader into a common format
"""
self.n += 1
if self.text_conditioned:
image_embedding, caption = next(self.loader)
image_embedding = from_numpy(image_embedding).to(self.device)
tokenized_caption = tokenize(
caption["caption"].to_list(), truncate=True
).to(self.device)
return image_embedding, tokenized_caption
else:
(image_embedding, _), (text_embedding, _) = next(self.loader)
image_embedding = from_numpy(image_embedding).to(self.device)
text_embedding = from_numpy(text_embedding).to(self.device)
return image_embedding, text_embedding
def make_splits(
text_conditioned: bool,
batch_size: int,
num_data_points: int,
train_split: float,
eval_split: float,
device: str,
img_url: str,
meta_url: str = None,
txt_url: str = None,
):
assert img_url is not None, "Must supply some image embeddings"
if text_conditioned:
assert meta_url is not None, "Must supply metadata url if text-conditioning"
image_reader = EmbeddingReader(
embeddings_folder=img_url,
file_format="parquet_npy",
meta_columns=["caption"],
metadata_folder=meta_url,
)
# compute split points
if num_data_points > image_reader.count:
print("Specified point count is larger than the number of points available...defaulting to max length of reader.")
num_data_points = image_reader.count
train_set_size = int(train_split * num_data_points)
eval_set_size = int(eval_split * num_data_points)
eval_stop = int(train_set_size + eval_set_size)
train_loader = PriorEmbeddingLoader(
text_conditioned=text_conditioned,
image_reader=image_reader,
batch_size=batch_size,
start=0,
stop=train_set_size,
device=device,
)
eval_loader = PriorEmbeddingLoader(
text_conditioned=text_conditioned,
image_reader=image_reader,
batch_size=batch_size,
start=train_set_size,
stop=eval_stop,
device=device,
)
test_loader = PriorEmbeddingLoader(
text_conditioned=text_conditioned,
image_reader=image_reader,
batch_size=batch_size,
start=eval_stop,
stop=int(num_data_points),
device=device,
)
else:
assert (
txt_url is not None
), "Must supply text embedding url if not text-conditioning"
image_reader = EmbeddingReader(img_url, file_format="npy")
text_reader = EmbeddingReader(txt_url, file_format="npy")
# compute split points
if num_data_points > image_reader.count:
print("Specified point count is larger than the number of points available...defaulting to max length of reader.")
num_data_points = image_reader.count
train_set_size = int(train_split * num_data_points)
eval_set_size = int(eval_split * num_data_points)
eval_stop = int(train_set_size + eval_set_size)
train_loader = PriorEmbeddingLoader(
text_conditioned=text_conditioned,
image_reader=image_reader,
text_reader=text_reader,
batch_size=batch_size,
start=0,
stop=train_set_size,
device=device,
)
eval_loader = PriorEmbeddingLoader(
text_conditioned=text_conditioned,
image_reader=image_reader,
text_reader=text_reader,
batch_size=batch_size,
start=train_set_size,
stop=eval_stop,
device=device,
)
test_loader = PriorEmbeddingLoader(
text_conditioned=text_conditioned,
image_reader=image_reader,
text_reader=text_reader,
batch_size=batch_size,
start=eval_stop,
stop=int(num_data_points),
device=device,
)
return train_loader, eval_loader, test_loader

View File

@@ -0,0 +1,59 @@
from pathlib import Path
import torch
from torch.utils import data
from torchvision import transforms, utils
from PIL import Image
# helpers functions
def cycle(dl):
while True:
for data in dl:
yield data
# dataset and dataloader
class Dataset(data.Dataset):
def __init__(
self,
folder,
image_size,
exts = ['jpg', 'jpeg', 'png']
):
super().__init__()
self.folder = folder
self.image_size = image_size
self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]
self.transform = transforms.Compose([
transforms.Resize(image_size),
transforms.RandomHorizontalFlip(),
transforms.CenterCrop(image_size),
transforms.ToTensor()
])
def __len__(self):
return len(self.paths)
def __getitem__(self, index):
path = self.paths[index]
img = Image.open(path)
return self.transform(img)
def get_images_dataloader(
folder,
*,
batch_size,
image_size,
shuffle = True,
cycle_dl = True,
pin_memory = True
):
ds = Dataset(folder, image_size)
dl = data.DataLoader(ds, batch_size = batch_size, shuffle = shuffle, pin_memory = pin_memory)
if cycle_dl:
dl = cycle(dl)
return dl

View File

@@ -7,16 +7,18 @@ def separate_weight_decayable_params(params):
def get_optimizer(
params,
lr = 3e-4,
lr = 1e-4,
wd = 1e-2,
betas = (0.9, 0.999),
filter_by_requires_grad = False
eps = 1e-8,
filter_by_requires_grad = False,
**kwargs
):
if filter_by_requires_grad:
params = list(filter(lambda t: t.requires_grad, params))
if wd == 0:
return Adam(params, lr = lr, betas = betas)
return Adam(params, lr = lr, betas = betas, eps = eps)
params = set(params)
wd_params, no_wd_params = separate_weight_decayable_params(params)
@@ -26,4 +28,4 @@ def get_optimizer(
{'params': list(no_wd_params), 'weight_decay': 0},
]
return AdamW(param_groups, lr = lr, weight_decay = wd, betas = betas)
return AdamW(param_groups, lr = lr, weight_decay = wd, betas = betas, eps = eps)

115
dalle2_pytorch/trackers.py Normal file
View File

@@ -0,0 +1,115 @@
import os
from pathlib import Path
import importlib
from itertools import zip_longest
import torch
from torch import nn
# constants
DEFAULT_DATA_PATH = './.tracker-data'
# helper functions
def exists(val):
return val is not None
def import_or_print_error(pkg_name, err_str = None):
try:
return importlib.import_module(pkg_name)
except ModuleNotFoundError as e:
if exists(err_str):
print(err_str)
exit()
# load state dict functions
def load_wandb_state_dict(run_path, file_path, **kwargs):
wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb recall function')
file_reference = wandb.restore(file_path, run_path=run_path)
return torch.load(file_reference.name)
def load_local_state_dict(file_path, **kwargs):
return torch.load(file_path)
# base class
class BaseTracker(nn.Module):
def __init__(self, data_path = DEFAULT_DATA_PATH):
super().__init__()
self.data_path = Path(data_path)
self.data_path.mkdir(parents = True, exist_ok = True)
def init(self, config, **kwargs):
raise NotImplementedError
def log(self, log, **kwargs):
raise NotImplementedError
def log_images(self, images, **kwargs):
raise NotImplementedError
def save_state_dict(self, state_dict, relative_path, **kwargs):
raise NotImplementedError
def recall_state_dict(self, recall_source, *args, **kwargs):
"""
Loads a state dict from any source.
Since a user may wish to load a model from a different source than their own tracker (i.e. tracking using wandb but recalling from disk),
this should not be linked to any individual tracker.
"""
# TODO: Pull this into a dict or something similar so that we can add more sources without having a massive switch statement
if recall_source == 'wandb':
return load_wandb_state_dict(*args, **kwargs)
elif recall_source == 'local':
return load_local_state_dict(*args, **kwargs)
else:
raise ValueError('`recall_source` must be one of `wandb` or `local`')
# basic stdout class
class ConsoleTracker(BaseTracker):
def init(self, **config):
print(config)
def log(self, log, **kwargs):
print(log)
def log_images(self, images, **kwargs): # noop for logging images
pass
def save_state_dict(self, state_dict, relative_path, **kwargs):
torch.save(state_dict, str(self.data_path / relative_path))
# basic wandb class
class WandbTracker(BaseTracker):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb experiment tracker')
os.environ["WANDB_SILENT"] = "true"
def init(self, **config):
self.wandb.init(**config)
def log(self, log, verbose=False, **kwargs):
if verbose:
print(log)
self.wandb.log(log, **kwargs)
def log_images(self, images, captions=[], image_section="images", **kwargs):
"""
Takes a tensor of images and a list of captions and logs them to wandb.
"""
wandb_images = [self.wandb.Image(image, caption=caption) for image, caption in zip_longest(images, captions)]
self.log({ image_section: wandb_images }, **kwargs)
def save_state_dict(self, state_dict, relative_path, **kwargs):
"""
Saves a state_dict to disk and uploads it
"""
full_path = str(self.data_path / relative_path)
torch.save(state_dict, full_path)
self.wandb.save(full_path, base_path = str(self.data_path)) # Upload and keep relative to data_path

View File

@@ -1,275 +0,0 @@
import copy
from functools import partial
import torch
from torch import nn
from torch.cuda.amp import autocast, GradScaler
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
from dalle2_pytorch.optimizer import get_optimizer
# helper functions
def exists(val):
return val is not None
def cast_tuple(val, length = 1):
return val if isinstance(val, tuple) else ((val,) * length)
def pick_and_pop(keys, d):
values = list(map(lambda key: d.pop(key), keys))
return dict(zip(keys, values))
def group_dict_by_key(cond, d):
return_val = [dict(),dict()]
for key in d.keys():
match = bool(cond(key))
ind = int(not match)
return_val[ind][key] = d[key]
return (*return_val,)
def string_begins_with(prefix, str):
return str.startswith(prefix)
def group_by_key_prefix(prefix, d):
return group_dict_by_key(partial(string_begins_with, prefix), d)
def groupby_prefix_and_trim(prefix, d):
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
return kwargs_without_prefix, kwargs
# exponential moving average wrapper
class EMA(nn.Module):
def __init__(
self,
model,
beta = 0.99,
update_after_step = 1000,
update_every = 10,
):
super().__init__()
self.beta = beta
self.online_model = model
self.ema_model = copy.deepcopy(model)
self.update_after_step = update_after_step # only start EMA after this step number, starting at 0
self.update_every = update_every
self.register_buffer('initted', torch.Tensor([False]))
self.register_buffer('step', torch.tensor([0.]))
def update(self):
self.step += 1
if self.step <= self.update_after_step or (self.step % self.update_every) != 0:
return
if not self.initted:
self.ema_model.state_dict(self.online_model.state_dict())
self.initted.data.copy_(torch.Tensor([True]))
self.update_moving_average(self.ema_model, self.online_model)
def update_moving_average(self, ma_model, current_model):
def calculate_ema(beta, old, new):
if not exists(old):
return new
return old * beta + (1 - beta) * new
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
old_weight, up_weight = ma_params.data, current_params.data
ma_params.data = calculate_ema(self.beta, old_weight, up_weight)
for current_buffer, ma_buffer in zip(current_model.buffers(), ma_model.buffers()):
new_buffer_value = calculate_ema(self.beta, ma_buffer, current_buffer)
ma_buffer.copy_(new_buffer_value)
def __call__(self, *args, **kwargs):
return self.ema_model(*args, **kwargs)
# diffusion prior trainer
class DiffusionPriorTrainer(nn.Module):
def __init__(
self,
diffusion_prior,
use_ema = True,
lr = 3e-4,
wd = 1e-2,
max_grad_norm = None,
amp = False,
**kwargs
):
super().__init__()
assert isinstance(diffusion_prior, DiffusionPrior)
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
self.diffusion_prior = diffusion_prior
# exponential moving average
self.use_ema = use_ema
if self.use_ema:
self.ema_diffusion_prior = EMA(diffusion_prior, **ema_kwargs)
# optimizer and mixed precision stuff
self.amp = amp
self.scaler = GradScaler(enabled = amp)
self.optimizer = get_optimizer(
diffusion_prior.parameters(),
lr = lr,
wd = wd,
**kwargs
)
# gradient clipping if needed
self.max_grad_norm = max_grad_norm
def update(self):
if exists(self.max_grad_norm):
self.scaler.unscale_(self.optimizer)
nn.utils.clip_grad_norm_(self.diffusion_prior.parameters(), self.max_grad_norm)
self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.zero_grad()
if self.use_ema:
self.ema_diffusion_prior.update()
@torch.inference_mode()
def p_sample_loop(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs)
@torch.inference_mode()
def sample(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs)
@torch.inference_mode()
def sample_batch_size(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.sample_batch_size(*args, **kwargs)
def forward(
self,
*args,
divisor = 1,
**kwargs
):
with autocast(enabled = self.amp):
loss = self.diffusion_prior(*args, **kwargs)
return self.scaler.scale(loss / divisor)
# decoder trainer
class DecoderTrainer(nn.Module):
def __init__(
self,
decoder,
use_ema = True,
lr = 3e-4,
wd = 1e-2,
max_grad_norm = None,
amp = False,
**kwargs
):
super().__init__()
assert isinstance(decoder, Decoder)
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
self.decoder = decoder
self.num_unets = len(self.decoder.unets)
self.use_ema = use_ema
if use_ema:
has_lazy_linear = any([type(module) == nn.LazyLinear for module in decoder.modules()])
assert not has_lazy_linear, 'you must set the text_embed_dim on your u-nets if you plan on doing automatic exponential moving average'
self.ema_unets = nn.ModuleList([])
self.amp = amp
# be able to finely customize learning rate, weight decay
# per unet
lr, wd = map(partial(cast_tuple, length = self.num_unets), (lr, wd))
for ind, (unet, unet_lr, unet_wd) in enumerate(zip(self.decoder.unets, lr, wd)):
optimizer = get_optimizer(
unet.parameters(),
lr = unet_lr,
wd = unet_wd,
**kwargs
)
setattr(self, f'optim{ind}', optimizer) # cannot use pytorch ModuleList for some reason with optimizers
if self.use_ema:
self.ema_unets.append(EMA(unet, **ema_kwargs))
scaler = GradScaler(enabled = amp)
setattr(self, f'scaler{ind}', scaler)
# gradient clipping if needed
self.max_grad_norm = max_grad_norm
@property
def unets(self):
return nn.ModuleList([ema.ema_model for ema in self.ema_unets])
def scale(self, loss, *, unet_number):
assert 1 <= unet_number <= self.num_unets
index = unet_number - 1
scaler = getattr(self, f'scaler{index}')
return scaler.scale(loss)
def update(self, unet_number):
assert 1 <= unet_number <= self.num_unets
index = unet_number - 1
unet = self.decoder.unets[index]
optimizer = getattr(self, f'optim{index}')
scaler = getattr(self, f'scaler{index}')
if exists(self.max_grad_norm):
scaler.unscale_(optimizer)
nn.utils.clip_grad_norm_(unet.parameters(), self.max_grad_norm)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
if self.use_ema:
ema_unet = self.ema_unets[index]
ema_unet.update()
@torch.no_grad()
def sample(self, *args, **kwargs):
if self.use_ema:
trainable_unets = self.decoder.unets
self.decoder.unets = self.unets # swap in exponential moving averaged unets for sampling
output = self.decoder.sample(*args, **kwargs)
if self.use_ema:
self.decoder.unets = trainable_unets # restore original training unets
return output
def forward(
self,
x,
*,
unet_number,
divisor = 1,
**kwargs
):
with autocast(enabled = self.amp):
loss = self.decoder(x, unet_number = unet_number, **kwargs)
return self.scale(loss / divisor, unet_number = unet_number)

View File

@@ -0,0 +1,128 @@
from torchvision import transforms as T
from pydantic import BaseModel, validator, root_validator
from typing import List, Iterable, Optional, Union, Tuple, Dict, Any
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
class UnetConfig(BaseModel):
dim: int
dim_mults: List[int]
image_embed_dim: int = None
cond_dim: int = None
channels: int = 3
attn_dim_head: int = 32
attn_heads: int = 16
class Config:
extra = "allow"
class DecoderConfig(BaseModel):
image_size: int = None
image_sizes: Union[List[int], Tuple[int]] = None
channels: int = 3
timesteps: int = 1000
loss_type: str = 'l2'
beta_schedule: str = 'cosine'
learned_variance: bool = True
@validator('image_sizes')
def check_image_sizes(cls, image_sizes, values):
if exists(values.get('image_size')) ^ exists(image_sizes):
return image_sizes
raise ValueError('either image_size or image_sizes is required, but not both')
class Config:
extra = "allow"
class TrainSplitConfig(BaseModel):
train: float = 0.75
val: float = 0.15
test: float = 0.1
@root_validator
def validate_all(cls, fields):
if sum([*fields.values()]) != 1.:
raise ValueError(f'{fields.keys()} must sum to 1.0')
return fields
class DecoderDataConfig(BaseModel):
webdataset_base_url: str # path to a webdataset with jpg images
embeddings_url: str # path to .npy files with embeddings
num_workers: int = 4
batch_size: int = 64
start_shard: int = 0
end_shard: int = 9999999
shard_width: int = 6
index_width: int = 4
splits: TrainSplitConfig
shuffle_train: bool = True
resample_train: bool = False
preprocessing: Dict[str, Any] = {'ToTensor': True}
class DecoderTrainConfig(BaseModel):
epochs: int = 20
lr: float = 1e-4
wd: float = 0.01
max_grad_norm: float = 0.5
save_every_n_samples: int = 100000
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
device: str = 'cuda:0'
epoch_samples: int = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
validation_samples: int = None # Same as above but for validation.
use_ema: bool = True
ema_beta: float = 0.99
amp: bool = False
save_all: bool = False # Whether to preserve all checkpoints
save_latest: bool = True # Whether to always save the latest checkpoint
save_best: bool = True # Whether to save the best checkpoint
unet_training_mask: List[bool] = None # If None, use all unets
class DecoderEvaluateConfig(BaseModel):
n_evaluation_samples: int = 1000
FID: Dict[str, Any] = None
IS: Dict[str, Any] = None
KID: Dict[str, Any] = None
LPIPS: Dict[str, Any] = None
class TrackerConfig(BaseModel):
tracker_type: str = 'console' # Decoder currently supports console and wandb
data_path: str = './models' # The path where files will be saved locally
init_config: Dict[str, Any] = None
wandb_entity: str = '' # Only needs to be set if tracker_type is wandb
wandb_project: str = ''
verbose: bool = False # Whether to print console logging for non-console trackers
class DecoderLoadConfig(BaseModel):
source: str = None # Supports file and wandb
run_path: str = '' # Used only if source is wandb
file_path: str = '' # The local filepath if source is file. If source is wandb, the relative path to the model file in wandb.
resume: bool = False # If using wandb, whether to resume the run
class TrainDecoderConfig(BaseModel):
unets: List[UnetConfig]
decoder: DecoderConfig
data: DecoderDataConfig
train: DecoderTrainConfig
evaluate: DecoderEvaluateConfig
tracker: TrackerConfig
load: DecoderLoadConfig
@property
def img_preproc(self):
def _get_transformation(transformation_name, **kwargs):
if transformation_name == "RandomResizedCrop":
return T.RandomResizedCrop(**kwargs)
elif transformation_name == "RandomHorizontalFlip":
return T.RandomHorizontalFlip()
elif transformation_name == "ToTensor":
return T.ToTensor()
transforms = []
for transform_name, transform_kwargs_or_bool in self.data.preprocessing.items():
transform_kwargs = {} if not isinstance(transform_kwargs_or_bool, dict) else transform_kwargs_or_bool
transforms.append(_get_transformation(transform_name, **transform_kwargs))
return T.Compose(transforms)

491
dalle2_pytorch/trainer.py Normal file
View File

@@ -0,0 +1,491 @@
import time
import copy
from math import ceil
from functools import partial, wraps
from collections.abc import Iterable
import torch
from torch import nn
from torch.cuda.amp import autocast, GradScaler
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
from dalle2_pytorch.optimizer import get_optimizer
import numpy as np
# helper functions
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def cast_tuple(val, length = 1):
return val if isinstance(val, tuple) else ((val,) * length)
def pick_and_pop(keys, d):
values = list(map(lambda key: d.pop(key), keys))
return dict(zip(keys, values))
def group_dict_by_key(cond, d):
return_val = [dict(),dict()]
for key in d.keys():
match = bool(cond(key))
ind = int(not match)
return_val[ind][key] = d[key]
return (*return_val,)
def string_begins_with(prefix, str):
return str.startswith(prefix)
def group_by_key_prefix(prefix, d):
return group_dict_by_key(partial(string_begins_with, prefix), d)
def groupby_prefix_and_trim(prefix, d):
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
return kwargs_without_prefix, kwargs
def num_to_groups(num, divisor):
groups = num // divisor
remainder = num % divisor
arr = [divisor] * groups
if remainder > 0:
arr.append(remainder)
return arr
# decorators
def cast_torch_tensor(fn):
@wraps(fn)
def inner(model, *args, **kwargs):
device = kwargs.pop('_device', next(model.parameters()).device)
cast_device = kwargs.pop('_cast_device', True)
kwargs_keys = kwargs.keys()
all_args = (*args, *kwargs.values())
split_kwargs_index = len(all_args) - len(kwargs_keys)
all_args = tuple(map(lambda t: torch.from_numpy(t) if exists(t) and isinstance(t, np.ndarray) else t, all_args))
if cast_device:
all_args = tuple(map(lambda t: t.to(device) if exists(t) and isinstance(t, torch.Tensor) else t, all_args))
args, kwargs_values = all_args[:split_kwargs_index], all_args[split_kwargs_index:]
kwargs = dict(tuple(zip(kwargs_keys, kwargs_values)))
out = fn(model, *args, **kwargs)
return out
return inner
# gradient accumulation functions
def split_iterable(it, split_size):
accum = []
for ind in range(ceil(len(it) / split_size)):
start_index = ind * split_size
accum.append(it[start_index: (start_index + split_size)])
return accum
def split(t, split_size = None):
if not exists(split_size):
return t
if isinstance(t, torch.Tensor):
return t.split(split_size, dim = 0)
if isinstance(t, Iterable):
return split_iterable(t, split_size)
return TypeError
def find_first(cond, arr):
for el in arr:
if cond(el):
return el
return None
def split_args_and_kwargs(*args, split_size = None, **kwargs):
all_args = (*args, *kwargs.values())
len_all_args = len(all_args)
first_tensor = find_first(lambda t: isinstance(t, torch.Tensor), all_args)
assert exists(first_tensor)
batch_size = len(first_tensor)
split_size = default(split_size, batch_size)
num_chunks = ceil(batch_size / split_size)
dict_len = len(kwargs)
dict_keys = kwargs.keys()
split_kwargs_index = len_all_args - dict_len
split_all_args = [split(arg, split_size = split_size) if exists(arg) and isinstance(arg, (torch.Tensor, Iterable)) else ((arg,) * num_chunks) for arg in all_args]
chunk_sizes = tuple(map(len, split_all_args[0]))
for (chunk_size, *chunked_all_args) in tuple(zip(chunk_sizes, *split_all_args)):
chunked_args, chunked_kwargs_values = chunked_all_args[:split_kwargs_index], chunked_all_args[split_kwargs_index:]
chunked_kwargs = dict(tuple(zip(dict_keys, chunked_kwargs_values)))
chunk_size_frac = chunk_size / batch_size
yield chunk_size_frac, (chunked_args, chunked_kwargs)
# print helpers
def print_ribbon(s, symbol = '=', repeat = 40):
flank = symbol * repeat
return f'{flank} {s} {flank}'
# saving and loading functions
# for diffusion prior
def load_diffusion_model(dprior_path, device):
dprior_path = Path(dprior_path)
assert dprior_path.exists(), 'Dprior model file does not exist'
loaded_obj = torch.load(str(dprior_path), map_location='cpu')
# Get hyperparameters of loaded model
dpn_config = loaded_obj['hparams']['diffusion_prior_network']
dp_config = loaded_obj['hparams']['diffusion_prior']
image_embed_dim = loaded_obj['image_embed_dim']['image_embed_dim']
# Create DiffusionPriorNetwork and DiffusionPrior with loaded hyperparameters
# DiffusionPriorNetwork
prior_network = DiffusionPriorNetwork( dim = image_embed_dim, **dpn_config).to(device)
# DiffusionPrior with text embeddings and image embeddings pre-computed
diffusion_prior = DiffusionPrior(net = prior_network, **dp_config, image_embed_dim = image_embed_dim).to(device)
# Load state dict from saved model
diffusion_prior.load_state_dict(loaded_obj['model'])
return diffusion_prior, loaded_obj
def save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim):
# Saving State Dict
print_ribbon('Saving checkpoint')
state_dict = dict(model=model.state_dict(),
optimizer=optimizer.state_dict(),
scaler=scaler.state_dict(),
hparams = config,
image_embed_dim = {"image_embed_dim":image_embed_dim})
torch.save(state_dict, save_path+'/'+str(time.time())+'_saved_model.pth')
# exponential moving average wrapper
class EMA(nn.Module):
def __init__(
self,
model,
beta = 0.9999,
update_after_step = 1000,
update_every = 10,
):
super().__init__()
self.beta = beta
self.online_model = model
self.ema_model = copy.deepcopy(model)
self.update_every = update_every
self.update_after_step = update_after_step // update_every # only start EMA after this step number, starting at 0
self.register_buffer('initted', torch.Tensor([False]))
self.register_buffer('step', torch.tensor([0.]))
def restore_ema_model_device(self):
device = self.initted.device
self.ema_model.to(device)
def copy_params_from_model_to_ema(self):
self.ema_model.state_dict(self.online_model.state_dict())
def update(self):
self.step += 1
if (self.step % self.update_every) != 0:
return
if self.step <= self.update_after_step:
self.copy_params_from_model_to_ema()
return
if not self.initted:
self.copy_params_from_model_to_ema()
self.initted.data.copy_(torch.Tensor([True]))
self.update_moving_average(self.ema_model, self.online_model)
def update_moving_average(self, ma_model, current_model):
def calculate_ema(beta, old, new):
if not exists(old):
return new
return old * beta + (1 - beta) * new
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
old_weight, up_weight = ma_params.data, current_params.data
ma_params.data = calculate_ema(self.beta, old_weight, up_weight)
for current_buffer, ma_buffer in zip(current_model.buffers(), ma_model.buffers()):
new_buffer_value = calculate_ema(self.beta, ma_buffer, current_buffer)
ma_buffer.copy_(new_buffer_value)
def __call__(self, *args, **kwargs):
return self.ema_model(*args, **kwargs)
# diffusion prior trainer
def prior_sample_in_chunks(fn):
@wraps(fn)
def inner(self, *args, max_batch_size = None, **kwargs):
if not exists(max_batch_size):
return fn(self, *args, **kwargs)
outputs = [fn(self, *chunked_args, **chunked_kwargs) for _, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs)]
return torch.cat(outputs, dim = 0)
return inner
class DiffusionPriorTrainer(nn.Module):
def __init__(
self,
diffusion_prior,
use_ema = True,
lr = 3e-4,
wd = 1e-2,
eps = 1e-6,
max_grad_norm = None,
amp = False,
**kwargs
):
super().__init__()
assert isinstance(diffusion_prior, DiffusionPrior)
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
self.diffusion_prior = diffusion_prior
# exponential moving average
self.use_ema = use_ema
if self.use_ema:
self.ema_diffusion_prior = EMA(diffusion_prior, **ema_kwargs)
# optimizer and mixed precision stuff
self.amp = amp
self.scaler = GradScaler(enabled = amp)
self.optimizer = get_optimizer(
diffusion_prior.parameters(),
lr = lr,
wd = wd,
eps = eps,
**kwargs
)
# gradient clipping if needed
self.max_grad_norm = max_grad_norm
self.register_buffer('step', torch.tensor([0.]))
def update(self):
if exists(self.max_grad_norm):
self.scaler.unscale_(self.optimizer)
nn.utils.clip_grad_norm_(self.diffusion_prior.parameters(), self.max_grad_norm)
self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.zero_grad()
if self.use_ema:
self.ema_diffusion_prior.update()
self.step += 1
@torch.no_grad()
@cast_torch_tensor
@prior_sample_in_chunks
def p_sample_loop(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs)
@torch.no_grad()
@cast_torch_tensor
@prior_sample_in_chunks
def sample(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs)
@torch.no_grad()
def sample_batch_size(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.sample_batch_size(*args, **kwargs)
@cast_torch_tensor
def forward(
self,
*args,
max_batch_size = None,
**kwargs
):
total_loss = 0.
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
with autocast(enabled = self.amp):
loss = self.diffusion_prior(*chunked_args, **chunked_kwargs)
loss = loss * chunk_size_frac
total_loss += loss.item()
if self.training:
self.scaler.scale(loss).backward()
return total_loss
# decoder trainer
def decoder_sample_in_chunks(fn):
@wraps(fn)
def inner(self, *args, max_batch_size = None, **kwargs):
if not exists(max_batch_size):
return fn(self, *args, **kwargs)
if self.decoder.unconditional:
batch_size = kwargs.get('batch_size')
batch_sizes = num_to_groups(batch_size, max_batch_size)
outputs = [fn(self, *args, **{**kwargs, 'batch_size': sub_batch_size}) for sub_batch_size in batch_sizes]
else:
outputs = [fn(self, *chunked_args, **chunked_kwargs) for _, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs)]
return torch.cat(outputs, dim = 0)
return inner
class DecoderTrainer(nn.Module):
def __init__(
self,
decoder,
use_ema = True,
lr = 1e-4,
wd = 1e-2,
eps = 1e-8,
max_grad_norm = 0.5,
amp = False,
**kwargs
):
super().__init__()
assert isinstance(decoder, Decoder)
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
self.decoder = decoder
self.num_unets = len(self.decoder.unets)
self.use_ema = use_ema
self.ema_unets = nn.ModuleList([])
self.amp = amp
# be able to finely customize learning rate, weight decay
# per unet
lr, wd, eps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps))
for ind, (unet, unet_lr, unet_wd, unet_eps) in enumerate(zip(self.decoder.unets, lr, wd, eps)):
optimizer = get_optimizer(
unet.parameters(),
lr = unet_lr,
wd = unet_wd,
eps = unet_eps,
**kwargs
)
setattr(self, f'optim{ind}', optimizer) # cannot use pytorch ModuleList for some reason with optimizers
if self.use_ema:
self.ema_unets.append(EMA(unet, **ema_kwargs))
scaler = GradScaler(enabled = amp)
setattr(self, f'scaler{ind}', scaler)
# gradient clipping if needed
self.max_grad_norm = max_grad_norm
self.register_buffer('step', torch.tensor([0.]))
@property
def unets(self):
return nn.ModuleList([ema.ema_model for ema in self.ema_unets])
def scale(self, loss, *, unet_number):
assert 1 <= unet_number <= self.num_unets
index = unet_number - 1
scaler = getattr(self, f'scaler{index}')
return scaler.scale(loss)
def update(self, unet_number = None):
if self.num_unets == 1:
unet_number = default(unet_number, 1)
assert exists(unet_number) and 1 <= unet_number <= self.num_unets
index = unet_number - 1
unet = self.decoder.unets[index]
optimizer = getattr(self, f'optim{index}')
scaler = getattr(self, f'scaler{index}')
if exists(self.max_grad_norm):
scaler.unscale_(optimizer)
nn.utils.clip_grad_norm_(unet.parameters(), self.max_grad_norm)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
if self.use_ema:
ema_unet = self.ema_unets[index]
ema_unet.update()
self.step += 1
@torch.no_grad()
@cast_torch_tensor
@decoder_sample_in_chunks
def sample(self, *args, **kwargs):
if kwargs.pop('use_non_ema', False) or not self.use_ema:
return self.decoder.sample(*args, **kwargs)
trainable_unets = self.decoder.unets
self.decoder.unets = self.unets # swap in exponential moving averaged unets for sampling
output = self.decoder.sample(*args, **kwargs)
self.decoder.unets = trainable_unets # restore original training unets
# cast the ema_model unets back to original device
for ema in self.ema_unets:
ema.restore_ema_model_device()
return output
@cast_torch_tensor
def forward(
self,
*args,
unet_number = None,
max_batch_size = None,
**kwargs
):
if self.num_unets == 1:
unet_number = default(unet_number, 1)
total_loss = 0.
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
with autocast(enabled = self.amp):
loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs)
loss = loss * chunk_size_frac
total_loss += loss.item()
if self.training:
self.scale(loss, unet_number = unet_number).backward()
return total_loss

11
dalle2_pytorch/utils.py Normal file
View File

@@ -0,0 +1,11 @@
import time
class Timer:
def __init__(self):
self.reset()
def reset(self):
self.last_time = time.time()
def elapsed(self):
return time.time() - self.last_time

BIN
samples/oxford.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 985 KiB

View File

@@ -10,11 +10,12 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.1.9',
version = '0.4.1',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',
author_email = 'lucidrains@gmail.com',
long_description_content_type = 'text/markdown',
url = 'https://github.com/lucidrains/dalle2-pytorch',
keywords = [
'artificial intelligence',
@@ -29,7 +30,9 @@ setup(
'einops-exts>=0.0.3',
'embedding-reader',
'kornia>=0.5.4',
'numpy',
'pillow',
'pydantic',
'resize-right>=0.0.2',
'rotary-embedding-torch',
'torch>=1.10',
@@ -39,7 +42,8 @@ setup(
'x-clip>=0.4.4',
'youtokentome',
'webdataset>=0.2.5',
'fsspec>=2022.1.0'
'fsspec>=2022.1.0',
'torchmetrics[image]>=0.8.0'
],
classifiers=[
'Development Status :: 4 - Beta',

459
train_decoder.py Normal file
View File

@@ -0,0 +1,459 @@
from dalle2_pytorch import Unet, Decoder
from dalle2_pytorch.trainer import DecoderTrainer, print_ribbon
from dalle2_pytorch.dataloaders import create_image_embedding_dataloader
from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker
from dalle2_pytorch.train_configs import TrainDecoderConfig
from dalle2_pytorch.utils import Timer
import json
import torchvision
import torch
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore
from torchmetrics.image.kid import KernelInceptionDistance
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
import webdataset as wds
import click
# constants
TRAIN_CALC_LOSS_EVERY_ITERS = 10
VALID_CALC_LOSS_EVERY_ITERS = 10
# helpers functions
def exists(val):
return val is not None
# main functions
def create_dataloaders(
available_shards,
webdataset_base_url,
embeddings_url,
shard_width=6,
num_workers=4,
batch_size=32,
n_sample_images=6,
shuffle_train=True,
resample_train=False,
img_preproc = None,
index_width=4,
train_prop = 0.75,
val_prop = 0.15,
test_prop = 0.10,
**kwargs
):
"""
Randomly splits the available shards into train, val, and test sets and returns a dataloader for each
"""
assert train_prop + test_prop + val_prop == 1
num_train = round(train_prop*len(available_shards))
num_test = round(test_prop*len(available_shards))
num_val = len(available_shards) - num_train - num_test
assert num_train + num_test + num_val == len(available_shards), f"{num_train} + {num_test} + {num_val} = {num_train + num_test + num_val} != {len(available_shards)}"
train_split, test_split, val_split = torch.utils.data.random_split(available_shards, [num_train, num_test, num_val], generator=torch.Generator().manual_seed(0))
# The shard number in the webdataset file names has a fixed width. We zero pad the shard numbers so they correspond to a filename.
train_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in train_split]
test_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in test_split]
val_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in val_split]
create_dataloader = lambda tar_urls, shuffle=False, resample=False, with_text=False, for_sampling=False: create_image_embedding_dataloader(
tar_url=tar_urls,
num_workers=num_workers,
batch_size=batch_size if not for_sampling else n_sample_images,
embeddings_url=embeddings_url,
index_width=index_width,
shuffle_num = None,
extra_keys= ["txt"] if with_text else [],
shuffle_shards = shuffle,
resample_shards = resample,
img_preproc=img_preproc,
handler=wds.handlers.warn_and_continue
)
train_dataloader = create_dataloader(train_urls, shuffle=shuffle_train, resample=resample_train)
train_sampling_dataloader = create_dataloader(train_urls, shuffle=False, for_sampling=True)
val_dataloader = create_dataloader(val_urls, shuffle=False, with_text=True)
test_dataloader = create_dataloader(test_urls, shuffle=False, with_text=True)
test_sampling_dataloader = create_dataloader(test_urls, shuffle=False, for_sampling=True)
return {
"train": train_dataloader,
"train_sampling": train_sampling_dataloader,
"val": val_dataloader,
"test": test_dataloader,
"test_sampling": test_sampling_dataloader
}
def create_decoder(device, decoder_config, unets_config):
"""Creates a sample decoder"""
unets = [Unet(**config.dict()) for config in unets_config]
decoder = Decoder(
unet=unets,
**decoder_config.dict()
)
decoder.to(device=device)
return decoder
def get_dataset_keys(dataloader):
"""
It is sometimes neccesary to get the keys the dataloader is returning. Since the dataset is burried in the dataloader, we need to do a process to recover it.
"""
# If the dataloader is actually a WebLoader, we need to extract the real dataloader
if isinstance(dataloader, wds.WebLoader):
dataloader = dataloader.pipeline[0]
return dataloader.dataset.key_map
def get_example_data(dataloader, device, n=5):
"""
Samples the dataloader and returns a zipped list of examples
"""
images = []
embeddings = []
captions = []
dataset_keys = get_dataset_keys(dataloader)
has_caption = "txt" in dataset_keys
for data in dataloader:
if has_caption:
img, emb, txt = data
else:
img, emb = data
txt = [""] * emb.shape[0]
img = img.to(device=device, dtype=torch.float)
emb = emb.to(device=device, dtype=torch.float)
images.extend(list(img))
embeddings.extend(list(emb))
captions.extend(list(txt))
if len(images) >= n:
break
print("Generated {} examples".format(len(images)))
return list(zip(images[:n], embeddings[:n], captions[:n]))
def generate_samples(trainer, example_data, text_prepend=""):
"""
Takes example data and generates images from the embeddings
Returns three lists: real images, generated images, and captions
"""
real_images, embeddings, txts = zip(*example_data)
embeddings_tensor = torch.stack(embeddings)
samples = trainer.sample(embeddings_tensor)
generated_images = list(samples)
captions = [text_prepend + txt for txt in txts]
return real_images, generated_images, captions
def generate_grid_samples(trainer, examples, text_prepend=""):
"""
Generates samples and uses torchvision to put them in a side by side grid for easy viewing
"""
real_images, generated_images, captions = generate_samples(trainer, examples, text_prepend)
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
return grid_images, captions
def evaluate_trainer(trainer, dataloader, device, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
"""
Computes evaluation metrics for the decoder
"""
metrics = {}
# Prepare the data
examples = get_example_data(dataloader, device, n_evaluation_samples)
real_images, generated_images, captions = generate_samples(trainer, examples)
real_images = torch.stack(real_images).to(device=device, dtype=torch.float)
generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float)
# Convert from [0, 1] to [0, 255] and from torch.float to torch.uint8
int_real_images = real_images.mul(255).add(0.5).clamp(0, 255).type(torch.uint8)
int_generated_images = generated_images.mul(255).add(0.5).clamp(0, 255).type(torch.uint8)
if exists(FID):
fid = FrechetInceptionDistance(**FID)
fid.to(device=device)
fid.update(int_real_images, real=True)
fid.update(int_generated_images, real=False)
metrics["FID"] = fid.compute().item()
if exists(IS):
inception = InceptionScore(**IS)
inception.to(device=device)
inception.update(int_real_images)
is_mean, is_std = inception.compute()
metrics["IS_mean"] = is_mean.item()
metrics["IS_std"] = is_std.item()
if exists(KID):
kernel_inception = KernelInceptionDistance(**KID)
kernel_inception.to(device=device)
kernel_inception.update(int_real_images, real=True)
kernel_inception.update(int_generated_images, real=False)
kid_mean, kid_std = kernel_inception.compute()
metrics["KID_mean"] = kid_mean.item()
metrics["KID_std"] = kid_std.item()
if exists(LPIPS):
# Convert from [0, 1] to [-1, 1]
renorm_real_images = real_images.mul(2).sub(1)
renorm_generated_images = generated_images.mul(2).sub(1)
lpips = LearnedPerceptualImagePatchSimilarity(**LPIPS)
lpips.to(device=device)
lpips.update(renorm_real_images, renorm_generated_images)
metrics["LPIPS"] = lpips.compute().item()
return metrics
def save_trainer(tracker, trainer, epoch, step, validation_losses, relative_paths):
"""
Logs the model with an appropriate method depending on the tracker
"""
if isinstance(relative_paths, str):
relative_paths = [relative_paths]
trainer_state_dict = {}
trainer_state_dict["trainer"] = trainer.state_dict()
trainer_state_dict['epoch'] = epoch
trainer_state_dict['step'] = step
trainer_state_dict['validation_losses'] = validation_losses
for relative_path in relative_paths:
tracker.save_state_dict(trainer_state_dict, relative_path)
def recall_trainer(tracker, trainer, recall_source=None, **load_config):
"""
Loads the model with an appropriate method depending on the tracker
"""
print(print_ribbon(f"Loading model from {recall_source}"))
state_dict = tracker.recall_state_dict(recall_source, **load_config)
trainer.load_state_dict(state_dict["trainer"])
print("Model loaded")
return state_dict["epoch"], state_dict["step"], state_dict["validation_losses"]
def train(
dataloaders,
decoder,
tracker,
inference_device,
load_config=None,
evaluate_config=None,
epoch_samples = None, # If the training dataset is resampling, we have to manually stop an epoch
validation_samples = None,
epochs = 20,
n_sample_images = 5,
save_every_n_samples = 100000,
save_all=False,
save_latest=True,
save_best=True,
unet_training_mask=None,
**kwargs
):
"""
Trains a decoder on a dataset.
"""
trainer = DecoderTrainer( # TODO: Change the get_optimizer function so that it can take arbitrary named args so we can just put **kwargs as an argument here
decoder,
**kwargs
)
# Set up starting model and parameters based on a recalled state dict
start_step = 0
start_epoch = 0
validation_losses = []
if exists(load_config) and exists(load_config.source):
start_epoch, start_step, validation_losses = recall_trainer(tracker, trainer, recall_source=load_config.source, **load_config)
trainer.to(device=inference_device)
if not exists(unet_training_mask):
# Then the unet mask should be true for all unets in the decoder
unet_training_mask = [True] * trainer.num_unets
assert len(unet_training_mask) == trainer.num_unets, f"The unet training mask should be the same length as the number of unets in the decoder. Got {len(unet_training_mask)} and {trainer.num_unets}"
print(print_ribbon("Generating Example Data", repeat=40))
print("This can take a while to load the shard lists...")
train_example_data = get_example_data(dataloaders["train_sampling"], inference_device, n_sample_images)
test_example_data = get_example_data(dataloaders["test_sampling"], inference_device, n_sample_images)
send_to_device = lambda arr: [x.to(device=inference_device, dtype=torch.float) for x in arr]
step = start_step
for epoch in range(start_epoch, epochs):
print(print_ribbon(f"Starting epoch {epoch}", repeat=40))
timer = Timer()
sample = 0
last_sample = 0
last_snapshot = 0
losses = []
for i, (img, emb) in enumerate(dataloaders["train"]):
step += 1
sample += img.shape[0]
img, emb = send_to_device((img, emb))
trainer.train()
for unet in range(1, trainer.num_unets+1):
# Check if this is a unet we are training
if not unet_training_mask[unet-1]: # Unet index is the unet number - 1
continue
loss = trainer.forward(img, image_embed=emb, unet_number=unet)
trainer.update(unet_number=unet)
losses.append(loss)
samples_per_sec = (sample - last_sample) / timer.elapsed()
timer.reset()
last_sample = sample
if i % TRAIN_CALC_LOSS_EVERY_ITERS == 0:
average_loss = sum(losses) / len(losses)
log_data = {
"Training loss": average_loss,
"Epoch": epoch,
"Sample": sample,
"Step": i,
"Samples per second": samples_per_sec
}
tracker.log(log_data, step=step, verbose=True)
losses = []
if last_snapshot + save_every_n_samples < sample: # This will miss by some amount every time, but it's not a big deal... I hope
last_snapshot = sample
# We need to know where the model should be saved
save_paths = []
if save_latest:
save_paths.append("latest.pth")
if save_all:
save_paths.append(f"checkpoints/epoch_{epoch}_step_{step}.pth")
save_trainer(tracker, trainer, epoch, step, validation_losses, save_paths)
if exists(n_sample_images) and n_sample_images > 0:
trainer.eval()
train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ")
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step)
if exists(epoch_samples) and sample >= epoch_samples:
break
trainer.eval()
print(print_ribbon(f"Starting Validation {epoch}", repeat=40))
with torch.no_grad():
sample = 0
average_loss = 0
timer = Timer()
for i, (img, emb, txt) in enumerate(dataloaders["val"]):
sample += img.shape[0]
img, emb = send_to_device((img, emb))
for unet in range(1, len(decoder.unets)+1):
loss = trainer.forward(img.float(), image_embed=emb.float(), unet_number=unet)
average_loss += loss
if i % VALID_CALC_LOSS_EVERY_ITERS == 0:
print(f"Epoch {epoch}/{epochs} - {sample / timer.elapsed():.2f} samples/sec")
print(f"Loss: {average_loss / (i+1)}")
print("")
if exists(validation_samples) and sample >= validation_samples:
break
average_loss /= i+1
log_data = {
"Validation loss": average_loss
}
tracker.log(log_data, step=step, verbose=True)
# Compute evaluation metrics
if exists(evaluate_config):
print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config)
tracker.log(evaluation, step=step, verbose=True)
# Generate sample images
print(print_ribbon(f"Sampling Set {epoch}", repeat=40))
test_images, test_captions = generate_grid_samples(trainer, test_example_data, "Test: ")
train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ")
tracker.log_images(test_images, captions=test_captions, image_section="Test Samples", step=step)
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step)
print(print_ribbon(f"Starting Saving {epoch}", repeat=40))
# Get the same paths
save_paths = []
if save_latest:
save_paths.append("latest.pth")
if save_best and (len(validation_losses) == 0 or average_loss < min(validation_losses)):
save_paths.append("best.pth")
validation_losses.append(average_loss)
save_trainer(tracker, trainer, epoch, step, validation_losses, save_paths)
def create_tracker(config, tracker_type=None, data_path=None, **kwargs):
"""
Creates a tracker of the specified type and initializes special features based on the full config
"""
tracker_config = config.tracker
init_config = {}
if exists(tracker_config.init_config):
init_config["config"] = tracker_config.init_config
if tracker_type == "console":
tracker = ConsoleTracker(**init_config)
elif tracker_type == "wandb":
# We need to initialize the resume state here
load_config = config.load
if load_config.source == "wandb" and load_config.resume:
# Then we are resuming the run load_config["run_path"]
run_id = load_config.run_path.split("/")[-1]
init_config["id"] = run_id
init_config["resume"] = "must"
init_config["entity"] = tracker_config.wandb_entity
init_config["project"] = tracker_config.wandb_project
tracker = WandbTracker(data_path)
tracker.init(**init_config)
else:
raise ValueError(f"Tracker type {tracker_type} not supported by decoder trainer")
return tracker
def initialize_training(config):
# Create the save path
if "cuda" in config.train.device:
assert torch.cuda.is_available(), "CUDA is not available"
device = torch.device(config.train.device)
torch.cuda.set_device(device)
all_shards = list(range(config.data.start_shard, config.data.end_shard + 1))
dataloaders = create_dataloaders (
available_shards=all_shards,
img_preproc = config.img_preproc,
train_prop = config.data.splits.train,
val_prop = config.data.splits.val,
test_prop = config.data.splits.test,
n_sample_images=config.train.n_sample_images,
**config.data.dict()
)
decoder = create_decoder(device, config.decoder, config.unets)
num_parameters = sum(p.numel() for p in decoder.parameters())
print(print_ribbon("Loaded Config", repeat=40))
print(f"Number of parameters: {num_parameters}")
tracker = create_tracker(config, **config.tracker.dict())
train(dataloaders, decoder,
tracker=tracker,
inference_device=device,
load_config=config.load,
evaluate_config=config.evaluate,
**config.train.dict(),
)
# Create a simple click command line interface to load the config and start the training
@click.command()
@click.option("--config_file", default="./train_decoder_config.json", help="Path to config file")
def main(config_file):
print("Recalling config from {}".format(config_file))
with open(config_file) as f:
config = json.load(f)
config = TrainDecoderConfig(**config)
initialize_training(config)
if __name__ == "__main__":
main()

View File

@@ -1,320 +1,365 @@
import os
from pathlib import Path
import click
import math
import argparse
import numpy as np
import torch
import clip
from torch import nn
from embedding_reader import EmbeddingReader
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork
from dalle2_pytorch.optimizer import get_optimizer
from torch.cuda.amp import autocast,GradScaler
import time
from dalle2_pytorch.dataloaders import make_splits
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter
from dalle2_pytorch.trainer import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model, print_ribbon
from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker
from dalle2_pytorch.utils import Timer
from embedding_reader import EmbeddingReader
from tqdm import tqdm
import wandb
os.environ["WANDB_SILENT"] = "true"
NUM_TEST_EMBEDDINGS = 100 # for cosine similarity reporting during training
REPORT_METRICS_EVERY = 100 # for cosine similarity and other metric reporting during training
# constants
REPORT_METRICS_EVERY = 250 # for cosine similarity and other metric reporting during training
def eval_model(model,device,image_reader,text_reader,start,end,batch_size,loss_type,phase="Validation"):
tracker = WandbTracker()
# helpers functions
def exists(val):
val is not None
# functions
def eval_model(model, dataloader, text_conditioned, loss_type, phase="Validation"):
model.eval()
with torch.no_grad():
total_loss = 0.
total_samples = 0.
for emb_images, emb_text in zip(image_reader(batch_size=batch_size, start=start, end=end),
text_reader(batch_size=batch_size, start=start, end=end)):
for image_embeddings, text_data in tqdm(dataloader):
emb_images_tensor = torch.tensor(emb_images[0]).to(device)
emb_text_tensor = torch.tensor(emb_text[0]).to(device)
batches = image_embeddings.shape[0]
batches = emb_images_tensor.shape[0]
input_args = dict(image_embed=image_embeddings)
if text_conditioned:
input_args = dict(**input_args, text = text_data)
else:
input_args = dict(**input_args, text_embed=text_data)
loss = model(text_embed = emb_text_tensor, image_embed = emb_images_tensor)
loss = model(**input_args)
total_loss += loss.item() * batches
total_loss += loss * batches
total_samples += batches
avg_loss = (total_loss / total_samples)
wandb.log({f'{phase} {loss_type}': avg_loss})
def save_model(save_path, state_dict):
# Saving State Dict
print("====================================== Saving checkpoint ======================================")
torch.save(state_dict, save_path+'/'+str(time.time())+'_saved_model.pth')
tracker.log({f'{phase} {loss_type}': avg_loss})
def report_cosine_sims(diffusion_prior, dataloader, text_conditioned):
diffusion_prior.eval()
def report_cosine_sims(diffusion_prior, image_reader, text_reader, train_set_size, val_set_size, NUM_TEST_EMBEDDINGS, device):
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
tstart = train_set_size+val_set_size
tend = train_set_size+val_set_size+NUM_TEST_EMBEDDINGS
for test_image_embeddings, text_data in tqdm(dataloader):
# we are text conditioned, we produce an embedding from the tokenized text
if text_conditioned:
text_embedding, text_encodings, text_mask = diffusion_prior.clip.embed_text(
text_data)
text_cond = dict(text_embed=text_embedding,
text_encodings=text_encodings, mask=text_mask)
else:
text_embedding = text_data
text_cond = dict(text_embed=text_embedding)
for embt, embi in zip(text_reader(batch_size=NUM_TEST_EMBEDDINGS, start=tstart, end=tend), image_reader(batch_size=NUM_TEST_EMBEDDINGS, start=tstart, end=tend)):
# make a copy of the text embeddings for shuffling
text_embed = torch.tensor(embt[0]).to(device)
text_embed_shuffled = text_embed.clone()
text_embed_shuffled = text_embedding.clone()
# roll the text embeddings to simulate "unrelated" captions
rolled_idx = torch.roll(torch.arange(NUM_TEST_EMBEDDINGS), 1)
# roll the text to simulate "unrelated" captions
rolled_idx = torch.roll(torch.arange(text_embedding.shape[0]), 1)
text_embed_shuffled = text_embed_shuffled[rolled_idx]
text_embed_shuffled = text_embed_shuffled / \
text_embed_shuffled.norm(dim=1, keepdim=True)
test_text_shuffled_cond = dict(text_embed=text_embed_shuffled)
if text_conditioned:
text_encodings_shuffled = text_encodings[rolled_idx]
text_mask_shuffled = text_mask[rolled_idx]
else:
text_encodings_shuffled = None
text_mask_shuffled = None
text_cond_shuffled = dict(text_embed=text_embed_shuffled,
text_encodings=text_encodings_shuffled, mask=text_mask_shuffled)
# prepare the text embedding
text_embed = text_embed / text_embed.norm(dim=1, keepdim=True)
test_text_cond = dict(text_embed=text_embed)
text_embed = text_embedding / text_embedding.norm(dim=1, keepdim=True)
# prepare image embeddings
test_image_embeddings = torch.tensor(embi[0]).to(device)
test_image_embeddings = test_image_embeddings / \
test_image_embeddings.norm(dim=1, keepdim=True)
# predict on the unshuffled text embeddings
predicted_image_embeddings = diffusion_prior.p_sample_loop(
(NUM_TEST_EMBEDDINGS, 768), text_cond=test_text_cond)
test_image_embeddings.shape, text_cond)
predicted_image_embeddings = predicted_image_embeddings / \
predicted_image_embeddings.norm(dim=1, keepdim=True)
# predict on the shuffled embeddings
predicted_unrelated_embeddings = diffusion_prior.p_sample_loop(
(NUM_TEST_EMBEDDINGS, 768), text_cond=test_text_shuffled_cond)
test_image_embeddings.shape, text_cond_shuffled)
predicted_unrelated_embeddings = predicted_unrelated_embeddings / \
predicted_unrelated_embeddings.norm(dim=1, keepdim=True)
# calculate similarities
original_similarity = cos(
text_embed, test_image_embeddings).cpu().numpy()
text_embed, test_image_embeddings).cpu().numpy()
predicted_similarity = cos(
text_embed, predicted_image_embeddings).cpu().numpy()
text_embed, predicted_image_embeddings).cpu().numpy()
unrelated_similarity = cos(
text_embed, predicted_unrelated_embeddings).cpu().numpy()
wandb.log(
{"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity)})
wandb.log({"CosineSimilarity(text_embed,predicted_image_embed)": np.mean(
predicted_similarity)})
wandb.log({"CosineSimilarity(text_embed,predicted_unrelated_embed)": np.mean(
unrelated_similarity)})
return np.mean(predicted_similarity - original_similarity)
text_embed, predicted_unrelated_embeddings).cpu().numpy()
predicted_img_similarity = cos(
test_image_embeddings, predicted_image_embeddings).cpu().numpy()
tracker.log({"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity),
"CosineSimilarity(text_embed,predicted_image_embed)":np.mean(predicted_similarity),
"CosineSimilarity(orig_image_embed,predicted_image_embed)":np.mean(predicted_img_similarity),
"CosineSimilarity(text_embed,predicted_unrelated_embed)": np.mean(unrelated_similarity),
"Cosine similarity difference":np.mean(predicted_similarity - original_similarity)})
@click.command()
@click.option("--wandb-entity", default="laion")
@click.option("--wandb-project", default="diffusion-prior")
@click.option("--wandb-dataset", default="LAION-5B")
@click.option("--wandb-arch", default="DiffusionPrior")
@click.option("--image-embed-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/")
@click.option("--text-embed-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/")
@click.option("--meta-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/laion2B-en-metadata/")
@click.option("--learning-rate", default=1.1e-4)
@click.option("--weight-decay", default=6.02e-2)
@click.option("--dropout", default=5e-2)
@click.option("--max-grad-norm", default=0.5)
@click.option("--num-data-points", default=250e6)
@click.option("--batch-size", default=320)
@click.option("--num-epochs", default=5)
@click.option("--image-embed-dim", default=768)
@click.option("--train-percent", default=0.9)
@click.option("--val-percent", default=1e-7)
@click.option("--test-percent", default=0.0999999)
@click.option("--dpn-depth", default=12)
@click.option("--dpn-dim-head", default=64)
@click.option("--dpn-heads", default=12)
@click.option("--dp-condition-on-text-encodings", default=True)
@click.option("--dp-timesteps", default=1000)
@click.option("--dp-normformer", default=True)
@click.option("--dp-cond-drop-prob", default=0.1)
@click.option("--dp-loss-type", default="l2")
@click.option("--clip", default="ViT-L/14")
@click.option("--amp", default=False)
@click.option("--save-interval", default=120)
@click.option("--save-path", default="./diffusion_prior_checkpoints")
@click.option("--pretrained-model-path", default=None)
@click.option("--gpu-device", default=0)
def train(
wandb_entity,
wandb_project,
wandb_dataset,
wandb_arch,
image_embed_url,
text_embed_url,
meta_url,
learning_rate,
weight_decay,
dropout,
max_grad_norm,
num_data_points,
batch_size,
num_epochs,
image_embed_dim,
train_percent,
val_percent,
test_percent,
dpn_depth,
dpn_dim_head,
dpn_heads,
dp_condition_on_text_encodings,
dp_timesteps,
dp_normformer,
dp_cond_drop_prob,
dp_loss_type,
clip,
amp,
save_interval,
save_path,
pretrained_model_path,
gpu_device
):
config = {
"learning_rate": learning_rate,
"architecture": wandb_arch,
"dataset": wandb_dataset,
"weight_decay": weight_decay,
"max_gradient_clipping_norm": max_grad_norm,
"batch_size": batch_size,
"epochs": num_epochs,
"diffusion_prior_network": {
"depth": dpn_depth,
"dim_head": dpn_dim_head,
"heads": dpn_heads,
"normformer": dp_normformer
},
"diffusion_prior": {
"condition_on_text_encodings": dp_condition_on_text_encodings,
"timesteps": dp_timesteps,
"cond_drop_prob": dp_cond_drop_prob,
"loss_type": dp_loss_type,
"clip": clip
}
}
def train(image_embed_dim,
image_embed_url,
text_embed_url,
batch_size,
train_percent,
val_percent,
test_percent,
num_epochs,
dp_loss_type,
clip,
dp_condition_on_text_encodings,
dp_timesteps,
dp_normformer,
dp_cond_drop_prob,
dpn_depth,
dpn_dim_head,
dpn_heads,
save_interval,
save_path,
device,
learning_rate=0.001,
max_grad_norm=0.5,
weight_decay=0.01,
amp=False):
# Check if DPRIOR_PATH exists(saved model path)
DPRIOR_PATH = pretrained_model_path
RESUME = exists(DPRIOR_PATH)
if not RESUME:
tracker.init(
entity = wandb_entity,
project = wandb_project,
config = config
)
# Obtain the utilized device.
has_cuda = torch.cuda.is_available()
if has_cuda:
device = torch.device(f"cuda:{gpu_device}")
torch.cuda.set_device(device)
# Training loop
# diffusion prior network
# DiffusionPriorNetwork
prior_network = DiffusionPriorNetwork(
dim = image_embed_dim,
depth = dpn_depth,
dim_head = dpn_dim_head,
heads = dpn_heads,
normformer = dp_normformer).to(device)
dim = image_embed_dim,
depth = dpn_depth,
dim_head = dpn_dim_head,
heads = dpn_heads,
attn_dropout = dropout,
ff_dropout = dropout,
normformer = dp_normformer
)
# DiffusionPrior with text embeddings and image embeddings pre-computed
diffusion_prior = DiffusionPrior(
net = prior_network,
clip = clip,
image_embed_dim = image_embed_dim,
timesteps = dp_timesteps,
cond_drop_prob = dp_cond_drop_prob,
loss_type = dp_loss_type,
condition_on_text_encodings = dp_condition_on_text_encodings).to(device)
# Load clip model if text-conditioning
if dp_condition_on_text_encodings:
clip_adapter = OpenAIClipAdapter(clip)
else:
clip_adapter = None
# diffusion prior with text embeddings and image embeddings pre-computed
# Get image and text embeddings from the servers
print("==============Downloading embeddings - image and text====================")
image_reader = EmbeddingReader(embeddings_folder=image_embed_url, file_format="npy")
text_reader = EmbeddingReader(embeddings_folder=text_embed_url, file_format="npy")
num_data_points = text_reader.count
diffusion_prior = DiffusionPrior(
net = prior_network,
clip = clip_adapter,
image_embed_dim = image_embed_dim,
timesteps = dp_timesteps,
cond_drop_prob = dp_cond_drop_prob,
loss_type = dp_loss_type,
condition_on_text_encodings = dp_condition_on_text_encodings
)
# Load pre-trained model from DPRIOR_PATH
if RESUME:
diffusion_prior, loaded_obj = load_diffusion_model(DPRIOR_PATH, device)
tracker.init(entity = wandb_entity, project = wandb_project, config = config)
# diffusion prior trainer
trainer = DiffusionPriorTrainer(
diffusion_prior = diffusion_prior,
lr = learning_rate,
wd = weight_decay,
max_grad_norm = max_grad_norm,
amp = amp,
).to(device)
# load optimizer and scaler
if RESUME:
trainer.optimizer.load_state_dict(loaded_obj['optimizer'])
trainer.scaler.load_state_dict(loaded_obj['scaler'])
# Create save_path if it doesn't exist
if not os.path.exists(save_path):
os.makedirs(save_path)
Path(save_path).mkdir(exist_ok = True, parents = True)
# Utilize wrapper to abstract away loader logic
print_ribbon("Downloading Embeddings")
loader_args = dict(text_conditioned=dp_condition_on_text_encodings, batch_size=batch_size, num_data_points=num_data_points,
train_split=train_percent, eval_split=val_percent, device=device, img_url=image_embed_url)
if dp_condition_on_text_encodings:
loader_args = dict(**loader_args, meta_url=meta_url)
else:
loader_args = dict(**loader_args, txt_url=text_embed_url)
train_loader, eval_loader, test_loader = make_splits(**loader_args)
### Training code ###
scaler = GradScaler(enabled=amp)
optimizer = get_optimizer(diffusion_prior.net.parameters(), wd=weight_decay, lr=learning_rate)
step = 1
timer = Timer()
epochs = num_epochs
step = 0
t = time.time()
train_set_size = int(train_percent*num_data_points)
val_set_size = int(val_percent*num_data_points)
for _ in range(epochs):
diffusion_prior.train()
for emb_images,emb_text in zip(image_reader(batch_size=batch_size, start=0, end=train_set_size),
text_reader(batch_size=batch_size, start=0, end=train_set_size)):
emb_images_tensor = torch.tensor(emb_images[0]).to(device)
emb_text_tensor = torch.tensor(emb_text[0]).to(device)
for image, text in tqdm(train_loader):
diffusion_prior.train()
input_args = dict(image_embed=image)
if dp_condition_on_text_encodings:
input_args = dict(**input_args, text = text)
else:
input_args = dict(**input_args, text_embed=text)
with autocast(enabled=amp):
loss = diffusion_prior(text_embed = emb_text_tensor,image_embed = emb_images_tensor)
scaler.scale(loss).backward()
loss = trainer(**input_args)
# Samples per second
step+=1
samples_per_sec = batch_size*step/(time.time()-t)
# Save checkpoint every save_interval minutes
if(int(time.time()-t) >= 60*save_interval):
t = time.time()
save_model(
samples_per_sec = batch_size * step / timer.elapsed()
# Save checkpoint every save_interval minutes
if(int(timer.elapsed()) >= 60 * save_interval):
timer.reset()
save_diffusion_model(
save_path,
dict(model=diffusion_prior.state_dict(), optimizer=optimizer.state_dict(), scaler=scaler.state_dict()))
diffusion_prior,
trainer.optimizer,
trainer.scaler,
config,
image_embed_dim)
# Log to wandb
wandb.log({"Training loss": loss.item(),
tracker.log({"Training loss": loss,
"Steps": step,
"Samples per second": samples_per_sec})
# Log cosineSim(text_embed,predicted_image_embed) - cosineSim(text_embed,image_embed)
# Use NUM_TEST_EMBEDDINGS samples from the test set each time
# Get embeddings from the most recently saved model
if(step % REPORT_METRICS_EVERY) == 0:
diff_cosine_sim = report_cosine_sims(diffusion_prior,
image_reader,
text_reader,
train_set_size,
val_set_size,
NUM_TEST_EMBEDDINGS,
device)
wandb.log({"Cosine similarity difference": diff_cosine_sim})
report_cosine_sims(diffusion_prior, eval_loader, dp_condition_on_text_encodings)
### Evaluate model(validation run) ###
eval_model(diffusion_prior, eval_loader, dp_condition_on_text_encodings, dp_loss_type, phase="Validation")
scaler.unscale_(optimizer)
nn.utils.clip_grad_norm_(diffusion_prior.parameters(), max_grad_norm)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
### Evaluate model(validation run) ###
start = train_set_size
end=start+val_set_size
eval_model(diffusion_prior,device,image_reader,text_reader,start,end,batch_size,dp_loss_type,phase="Validation")
step += 1
trainer.update()
### Test run ###
test_set_size = int(test_percent*train_set_size)
start=train_set_size+val_set_size
end=num_data_points
eval_model(diffusion_prior,device,image_reader,text_reader,start,end,batch_size,dp_loss_type,phase="Test")
eval_model(diffusion_prior, test_loader, dp_condition_on_text_encodings, dp_loss_type, phase="Test")
def main():
parser = argparse.ArgumentParser()
# Logging
parser.add_argument("--wandb-entity", type=str, default="laion")
parser.add_argument("--wandb-project", type=str, default="diffusion-prior")
parser.add_argument("--wandb-name", type=str, default="laion-dprior")
parser.add_argument("--wandb-dataset", type=str, default="LAION-5B")
parser.add_argument("--wandb-arch", type=str, default="DiffusionPrior")
# URLs for embeddings
parser.add_argument("--image-embed-url", type=str, default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/")
parser.add_argument("--text-embed-url", type=str, default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/")
# Hyperparameters
parser.add_argument("--learning-rate", type=float, default=1.1e-4)
parser.add_argument("--weight-decay", type=float, default=6.02e-2)
parser.add_argument("--max-grad-norm", type=float, default=0.5)
parser.add_argument("--batch-size", type=int, default=10**4)
parser.add_argument("--num-epochs", type=int, default=5)
# Image embed dimension
parser.add_argument("--image-embed-dim", type=int, default=768)
# Train-test split
parser.add_argument("--train-percent", type=float, default=0.7)
parser.add_argument("--val-percent", type=float, default=0.2)
parser.add_argument("--test-percent", type=float, default=0.1)
# LAION training(pre-computed embeddings)
# DiffusionPriorNetwork(dpn) parameters
parser.add_argument("--dpn-depth", type=int, default=6)
parser.add_argument("--dpn-dim-head", type=int, default=64)
parser.add_argument("--dpn-heads", type=int, default=8)
# DiffusionPrior(dp) parameters
parser.add_argument("--dp-condition-on-text-encodings", type=bool, default=False)
parser.add_argument("--dp-timesteps", type=int, default=100)
parser.add_argument("--dp-l2norm-output", type=bool, default=False)
parser.add_argument("--dp-normformer", type=bool, default=False)
parser.add_argument("--dp-cond-drop-prob", type=float, default=0.1)
parser.add_argument("--dp-loss-type", type=str, default="l2")
parser.add_argument("--clip", type=str, default=None)
parser.add_argument("--amp", type=bool, default=False)
# Model checkpointing interval(minutes)
parser.add_argument("--save-interval", type=int, default=30)
parser.add_argument("--save-path", type=str, default="./diffusion_prior_checkpoints")
args = parser.parse_args()
print("Setting up wandb logging... Please wait...")
wandb.init(
entity=args.wandb_entity,
project=args.wandb_project,
config={
"learning_rate": args.learning_rate,
"architecture": args.wandb_arch,
"dataset": args.wandb_dataset,
"epochs": args.num_epochs,
})
print("wandb logging setup done!")
# Obtain the utilized device.
has_cuda = torch.cuda.is_available()
if has_cuda:
device = torch.device("cuda:0")
torch.cuda.set_device(device)
# Training loop
train(args.image_embed_dim,
args.image_embed_url,
args.text_embed_url,
args.batch_size,
args.train_percent,
args.val_percent,
args.test_percent,
args.num_epochs,
args.dp_loss_type,
args.clip,
args.dp_condition_on_text_encodings,
args.dp_timesteps,
args.dp_normformer,
args.dp_cond_drop_prob,
args.dpn_depth,
args.dpn_dim_head,
args.dpn_heads,
args.save_interval,
args.save_path,
device,
args.learning_rate,
args.max_grad_norm,
args.weight_decay,
args.amp)
if __name__ == "__main__":
main()
train()