Compare commits

...

158 Commits

Author SHA1 Message Date
Phil Wang
68e7d2f241 make sure gradient accumulation feature works even if all arguments passed in are keyword arguments 2022-05-15 11:16:16 -07:00
Phil Wang
74f222596a remove todo 2022-05-15 11:01:35 -07:00
Phil Wang
aa6772dcff make sure optimizer and scaler is reloaded on resume for training diffusion prior script, move argparse to click 2022-05-15 10:48:10 -07:00
Phil Wang
71d0c4edae cleanup to use diffusion prior trainer 2022-05-15 10:16:05 -07:00
Phil Wang
f7eee09d8b 0.2.30 2022-05-15 09:56:59 -07:00
Phil Wang
89de5af63e experiment tracker agnostic 2022-05-15 09:56:40 -07:00
Phil Wang
4ec6d0ba81 backwards pass is not recommended under the autocast context, per pytorch docs 2022-05-14 18:26:19 -07:00
Phil Wang
aee92dba4a simplify more 2022-05-14 17:16:46 -07:00
Phil Wang
b0cd5f24b6 take care of gradient accumulation automatically for researchers, by passing in a max_batch_size on the decoder or diffusion prior trainer forward 2022-05-14 17:04:09 -07:00
Phil Wang
b494ed81d4 take care of backwards within trainer classes for diffusion prior and decoder, readying to take care of gradient accumulation as well (plus, unsure if loss should be backwards within autocast block) 2022-05-14 15:49:24 -07:00
Phil Wang
ff3474f05c normalize conditioning tokens outside of cross attention blocks 2022-05-14 14:23:52 -07:00
Phil Wang
d5293f19f1 lineup with paper 2022-05-14 13:57:00 -07:00
Phil Wang
e697183849 be able to customize adam eps 2022-05-14 13:55:04 -07:00
Phil Wang
591d37e266 lower default initial learning rate to what Jonathan Ho had in his original repo 2022-05-14 13:22:43 -07:00
Phil Wang
d1f02e8f49 always use sandwich norm for attention layer 2022-05-14 12:13:41 -07:00
Phil Wang
9faab59b23 use post-attn-branch layernorm in attempt to stabilize cross attention conditioning in decoder 2022-05-14 11:58:09 -07:00
Phil Wang
5d27029e98 make sure lowres conditioning image is properly normalized to -1 to 1 for cascading ddpm 2022-05-14 01:23:54 -07:00
Phil Wang
3115fa17b3 fix everything around normalizing images to -1 to 1 for ddpm training automatically 2022-05-14 01:17:11 -07:00
Phil Wang
124d8577c8 move the inverse normalization function called before image embeddings are derived from clip to within the diffusion prior and decoder classes 2022-05-14 00:37:52 -07:00
Phil Wang
2db0c9794c comments 2022-05-12 14:25:20 -07:00
Phil Wang
2277b47ffd make sure learned variance can work for any number of unets in the decoder, defaults to first unet, as suggested was used in the paper 2022-05-12 14:18:15 -07:00
Phil Wang
28b58e568c cleanup in preparation of option for learned variance 2022-05-12 12:04:52 -07:00
Phil Wang
924455d97d align the ema model device back after sampling from the cascading ddpm in the decoder 2022-05-11 19:56:54 -07:00
Phil Wang
6021945fc8 default to l2 loss 2022-05-11 19:24:51 -07:00
Light-V
6f76652d11 fix typo in README.md (#85)
The default config for clip from openai should be ViT-B/32
2022-05-11 13:38:16 -07:00
Phil Wang
3dda2570ed fix amp issue for https://github.com/lucidrains/DALLE2-pytorch/issues/82 2022-05-11 08:21:39 -07:00
Phil Wang
2f3c02dba8 numerical accuracy for noise schedule parameters 2022-05-10 15:28:46 -07:00
Phil Wang
908088cfea wrap up cross embed layer feature 2022-05-10 12:19:34 -07:00
Phil Wang
8dc8a3de0d product management 2022-05-10 11:51:38 -07:00
Phil Wang
35f89556ba bring in the cross embed layer from Crossformer paper for initial convolution in unet 2022-05-10 11:50:38 -07:00
Phil Wang
2b55f753b9 fix new issue with github actions and auto pypi package uploading 2022-05-10 10:51:15 -07:00
Phil Wang
fc8fce38fb make sure cascading DDPM can be trained unconditionally, to ready for CLI one command training for the public 2022-05-10 10:48:10 -07:00
Phil Wang
a1bfb03ba4 project management 2022-05-10 10:13:51 -07:00
Phil Wang
b1e7b5f6bb make sure resnet groups in unet is finely customizable 2022-05-10 10:12:50 -07:00
z
10b905b445 smol typo (#81) 2022-05-10 09:52:50 -07:00
Phil Wang
9b322ea634 patch 2022-05-09 19:46:19 -07:00
Phil Wang
ba64ea45cc 0.2.3 2022-05-09 16:50:31 -07:00
Phil Wang
64f7be1926 some cleanup 2022-05-09 16:50:21 -07:00
Phil Wang
db805e73e1 fix a bug with numerical stability in attention, sorry! 🐛 2022-05-09 16:23:37 -07:00
z
cb07b37970 Ensure Eval Mode In Metric Functions (#79)
* add eval/train toggles

* train/eval flags

* shift train toggle

Co-authored-by: nousr <z@localhost.com>
2022-05-09 16:05:40 -07:00
Phil Wang
a774bfefe2 add attention and feedforward dropouts to train_diffusion_prior script 2022-05-09 13:57:15 -07:00
Phil Wang
2ae57f0cf5 cleanup 2022-05-09 13:51:26 -07:00
Phil Wang
e46eaec817 deal the diffusion prior problem yet another blow 2022-05-09 11:08:52 -07:00
Kumar R
8647cb5e76 Val loss changes, with quite a few other changes. This is in place of the earlier PR(https://github.com/lucidrains/DALLE2-pytorch/pull/67) (#77)
* Val_loss changes - no rebased with lucidrains' master.

* Val Loss changes - now rebased with lucidrains' master

* train_diffusion_prior.py updates

* dalle2_pytorch.py updates

* __init__.py changes

* Update train_diffusion_prior.py

* Update dalle2_pytorch.py

* Update train_diffusion_prior.py

* Update train_diffusion_prior.py

* Update dalle2_pytorch.py

* Update train_diffusion_prior.py

* Update train_diffusion_prior.py

* Update train_diffusion_prior.py

* Update train_diffusion_prior.py

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md
2022-05-09 08:53:29 -07:00
Phil Wang
53c189e46a give more surface area for attention in diffusion prior 2022-05-09 08:08:11 -07:00
Phil Wang
dde51fd362 revert restriction for classifier free guidance for diffusion prior, given @crowsonkb advice 2022-05-07 20:55:41 -07:00
Nasir Khalid
2eac7996fa Additional image_embed metric (#75)
Added metric to track image_embed vs predicted_image_embed
2022-05-07 14:32:33 -07:00
Phil Wang
4010aec033 turn off classifier free guidance if predicting x_start for diffusion prior 2022-05-07 09:38:17 -07:00
Phil Wang
c87b84a259 todo 2022-05-07 09:21:08 -07:00
Phil Wang
8b05468653 todo 2022-05-07 08:33:45 -07:00
Phil Wang
830afd3c15 sinusoidal embed time embeddings for diffusion prior as well, for continuous version 2022-05-07 08:32:43 -07:00
Phil Wang
8f93729d19 when in doubt, make it a hyperparameter 2022-05-07 07:52:17 -07:00
z
cd5f2c1de4 simulate unrelated captions as a training metric (#66)
* add unrelated embedding metric

* change to torch.roll

Co-authored-by: nousr <z@localhost.com>
Co-authored-by: nousr <>
2022-05-07 05:34:59 -07:00
Phil Wang
85ed77d512 fix a potentially huge bug thanks to @CiaoHe https://github.com/lucidrains/DALLE2-pytorch/issues/71 2022-05-07 05:05:54 -07:00
Piero Rolando
fd53fa17db Fix a typo in README (#70)
Change "pyhon" for "python" (correct)
2022-05-06 16:53:36 -07:00
Phil Wang
3676ef4d49 make sure vqgan-vae trainer supports mixed precision 2022-05-06 10:44:16 -07:00
Phil Wang
28e944f328 make sure openai clip adapter outputs l2normed embeddings 2022-05-06 10:12:03 -07:00
Phil Wang
14e63a3f67 also offer l2norm clamping in diffusion prior during training, if one were using predict x0 objective 2022-05-06 10:05:14 -07:00
Phil Wang
09e9eaa5a6 project management 2022-05-06 09:00:22 -07:00
Phil Wang
e6d752cf4a reprioritize 2022-05-06 08:55:26 -07:00
Phil Wang
ad20a14a4d bring in rotary embeddings for diffusion prior causal transformer (the most powerful relative positional encoding, used in PaLM) - 0.1.0 because of breaking change 2022-05-06 08:45:30 -07:00
Phil Wang
0be1e0d64c support CoCa, which seems to be better than CLIP (has an autoregressive text encoder) https://arxiv.org/abs/2205.01917 2022-05-06 08:27:12 -07:00
Phil Wang
98df1ba51e add diffusion prior trainer, which automatically takes care of the exponential moving average (training and sampling), as well as mixed precision, gradient clipping 2022-05-06 08:11:09 -07:00
Phil Wang
878b555ef7 fix training with clip 2022-05-06 07:37:57 -07:00
Phil Wang
63029f7388 remove l2norm output from train_diffusion_prior.py 2022-05-05 19:07:58 -07:00
Phil Wang
c76a964fd6 allow for CLIP to be optional in Decoder, and allow DecoderTrainer to work off training pre-encoded image embeddings 2022-05-05 08:11:01 -07:00
Phil Wang
79fabc4341 reorg readme 2022-05-05 07:54:12 -07:00
Kumar R
f7ef4bde38 Added some documentation for the diffusion prior in README.md (#62)
* Delete README.md

* Create README.md

* Update README.md

* Update README.md
2022-05-05 07:51:31 -07:00
Phil Wang
93ba019069 product management 2022-05-05 07:39:51 -07:00
Phil Wang
8518684ae9 does not make much sense, as researchers may want to try predicting noise with diffusionprior instead of predicting x0 2022-05-05 07:37:00 -07:00
Phil Wang
1d5dc08810 take @crowsonkb 's suggestion at https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132 2022-05-05 07:28:53 -07:00
Phil Wang
d8d8b6caf1 dataloaders for decoder training, from @Veldrovive 2022-05-05 07:09:45 -07:00
Aidan Dempster
15acc03bd4 Add a dataloader for training the decoder (#57)
* Added dataloader and updated requirements

* Added option to set embedding shard width separately from webdataset shard length.
There must be a better way to do this.

* Changed embedding loader to read using fsspec

* Moved the loader into a more compatible location

* Removed unnecessary package

* Fixed typo (Embeding -> Embedding)

* Simplified example embedding finder code to remove unnecessary get_file_list function

* Added example usage of ImageEmbeddingDataset

* Changed the name of create_dataloader to be more verbose
Added a dataloaders __init__.py
2022-05-05 07:08:45 -07:00
Phil Wang
896f19786d remove convnext blocks, they are illsuited for generative work, validated by early experimental results at https://github.com/lucidrains/video-diffusion-pytorch 2022-05-05 07:07:21 -07:00
Phil Wang
aec5575d09 take a bet on resize right, given Katherine is using it 2022-05-04 19:26:45 -07:00
Phil Wang
9773f10d6c use inference mode whenever possible, cleanup 2022-05-04 15:25:05 -07:00
Phil Wang
a6bf8ddef6 advertise laion 2022-05-04 15:04:05 -07:00
Phil Wang
86e692d24f fix random crop probability 2022-05-04 11:52:24 -07:00
Phil Wang
97b751209f allow for last unet in the cascade to be trained on crops, if it is convolution-only 2022-05-04 11:48:48 -07:00
Phil Wang
74103fd8d6 product management 2022-05-04 11:20:50 -07:00
Phil Wang
1992d25cad project management 2022-05-04 11:18:54 -07:00
Phil Wang
5b619c2fd5 make sure some hyperparameters for unet block is configurable 2022-05-04 11:18:32 -07:00
Phil Wang
9359ad2e91 0.0.95 2022-05-04 10:53:05 -07:00
Phil Wang
9ff228188b offer old resnet blocks, from the original DDPM paper, just in case convnexts are unsuitable for generative work 2022-05-04 10:52:58 -07:00
Kumar R
2d9963d30e Reporting metrics - Cosine similarity. (#55)
* Update train_diffusion_prior.py

* Delete train_diffusion_prior.py

* Cosine similarity logging.

* Update train_diffusion_prior.py

* Report Cosine metrics every N steps.
2022-05-04 08:04:36 -07:00
Phil Wang
58d9b422f3 0.0.94 2022-05-04 07:42:33 -07:00
Ray Bell
44b319cb57 add missing import (#56) 2022-05-04 07:42:20 -07:00
Phil Wang
c30f380689 final reminder 2022-05-03 08:18:53 -07:00
Phil Wang
e4e884bb8b keep all doors open 2022-05-03 08:17:02 -07:00
Phil Wang
803ad9c17d product management again 2022-05-03 08:15:25 -07:00
Phil Wang
a88dd6a9c0 todo 2022-05-03 08:09:02 -07:00
Kumar R
72c16b496e Update train_diffusion_prior.py (#53) 2022-05-02 22:44:57 -07:00
z
81d83dd7f2 defaults align with paper (#52)
Co-authored-by: nousr <>
2022-05-02 13:52:11 -07:00
Phil Wang
fa66f7e1e9 todo 2022-05-02 12:57:15 -07:00
Phil Wang
aa8d135245 allow laion to experiment with normformer in diffusion prior 2022-05-02 11:35:00 -07:00
Phil Wang
70282de23b add ability to turn on normformer settings, given @borisdayma reported good results and some personal anecdata 2022-05-02 11:33:15 -07:00
Phil Wang
83f761847e todo 2022-05-02 10:52:39 -07:00
Phil Wang
11469dc0c6 makes more sense to keep this as True as default, for stability 2022-05-02 10:50:55 -07:00
Romain Beaumont
2d25c89f35 Fix passing of l2norm_output to DiffusionPriorNetwork (#51) 2022-05-02 10:48:16 -07:00
Phil Wang
3fe96c208a add ability to train diffusion prior with l2norm on output image embed 2022-05-02 09:53:20 -07:00
Phil Wang
0fc6c9cdf3 provide option to l2norm the output of the diffusion prior 2022-05-02 09:41:03 -07:00
Phil Wang
7ee0ecc388 mixed precision for training diffusion prior + save optimizer and scaler states 2022-05-02 09:31:04 -07:00
Phil Wang
1924c7cc3d fix issue with mixed precision and gradient clipping 2022-05-02 09:20:19 -07:00
Phil Wang
f7df3caaf3 address not calculating average eval / test loss when training diffusion prior https://github.com/lucidrains/DALLE2-pytorch/issues/49 2022-05-02 08:51:41 -07:00
Phil Wang
fc954ee788 fix calculation of adaptive weight for vit-vqgan, thanks to @CiaoHe 2022-05-02 07:58:14 -07:00
Phil Wang
c1db2753f5 todo 2022-05-01 18:02:30 -07:00
Phil Wang
ad87bfe28f switch to using linear attention for the sparse attention layers within unet, given success in GAN projects 2022-05-01 17:59:03 -07:00
Phil Wang
76c767b1ce update deps, commit to using webdatasets, per @rom1504 consultation 2022-05-01 12:22:15 -07:00
Phil Wang
d991b8c39c just clip the diffusion prior network parameters 2022-05-01 12:01:08 -07:00
Phil Wang
902693e271 todo 2022-05-01 11:57:08 -07:00
Phil Wang
35cd63982d add gradient clipping, make sure weight decay is configurable, make sure learning rate is actually passed into get_optimizer, make sure model is set to training mode at beginning of each epoch 2022-05-01 11:55:38 -07:00
Kumar R
53ce6dfdf6 All changes implemented, current run happening. Link to wandb run in comments. (#43)
* Train DiffusionPrior with pre-computed embeddings

This is in response to https://github.com/lucidrains/DALLE2-pytorch/issues/29 - more metrics will get added.
2022-05-01 11:46:59 -07:00
Phil Wang
ad8d7a368b product management 2022-05-01 11:26:21 -07:00
Phil Wang
b8cf1e5c20 more attention 2022-05-01 11:00:33 -07:00
Phil Wang
94aaa08d97 product management 2022-05-01 09:43:10 -07:00
Phil Wang
8b9bbec7d1 project management 2022-05-01 09:32:57 -07:00
Phil Wang
1bb9fc9829 add convnext backbone for vqgan-vae, still need to fix groupnorms in resnet encdec 2022-05-01 09:32:24 -07:00
Phil Wang
5e421bd5bb let researchers do the hyperparameter search 2022-05-01 08:46:21 -07:00
Phil Wang
67fcab1122 add MLP based time conditioning to all convnexts, in addition to cross attention. also add an initial convolution, given convnext first depthwise conv 2022-05-01 08:41:02 -07:00
Phil Wang
5bfbccda22 port over vqgan vae trainer 2022-05-01 08:09:15 -07:00
Phil Wang
989275ff59 product management 2022-04-30 16:57:56 -07:00
Phil Wang
56408f4a40 project management 2022-04-30 16:57:02 -07:00
Phil Wang
d1a697ac23 allows one to shortcut sampling at a specific unet number, if one were to be training in stages 2022-04-30 16:05:13 -07:00
Phil Wang
ebe01749ed DecoderTrainer sample method uses the exponentially moving averaged 2022-04-30 14:55:34 -07:00
Phil Wang
63195cc2cb allow for division of loss prior to scaling, for gradient accumulation purposes 2022-04-30 12:56:47 -07:00
Phil Wang
a2ef69af66 take care of mixed precision, and make gradient accumulation do-able externally 2022-04-30 12:27:24 -07:00
Phil Wang
5fff22834e be able to finely customize learning parameters for each unet, take care of gradient clipping 2022-04-30 11:56:05 -07:00
Phil Wang
a9421f49ec simplify Decoder training for the public 2022-04-30 11:45:18 -07:00
Phil Wang
77fa34eae9 fix all clipping / clamping issues 2022-04-30 10:08:24 -07:00
Phil Wang
1c1e508369 fix all issues with text encodings conditioning in the decoder, using null padding tokens technique from dalle v1 2022-04-30 09:13:34 -07:00
Phil Wang
f19c99ecb0 fix decoder needing separate conditional dropping probabilities for image embeddings and text encodings, thanks to @xiankgx ! 2022-04-30 08:48:05 -07:00
Phil Wang
721a444686 Merge pull request #37 from ProGamerGov/patch-1
Fix spelling and grammatical errors
2022-04-30 08:19:07 -07:00
ProGamerGov
63450b466d Fix spelling and grammatical errors 2022-04-30 09:18:13 -06:00
Phil Wang
20e7eb5a9b cleanup 2022-04-30 07:22:57 -07:00
Phil Wang
e2f9615afa use @clip-anytorch , thanks to @rom1504 2022-04-30 06:40:54 -07:00
Phil Wang
0d1c07c803 fix a bug with classifier free guidance, thanks to @xiankgx again! 2022-04-30 06:34:57 -07:00
Phil Wang
a389f81138 todo 2022-04-29 15:40:51 -07:00
Phil Wang
0283556608 fix example in readme, since api changed 2022-04-29 13:40:55 -07:00
Phil Wang
5063d192b6 now completely OpenAI CLIP compatible for training
just take care of the logic for AdamW and transformers

used namedtuples for clip adapter embedding outputs
2022-04-29 13:05:01 -07:00
Phil Wang
f4a54e475e add some training fns 2022-04-29 09:44:55 -07:00
Phil Wang
fb662a62f3 fix another bug thanks to @xiankgx 2022-04-29 07:38:32 -07:00
Phil Wang
587c8c9b44 optimize for clarity 2022-04-28 21:59:13 -07:00
Phil Wang
aa900213e7 force first unet in the cascade to be conditioned on image embeds 2022-04-28 20:53:15 -07:00
Phil Wang
cb26187450 vqgan-vae codebook dims should be 256 or smaller 2022-04-28 08:59:03 -07:00
Phil Wang
625ce23f6b 🐛 2022-04-28 07:21:18 -07:00
Phil Wang
dbf4a281f1 make sure another CLIP can actually be passed in, as long as it is wrapped in an adapter extended from BaseClipAdapter 2022-04-27 20:45:27 -07:00
Phil Wang
4ab527e779 some extra asserts for text encoding of diffusion prior and decoder 2022-04-27 20:11:43 -07:00
Phil Wang
d0cdeb3247 add ability for DALL-E2 to return PIL images with return_pil_images = True on forward, for those who have no clue about deep learning 2022-04-27 19:58:06 -07:00
Phil Wang
8c610aad9a only pass text encodings conditioning in diffusion prior if specified on initialization 2022-04-27 19:48:16 -07:00
Phil Wang
6700381a37 prepare for ability to integrate other clips other than x-clip 2022-04-27 19:35:05 -07:00
Phil Wang
20377f889a todo 2022-04-27 17:22:14 -07:00
Phil Wang
6edb1c5dd0 fix issue with ema class 2022-04-27 16:40:02 -07:00
Phil Wang
b093f92182 inform what is possible 2022-04-27 08:25:16 -07:00
Phil Wang
fa3bb6ba5c make sure cpu-only still works 2022-04-27 08:02:10 -07:00
Phil Wang
2705e7c9b0 attention-based upsampling claims unsupported by local experiments, removing 2022-04-27 07:51:04 -07:00
Phil Wang
77141882c8 complete vit-vqgan from https://arxiv.org/abs/2110.04627 2022-04-26 17:20:47 -07:00
Phil Wang
4075d02139 nevermind, it could be working, but only when i stabilize it with the feedforward layer + tanh as proposed in vit-vqgan paper (which will be built into the repository later for the latent diffusion) 2022-04-26 12:43:31 -07:00
Phil Wang
de0296106b be able to turn off warning for use of LazyLinear by passing in text embedding dimension for unet 2022-04-26 11:42:46 -07:00
14 changed files with 2874 additions and 462 deletions

454
README.md
View File

@@ -10,7 +10,7 @@ The main novelty seems to be an extra layer of indirection with the prior networ
This model is SOTA for text-to-image for now.
Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord" src="https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white"></a> if you are interested in helping out with the replication
Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord" src="https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white"></a> if you are interested in helping out with the replication with the <a href="https://laion.ai/">LAION</a> community | <a href="https://www.youtube.com/watch?v=AIOE1l1W0Tw">Yannic Interview</a>
There was enough interest for a <a href="https://github.com/lucidrains/dalle2-jax">Jax version</a>. I will also eventually extend this to <a href="https://github.com/lucidrains/dalle2-video">text to video</a>, once the repository is in a good place.
@@ -47,7 +47,7 @@ clip = CLIP(
use_all_token_embeds = True, # whether to use fine-grained contrastive learning (FILIP)
decoupled_contrastive_learning = True, # use decoupled contrastive learning (DCL) objective function, removing positive pairs from the denominator of the InfoNCE loss (CLOOB + DCL)
extra_latent_projection = True, # whether to use separate projections for text-to-image vs image-to-text comparisons (CLOOB)
use_visual_ssl = True, # whether to do self supervised learning on iages
use_visual_ssl = True, # whether to do self supervised learning on images
visual_ssl_type = 'simclr', # can be either 'simclr' or 'simsiam', depending on using DeCLIP or SLIP
use_mlm = False, # use masked language learning (MLM) on text (DeCLIP)
text_ssl_loss_weight = 0.05, # weight for text MLM loss
@@ -110,7 +110,8 @@ decoder = Decoder(
unet = unet,
clip = clip,
timesteps = 100,
cond_drop_prob = 0.2
image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5
).cuda()
# mock images (get a lot of this)
@@ -229,7 +230,8 @@ decoder = Decoder(
unet = (unet1, unet2), # insert both unets in order of low resolution to highest resolution (you can have as many stages as you want here)
image_sizes = (256, 512), # resolutions, 256 for first unet, 512 for second. these must be unique and in ascending order (matches with the unets passed in)
timesteps = 1000,
cond_drop_prob = 0.2
image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5
).cuda()
# mock images (get a lot of this)
@@ -348,7 +350,8 @@ decoder = Decoder(
image_sizes = (128, 256),
clip = clip,
timesteps = 100,
cond_drop_prob = 0.2,
image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5,
condition_on_text_encodings = False # set this to True if you wish to condition on text during training and sampling
).cuda()
@@ -430,8 +433,8 @@ images = torch.randn(4, 3, 256, 256).cuda()
# precompute the text and image embeddings
# here using the diffusion prior class, but could be done with CLIP alone
clip_image_embeds = diffusion_prior.get_image_embed(images)
clip_text_embeds = diffusion_prior.get_text_cond(text).get('text_embed')
clip_image_embeds = diffusion_prior.clip.embed_image(images).image_embed
clip_text_embeds = diffusion_prior.clip.embed_text(text).text_embed
# feed text and images into diffusion prior network
@@ -495,14 +498,105 @@ loss.backward()
# now the diffusion prior can generate image embeddings from the text embeddings
```
## OpenAI CLIP
Although there is the possibility they are using an unreleased, more powerful CLIP, you can use one of the released ones, if you do not wish to train your own CLIP from scratch. This will also allow the community to more quickly validate the conclusions of the paper.
To use a pretrained OpenAI CLIP, simply import `OpenAIClipAdapter` and pass it into the `DiffusionPrior` or `Decoder` like so
```python
import torch
from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder, OpenAIClipAdapter
# openai pretrained clip - defaults to ViT-B/32
clip = OpenAIClipAdapter()
# mock data
text = torch.randint(0, 49408, (4, 256)).cuda()
images = torch.randn(4, 3, 256, 256).cuda()
# prior networks (with transformer)
prior_network = DiffusionPriorNetwork(
dim = 512,
depth = 6,
dim_head = 64,
heads = 8
).cuda()
diffusion_prior = DiffusionPrior(
net = prior_network,
clip = clip,
timesteps = 100,
cond_drop_prob = 0.2
).cuda()
loss = diffusion_prior(text, images)
loss.backward()
# do above for many steps ...
# decoder (with unet)
unet1 = Unet(
dim = 128,
image_embed_dim = 512,
cond_dim = 128,
channels = 3,
dim_mults=(1, 2, 4, 8)
).cuda()
unet2 = Unet(
dim = 16,
image_embed_dim = 512,
cond_dim = 128,
channels = 3,
dim_mults = (1, 2, 4, 8, 16)
).cuda()
decoder = Decoder(
unet = (unet1, unet2),
image_sizes = (128, 256),
clip = clip,
timesteps = 100,
image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5,
condition_on_text_encodings = False # set this to True if you wish to condition on text during training and sampling
).cuda()
for unet_number in (1, 2):
loss = decoder(images, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
loss.backward()
# do above for many steps
dalle2 = DALLE2(
prior = diffusion_prior,
decoder = decoder
)
images = dalle2(
['a butterfly trying to escape a tornado'],
cond_scale = 2. # classifier free guidance strength (> 1 would strengthen the condition)
)
# save your image (in this example, of size 256x256)
```
Now you'll just have to worry about training the Prior and the Decoder!
## Experimental
### DALL-E2 with Latent Diffusion
This repository decides to take the next step and offer DALL-E2 combined with <a href="https://huggingface.co/spaces/multimodalart/latentdiffusion">latent diffusion</a>, from Rombach et al.
This repository decides to take the next step and offer DALL-E v2 combined with <a href="https://huggingface.co/spaces/multimodalart/latentdiffusion">latent diffusion</a>, from Rombach et al.
You can use it as follows. Latent diffusion can be limited to just the first U-Net in the cascade, or to any number you wish.
The repository also comes equipped with all the necessary settings to recreate `ViT-VQGan` from the <a href="https://arxiv.org/abs/2110.04627">Improved VQGans</a> paper. Furthermore, the <a href="https://github.com/lucidrains/vector-quantize-pytorch">vector quantization</a> library also comes equipped to do <a href="https://arxiv.org/abs/2203.01941">residual or multi-headed quantization</a>, which I believe will give an even further boost in performance to the autoencoder.
```python
import torch
from dalle2_pytorch import Unet, Decoder, CLIP, VQGanVAE
@@ -526,7 +620,7 @@ clip = CLIP(
# 3 unets for the decoder (a la cascading DDPM)
# first two unets are doing latent diffusion
# vqgan-vae must be trained before hand
# vqgan-vae must be trained beforehand
vae1 = VQGanVAE(
dim = 32,
@@ -579,7 +673,8 @@ decoder = Decoder(
unet = (unet1, unet2, unet3), # insert unets in order of low resolution to highest resolution (you can have as many stages as you want here)
image_sizes = (256, 512, 1024), # resolutions, 256 for first unet, 512 for second, 1024 for third
timesteps = 100,
cond_drop_prob = 0.2
image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5
).cuda()
# mock images (get a lot of this)
@@ -613,7 +708,261 @@ images = decoder.sample(mock_image_embed) # (1, 3, 1024, 1024)
## Training wrapper (wip)
Offer training wrappers
### Decoder Training
Training the `Decoder` may be confusing, as one needs to keep track of an optimizer for each of the `Unet`(s) separately. Each `Unet` will also need its own corresponding exponential moving average. The `DecoderTrainer` hopes to make this simple, as shown below
```python
import torch
from dalle2_pytorch import DALLE2, Unet, Decoder, CLIP, DecoderTrainer
clip = CLIP(
dim_text = 512,
dim_image = 512,
dim_latent = 512,
num_text_tokens = 49408,
text_enc_depth = 6,
text_seq_len = 256,
text_heads = 8,
visual_enc_depth = 6,
visual_image_size = 256,
visual_patch_size = 32,
visual_heads = 8
).cuda()
# mock data
text = torch.randint(0, 49408, (32, 256)).cuda()
images = torch.randn(32, 3, 256, 256).cuda()
# decoder (with unet)
unet1 = Unet(
dim = 128,
image_embed_dim = 512,
text_embed_dim = 512,
cond_dim = 128,
channels = 3,
dim_mults=(1, 2, 4, 8)
).cuda()
unet2 = Unet(
dim = 16,
image_embed_dim = 512,
text_embed_dim = 512,
cond_dim = 128,
channels = 3,
dim_mults = (1, 2, 4, 8, 16),
cond_on_text_encodings = True
).cuda()
decoder = Decoder(
unet = (unet1, unet2),
image_sizes = (128, 256),
clip = clip,
timesteps = 1000,
condition_on_text_encodings = True
).cuda()
decoder_trainer = DecoderTrainer(
decoder,
lr = 3e-4,
wd = 1e-2,
ema_beta = 0.99,
ema_update_after_step = 1000,
ema_update_every = 10,
)
for unet_number in (1, 2):
loss = decoder_trainer(
images,
text = text,
unet_number = unet_number, # which unet to train on
max_batch_size = 4 # gradient accumulation - this sets the maximum batch size in which to do forward and backwards pass - for this example 32 / 4 == 8 times
)
decoder_trainer.update(unet_number) # update the specific unet as well as its exponential moving average
# after much training
# you can sample from the exponentially moving averaged unets as so
mock_image_embed = torch.randn(4, 512).cuda()
images = decoder_trainer.sample(mock_image_embed, text = text) # (4, 3, 256, 256)
```
### Diffusion Prior Training
Similarly, one can use the `DiffusionPriorTrainer` to automatically instantiate and keep track of an exponential moving averaged prior.
```python
import torch
from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, DiffusionPriorTrainer, Unet, Decoder, CLIP
clip = CLIP(
dim_text = 512,
dim_image = 512,
dim_latent = 512,
num_text_tokens = 49408,
text_enc_depth = 6,
text_seq_len = 256,
text_heads = 8,
visual_enc_depth = 6,
visual_image_size = 256,
visual_patch_size = 32,
visual_heads = 8
).cuda()
# mock data
text = torch.randint(0, 49408, (32, 256)).cuda()
images = torch.randn(32, 3, 256, 256).cuda()
# prior networks (with transformer)
prior_network = DiffusionPriorNetwork(
dim = 512,
depth = 6,
dim_head = 64,
heads = 8
).cuda()
diffusion_prior = DiffusionPrior(
net = prior_network,
clip = clip,
timesteps = 100,
cond_drop_prob = 0.2
).cuda()
diffusion_prior_trainer = DiffusionPriorTrainer(
diffusion_prior,
lr = 3e-4,
wd = 1e-2,
ema_beta = 0.99,
ema_update_after_step = 1000,
ema_update_every = 10,
)
loss = diffusion_prior_trainer(text, images, max_batch_size = 4)
diffusion_prior_trainer.update() # this will update the optimizer as well as the exponential moving averaged diffusion prior
# after much of the above three lines in a loop
# you can sample from the exponential moving average of the diffusion prior identically to how you do so for DiffusionPrior
image_embeds = diffusion_prior_trainer.sample(text) # (4, 512) - exponential moving averaged image embeddings
```
### Decoder Dataloaders
In order to make loading data simple and efficient, we include some general dataloaders that can be used to train portions of the network.
#### Decoder: Image Embedding Dataset
When training the decoder (and up samplers if training together) in isolation, you will need to load images and corresponding image embeddings. This dataset can read two similar types of datasets. First, it can read a [webdataset](https://github.com/webdataset/webdataset) that contains `.jpg` and `.npy` files in the `.tar`s that contain the images and associated image embeddings respectively. Alternatively, you can also specify a source for the embeddings outside of the webdataset. In this case, the path to the embeddings should contain `.npy` files with the same shard numbers as the webdataset and there should be a correspondence between the filename of the `.jpg` and the index of the embedding in the `.npy`. So, for example, `0001.tar` from the webdataset with image `00010509.jpg` (the first 4 digits are the shard number and the last 4 are the index) in it should be paralleled by a `img_emb_0001.npy` which contains a NumPy array with the embedding at index 509.
Generating a dataset of this type:
1. Use [img2dataset](https://github.com/rom1504/img2dataset) to generate a webdataset.
2. Use [clip-retrieval](https://github.com/rom1504/clip-retrieval) to convert the images to embeddings.
3. Use [embedding-dataset-reordering](https://github.com/Veldrovive/embedding-dataset-reordering) to reorder the embeddings into the expected format.
Usage:
```python
from dalle2_pytorch.dataloaders import ImageEmbeddingDataset, create_image_embedding_dataloader
# Create a dataloader directly.
dataloader = create_image_embedding_dataloader(
tar_url="/path/or/url/to/webdataset/{0000..9999}.tar", # Uses braket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar
embeddings_url="path/or/url/to/embeddings/folder", # Included if .npy files are not in webdataset. Left out or set to None otherwise
num_workers=4,
batch_size=32,
shard_width=4, # If a file in the webdataset shard 3 is named 0003039.jpg, we know the shard width is 4 and the last three digits are the index
shuffle_num=200, # Does a shuffle of the data with a buffer size of 200
shuffle_shards=True, # Shuffle the order the shards are read in
resample_shards=False, # Sample shards with replacement. If true, an epoch will be infinite unless stopped manually
)
for img, emb in dataloader:
print(img.shape) # torch.Size([32, 3, 256, 256])
print(emb.shape) # torch.Size([32, 512])
# Train decoder only as shown above
# Or create a dataset without a loader so you can configure it manually
dataset = ImageEmbeddingDataset(
urls="/path/or/url/to/webdataset/{0000..9999}.tar",
embedding_folder_url="path/or/url/to/embeddings/folder",
shard_width=4,
shuffle_shards=True,
resample=False
)
```
## Scripts
### Using the `train_diffusion_prior.py` script
This script allows training the DiffusionPrior on pre-computed text and image embeddings. The working example below elucidates this process.
Please note that the script internally passes text_embed and image_embed to the DiffusionPrior, unlike the example below.
### Usage
```bash
$ python train_diffusion_prior.py
```
The most significant parameters for the script are as follows:
--image-embed-url, default = "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/")
--text-embed-url, default = "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/")
--image-embed-dim, default=768 - 768 corresponds to the ViT iL/14 embedding size,change it to what your chosen ViT generates
--learning-rate, default=1.1e-4
--weight-decay, default=6.02e-2
--max-grad-norm, default=0.5
--batch-size, default=10 ** 4
--num-epochs, default=5
--clip, default=None # Signals the prior to use pre-computed embeddings
### Sample wandb run log
Please find a sample wandb run log at : https://wandb.ai/laion/diffusion-prior/runs/1blxu24j
### Loading and saving the Diffusion Prior model
Two methods are provided, load_diffusion_model and save_diffusion_model, the names being self-explanatory.
## from dalle2_pytorch.train import load_diffusion_model, save_diffusion_model
load_diffusion_model(dprior_path, device)
dprior_path : path to saved model(.pth)
device : the cuda device you're running on
save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim)
save_path : path to save at
model : object of Diffusion_Prior
optimizer : optimizer object - see train_diffusion_prior.py for how to create one.
e.g: optimizer = get_optimizer(diffusion_prior.net.parameters(), wd=weight_decay, lr=learning_rate)
scaler : a GradScaler object.
e.g: scaler = GradScaler(enabled=amp)
config : config object created in train_diffusion_prior.py - see file for example.
image_embed_dim - the dimension of the image_embedding
e.g: 768
## CLI (wip)
@@ -644,14 +993,34 @@ Once built, images will be saved to the same directory the command is invoked
- [x] for decoder, allow ability to customize objective (predict epsilon vs x0), in case latent diffusion does better with prediction of x0
- [x] use attention-based upsampling https://arxiv.org/abs/2112.11435
- [x] use inheritance just this once for sharing logic between decoder and prior network ddpms
- [ ] abstract interface for CLIP adapter class, so other CLIPs can be brought in
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
- [x] bring in vit-vqgan https://arxiv.org/abs/2110.04627 for the latent diffusion
- [x] abstract interface for CLIP adapter class, so other CLIPs can be brought in
- [x] take care of mixed precision as well as gradient accumulation within decoder trainer
- [x] just take care of the training for the decoder in a wrapper class, as each unet in the cascade will need its own optimizer
- [x] bring in tools to train vqgan-vae
- [x] add convnext backbone for vqgan-vae (in addition to vit [vit-vqgan] + resnet)
- [x] make sure DDPMs can be run with traditional resnet blocks (but leave convnext as an option for experimentation)
- [x] make sure for the latter unets in the cascade, one can train on crops for learning super resolution (constrain the unet to be only convolutions in that case, or allow conv-like attention with rel pos bias)
- [x] offer setting in diffusion prior to split time and image embeddings into multiple tokens, configurable, for more surface area during attention
- [x] make sure resnet hyperparameters can be configurable across unet depth (groups and expansion factor)
- [x] pull logic for training diffusion prior into a class DiffusionPriorTrainer, for eventual script based + CLI based training
- [x] make sure the cascading ddpm in the repository can be trained unconditionally, offer a one-line CLI tool for training on a folder of images
- [x] bring in cross-scale embedding from iclr paper https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/crossformer.py#L14
- [x] cross embed layers for downsampling, as an option
- [x] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a>
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
- [ ] train on a toy task, offer in colab
- [ ] think about how best to design a declarative training config that handles preencoding for prior and training of multiple networks in decoder
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
- [ ] bring in tools to train vqgan-vae
- [ ] bring in vit-vqgan https://arxiv.org/abs/2110.04627 for the latent diffusion
- [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824
- [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
- [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783
- [ ] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
- [ ] bring in skip-layer excitatons (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training
- [ ] decoder needs one day worth of refactor for tech debt
- [ ] allow for unet to be able to condition non-cross attention style as well
## Citations
@@ -682,10 +1051,12 @@ Once built, images will be saved to the same directory the command is invoked
```
```bibtex
@inproceedings{Liu2022ACF,
title = {A ConvNet for the 2020https://arxiv.org/abs/2112.11435s},
author = {Zhuang Liu and Hanzi Mao and Chaozheng Wu and Christoph Feichtenhofer and Trevor Darrell and Saining Xie},
year = {2022}
@article{shen2019efficient,
author = {Zhuoran Shen and Mingyuan Zhang and Haiyu Zhao and Shuai Yi and Hongsheng Li},
title = {Efficient Attention: Attention with Linear Complexities},
journal = {CoRR},
year = {2018},
url = {http://arxiv.org/abs/1812.01243},
}
```
@@ -697,4 +1068,45 @@ Once built, images will be saved to the same directory the command is invoked
}
```
*Creating noise from data is easy; creating data from noise is generative modeling.* - Yang Song's <a href="https://arxiv.org/abs/2011.13456">paper</a>
```bibtex
@article{Yu2021VectorquantizedIM,
title = {Vector-quantized Image Modeling with Improved VQGAN},
author = {Jiahui Yu and Xin Li and Jing Yu Koh and Han Zhang and Ruoming Pang and James Qin and Alexander Ku and Yuanzhong Xu and Jason Baldridge and Yonghui Wu},
journal = {ArXiv},
year = {2021},
volume = {abs/2110.04627}
}
```
```bibtex
@article{Shleifer2021NormFormerIT,
title = {NormFormer: Improved Transformer Pretraining with Extra Normalization},
author = {Sam Shleifer and Jason Weston and Myle Ott},
journal = {ArXiv},
year = {2021},
volume = {abs/2110.09456}
}
```
```bibtex
@article{Yu2022CoCaCC,
title = {CoCa: Contrastive Captioners are Image-Text Foundation Models},
author = {Jiahui Yu and Zirui Wang and Vijay Vasudevan and Legg Yeung and Mojtaba Seyedhosseini and Yonghui Wu},
journal = {ArXiv},
year = {2022},
volume = {abs/2205.01917}
}
```
```bibtex
@misc{wang2021crossformer,
title = {CrossFormer: A Versatile Vision Transformer Hinging on Cross-scale Attention},
author = {Wenxiao Wang and Lu Yao and Long Chen and Binbin Lin and Deng Cai and Xiaofei He and Wei Liu},
year = {2021},
eprint = {2108.00154},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>

View File

@@ -1,4 +1,6 @@
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter
from dalle2_pytorch.train import DecoderTrainer, DiffusionPriorTrainer
from dalle2_pytorch.vqgan_vae import VQGanVAE
from x_clip import CLIP

View File

@@ -1,125 +0,0 @@
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange, repeat
class LayerNormChan(nn.Module):
def __init__(
self,
dim,
eps = 1e-5
):
super().__init__()
self.eps = eps
self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1))
def forward(self, x):
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (var + self.eps).sqrt() * self.gamma
# attention-based upsampling
# from https://arxiv.org/abs/2112.11435
class QueryAndAttend(nn.Module):
def __init__(
self,
*,
dim,
num_queries = 1,
dim_head = 32,
heads = 8,
window_size = 3
):
super().__init__()
self.scale = dim_head ** -0.5
inner_dim = dim_head * heads
self.heads = heads
self.dim_head = dim_head
self.window_size = window_size
self.num_queries = num_queries
self.rel_pos_bias = nn.Parameter(torch.randn(heads, num_queries, window_size * window_size, 1, 1))
self.queries = nn.Parameter(torch.randn(heads, num_queries, dim_head))
self.to_kv = nn.Conv2d(dim, dim_head * 2, 1, bias = False)
self.to_out = nn.Conv2d(inner_dim, dim, 1, bias = False)
def forward(self, x):
"""
einstein notation
b - batch
h - heads
l - num queries
d - head dimension
x - height
y - width
j - source sequence for attending to (kernel size squared in this case)
"""
wsz, heads, dim_head, num_queries = self.window_size, self.heads, self.dim_head, self.num_queries
batch, _, height, width = x.shape
is_one_query = self.num_queries == 1
# queries, keys, values
q = self.queries * self.scale
k, v = self.to_kv(x).chunk(2, dim = 1)
# similarities
sim = einsum('h l d, b d x y -> b h l x y', q, k)
sim = rearrange(sim, 'b ... x y -> b (...) x y')
# unfold the similarity scores, with float(-inf) as padding value
mask_value = -torch.finfo(sim.dtype).max
sim = F.pad(sim, ((wsz // 2,) * 4), value = mask_value)
sim = F.unfold(sim, kernel_size = wsz)
sim = rearrange(sim, 'b (h l j) (x y) -> b h l j x y', h = heads, l = num_queries, x = height, y = width)
# rel pos bias
sim = sim + self.rel_pos_bias
# numerically stable attention
sim = sim - sim.amax(dim = -3, keepdim = True).detach()
attn = sim.softmax(dim = -3)
# unfold values
v = F.pad(v, ((wsz // 2,) * 4), value = 0.)
v = F.unfold(v, kernel_size = wsz)
v = rearrange(v, 'b (d j) (x y) -> b d j x y', d = dim_head, x = height, y = width)
# aggregate values
out = einsum('b h l j x y, b d j x y -> b l h d x y', attn, v)
# combine heads
out = rearrange(out, 'b l h d x y -> (b l) (h d) x y')
out = self.to_out(out)
out = rearrange(out, '(b l) d x y -> b l d x y', b = batch)
# return original input if one query
if is_one_query:
out = rearrange(out, 'b 1 ... -> b ...')
return out
class QueryAttnUpsample(nn.Module):
def __init__(self, dim, **kwargs):
super().__init__()
self.norm = LayerNormChan(dim)
self.qna = QueryAndAttend(dim = dim, num_queries = 4, **kwargs)
def forward(self, x):
x = self.norm(x)
out = self.qna(x)
out = rearrange(out, 'b (w1 w2) c h w -> b c (h w1) (w w2)', w1 = 2, w2 = 2)
return out

View File

@@ -1,6 +1,7 @@
import click
import torch
import torchvision.transforms as T
from functools import reduce
from pathlib import Path
from dalle2_pytorch import DALLE2, Decoder, DiffusionPrior

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1 @@
from dalle2_pytorch.dataloaders.decoder_loader import ImageEmbeddingDataset, create_image_embedding_dataloader

View File

@@ -0,0 +1,170 @@
import os
import webdataset as wds
import torch
import numpy as np
import fsspec
def get_shard(filename):
"""
Filenames with shards in them have a consistent structure that we can take advantage of
Standard structure: path/to/file/prefix_string_00001.ext
"""
try:
return filename.split("_")[-1].split(".")[0]
except ValueError:
raise RuntimeError(f"Could not find shard for filename {filename}")
def get_example_file(fs, path, file_format):
"""
Given a file system and a file extension, return the example file
"""
return fs.glob(os.path.join(path, f"*.{file_format}"))[0]
def embedding_inserter(samples, embeddings_url, shard_width, handler=wds.handlers.reraise_exception):
"""Given a datum of {"__key__": str, "__url__": str, ...} adds the cooresponding embedding and yields"""
previous_tar_url = None
current_embeddings = None
# Get a reference to an abstract file system where the embeddings are stored
embeddings_fs, embeddings_path = fsspec.core.url_to_fs(embeddings_url)
example_embedding_file = get_example_file(embeddings_fs, embeddings_path, "npy")
example_embedding_shard = get_shard(example_embedding_file)
emb_shard_width = len(example_embedding_shard)
# Easier to get the basename without the shard once than search through for the correct file every time
embedding_file_basename = '_'.join(example_embedding_file.split("_")[:-1]) + "_"
def load_corresponding_embeds(tar_url):
"""Finds and reads the npy files that contains embeddings for the given webdataset tar"""
shard = int(tar_url.split("/")[-1].split(".")[0])
embedding_url = embedding_file_basename + str(shard).zfill(emb_shard_width) + '.npy'
with embeddings_fs.open(embedding_url) as f:
data = np.load(f)
return torch.from_numpy(data)
for sample in samples:
try:
tar_url = sample["__url__"]
key = sample["__key__"]
if tar_url != previous_tar_url:
# If the tar changed, we need to download new embeddings
# This means if we shuffle before inserting it will load many more files than we expect and be very inefficient.
previous_tar_url = tar_url
current_embeddings = load_corresponding_embeds(tar_url)
embedding_index = int(key[shard_width:])
sample["npy"] = current_embeddings[embedding_index]
yield sample
except Exception as exn: # From wds implementation
if handler(exn):
continue
else:
break
insert_embedding = wds.filters.pipelinefilter(embedding_inserter)
def verify_keys(samples, handler=wds.handlers.reraise_exception):
"""
Requires that both the image and embedding are present in the sample
This is important to do as a user may forget they do not have embeddings in their webdataset and neglect to add them using the embedding_folder_url parameter.
"""
for sample in samples:
try:
assert "jpg" in sample, f"Sample {sample['__key__']} missing image"
assert "npy" in sample, f"Sample {sample['__key__']} missing embedding. Did you set embedding_folder_url?"
yield sample
except Exception as exn: # From wds implementation
if handler(exn):
continue
else:
break
class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
"""
A fluid interface wrapper for DataPipline that returns image embedding pairs
Reads embeddings as npy files from the webdataset if they exist. If embedding_folder_url is set, they will be inserted in from the alternate source.
"""
def __init__(
self,
urls,
embedding_folder_url=None,
shard_width=None,
handler=wds.handlers.reraise_exception,
resample=False,
shuffle_shards=True
):
"""
Modeled directly off of the WebDataset constructor
:param urls: A url pointing to the tar files of the webdataset formatted as /path/to/webdataset/{0000..9999}.tar
:param embedding_folder_url: Required if webdataset does not contain embeddings. A url pointing to the npy files of the embeddings. Should have the same number of shards as the webdataset.
Webdataset image keys should align with the index of the embedding. This means missing image indices must have a corresponding embedding of all zeros.
:param shard_width: The number of digits in the shard number. This is used to align the embedding index with the image index.
For example, if a file in the webdataset shard 3 is named 0003039.jpg, we know the shard with this 4 and the last three digits are the index.
:param handler: A webdataset handler.
:param resample: If true, resample webdataset shards with replacement. You need to set your own epoch size if this is true since it will resample infinitely.
:param shuffle_shards: If true, shuffle the shards before resampling. This cannot be true if resample is true.
"""
super().__init__()
# Add the shardList and randomize or resample if requested
if resample:
assert not shuffle_shards, "Cannot both resample and shuffle"
self.append(wds.ResampledShards(urls))
else:
self.append(wds.SimpleShardList(urls))
if shuffle_shards:
self.append(wds.filters.shuffle(1000))
self.append(wds.split_by_node)
self.append(wds.split_by_worker)
self.append(wds.tarfile_to_samples(handler=handler))
self.append(wds.decode("torchrgb"))
if embedding_folder_url is not None:
assert shard_width is not None, "Reading embeddings separately requires shard length to be given"
self.append(insert_embedding(embeddings_url=embedding_folder_url, shard_width=shard_width, handler=handler))
self.append(verify_keys)
self.append(wds.to_tuple("jpg", "npy"))
def create_image_embedding_dataloader(
tar_url,
num_workers,
batch_size,
embeddings_url=None,
shard_width=None,
shuffle_num = None,
shuffle_shards = True,
resample_shards = False,
handler=wds.handlers.warn_and_continue
):
"""
Convenience function to create an image embedding dataseta and dataloader in one line
:param tar_url: A url pointing to the tar files of the webdataset formatted as /path/to/webdataset/{0000..9999}.tar
:param num_workers: The number of workers to use for the dataloader
:param batch_size: The batch size to use for the dataloader
:param embeddings_url: Required if webdataset does not contain embeddings. A url pointing to the npy files of the embeddings. Should have the same number of shards as the webdataset.
Webdataset image keys should align with the index of the embedding. This means missing image indices must have a corresponding embedding of all zeros.
:param shard_width: The number of digits in the shard number. This is used to align the embedding index with the image index.
For example, if a file in the webdataset shard 3 is named 0003039.jpg, we know the shard width is 4 and the last three digits are the index.
:param shuffle_num: If not None, shuffle the dataset with this size buffer after sampling.
:param shuffle_shards: If true, shuffle the shards before sampling. This cannot be true if resample is true.
:param resample_shards: If true, resample webdataset shards with replacement. You need to set your own epoch size if this is true since it will resample infinitely.
:param handler: A webdataset handler.
"""
ds = ImageEmbeddingDataset(
tar_url,
embeddings_url,
shard_width=shard_width,
shuffle_shards=shuffle_shards,
resample=resample_shards,
handler=handler
)
if shuffle_num is not None and shuffle_num > 0:
ds.shuffle(1000)
return wds.WebLoader(
ds,
num_workers=num_workers,
batch_size=batch_size,
prefetch_factor=2, # This might be good to have high so the next npy file is prefetched
pin_memory=True,
shuffle=False
)

View File

@@ -0,0 +1,30 @@
from torch.optim import AdamW, Adam
def separate_weight_decayable_params(params):
no_wd_params = set([param for param in params if param.ndim < 2])
wd_params = set(params) - no_wd_params
return wd_params, no_wd_params
def get_optimizer(
params,
lr = 2e-5,
wd = 1e-2,
betas = (0.9, 0.999),
eps = 1e-8,
filter_by_requires_grad = False
):
if filter_by_requires_grad:
params = list(filter(lambda t: t.requires_grad, params))
if wd == 0:
return Adam(params, lr = lr, betas = betas, eps = eps)
params = set(params)
wd_params, no_wd_params = separate_weight_decayable_params(params)
param_groups = [
{'params': list(wd_params)},
{'params': list(no_wd_params), 'weight_decay': 0},
]
return AdamW(param_groups, lr = lr, weight_decay = wd, betas = betas, eps = eps)

View File

@@ -0,0 +1,49 @@
import os
import torch
from torch import nn
# helper functions
def exists(val):
return val is not None
# base class
class BaseTracker(nn.Module):
def __init__(self):
super().__init__()
def init(self, config, **kwargs):
raise NotImplementedError
def log(self, log, **kwargs):
raise NotImplementedError
# basic stdout class
class ConsoleTracker(BaseTracker):
def init(self, **config):
print(config)
def log(self, log, **kwargs):
print(log)
# basic wandb class
class WandbTracker(BaseTracker):
def __init__(self):
super().__init__()
try:
import wandb
except ImportError as e:
print('`pip install wandb` to use the wandb experiment tracker')
raise e
os.environ["WANDB_SILENT"] = "true"
self.wandb = wandb
def init(self, **config):
self.wandb.init(**config)
def log(self, log, **kwargs):
self.wandb.log(log, **kwargs)

View File

@@ -1,6 +1,143 @@
import time
import copy
from math import ceil
from functools import partial
from collections.abc import Iterable
import torch
from torch import nn
from torch.cuda.amp import autocast, GradScaler
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
from dalle2_pytorch.optimizer import get_optimizer
# helper functions
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def cast_tuple(val, length = 1):
return val if isinstance(val, tuple) else ((val,) * length)
def pick_and_pop(keys, d):
values = list(map(lambda key: d.pop(key), keys))
return dict(zip(keys, values))
def group_dict_by_key(cond, d):
return_val = [dict(),dict()]
for key in d.keys():
match = bool(cond(key))
ind = int(not match)
return_val[ind][key] = d[key]
return (*return_val,)
def string_begins_with(prefix, str):
return str.startswith(prefix)
def group_by_key_prefix(prefix, d):
return group_dict_by_key(partial(string_begins_with, prefix), d)
def groupby_prefix_and_trim(prefix, d):
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
return kwargs_without_prefix, kwargs
# gradient accumulation functions
def split_iterable(it, split_size):
accum = []
for ind in range(ceil(len(it) / split_size)):
start_index = ind * split_size
accum.append(it[start_index: (start_index + split_size)])
return accum
def split(t, split_size = None):
if not exists(split_size):
return t
if isinstance(t, torch.Tensor):
return t.split(split_size, dim = 0)
if isinstance(t, Iterable):
return split_iterable(t, split_size)
return TypeError
def find_first(cond, arr):
for el in arr:
if cond(el):
return el
return None
def split_args_and_kwargs(*args, split_size = None, **kwargs):
all_args = (*args, *kwargs.values())
len_all_args = len(all_args)
first_tensor = find_first(lambda t: isinstance(t, torch.Tensor), all_args)
assert exists(first_tensor)
batch_size = len(first_tensor)
split_size = default(split_size, batch_size)
chunk_size = ceil(batch_size / split_size)
dict_len = len(kwargs)
dict_keys = kwargs.keys()
split_kwargs_index = len_all_args - dict_len
split_all_args = [split(arg, split_size = split_size) if exists(arg) and isinstance(arg, (torch.Tensor, Iterable)) else ((arg,) * chunk_size) for arg in all_args]
chunk_sizes = tuple(map(len, split_all_args[0]))
for (chunk_size, *chunked_all_args) in tuple(zip(chunk_sizes, *split_all_args)):
chunked_args, chunked_kwargs_values = chunked_all_args[:split_kwargs_index], chunked_all_args[split_kwargs_index:]
chunked_kwargs = dict(tuple(zip(dict_keys, chunked_kwargs_values)))
chunk_size_frac = chunk_size / batch_size
yield chunk_size_frac, (chunked_args, chunked_kwargs)
# print helpers
def print_ribbon(s, symbol = '=', repeat = 40):
flank = symbol * repeat
return f'{flank} {s} {flank}'
# saving and loading functions
# for diffusion prior
def load_diffusion_model(dprior_path, device):
dprior_path = Path(dprior_path)
assert dprior_path.exists(), 'Dprior model file does not exist'
loaded_obj = torch.load(str(dprior_path), map_location='cpu')
# Get hyperparameters of loaded model
dpn_config = loaded_obj['hparams']['diffusion_prior_network']
dp_config = loaded_obj['hparams']['diffusion_prior']
image_embed_dim = loaded_obj['image_embed_dim']['image_embed_dim']
# Create DiffusionPriorNetwork and DiffusionPrior with loaded hyperparameters
# DiffusionPriorNetwork
prior_network = DiffusionPriorNetwork( dim = image_embed_dim, **dpn_config).to(device)
# DiffusionPrior with text embeddings and image embeddings pre-computed
diffusion_prior = DiffusionPrior(net = prior_network, **dp_config, image_embed_dim = image_embed_dim).to(device)
# Load state dict from saved model
diffusion_prior.load_state_dict(loaded_obj['model'])
return diffusion_prior, loaded_obj
def save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim):
# Saving State Dict
print_ribbon('Saving checkpoint')
state_dict = dict(model=model.state_dict(),
optimizer=optimizer.state_dict(),
scaler=scaler.state_dict(),
hparams = config,
image_embed_dim = {"image_embed_dim":image_embed_dim})
torch.save(state_dict, save_path+'/'+str(time.time())+'_saved_model.pth')
# exponential moving average wrapper
@@ -8,25 +145,29 @@ class EMA(nn.Module):
def __init__(
self,
model,
beta = 0.99,
ema_update_after_step = 1000,
ema_update_every = 10,
beta = 0.9999,
update_after_step = 1000,
update_every = 10,
):
super().__init__()
self.beta = beta
self.online_model = model
self.ema_model = copy.deepcopy(model)
self.ema_update_after_step = ema_update_after_step # only start EMA after this step number, starting at 0
self.ema_update_every = ema_update_every
self.update_after_step = update_after_step # only start EMA after this step number, starting at 0
self.update_every = update_every
self.register_buffer('initted', torch.Tensor([False]))
self.register_buffer('step', torch.tensor([0.]))
def restore_ema_model_device(self):
device = self.initted.device
self.ema_model.to(device)
def update(self):
self.step += 1
if self.step <= self.ema_update_after_step or (self.step % self.ema_update_every) != 0:
if self.step <= self.update_after_step or (self.step % self.update_every) != 0:
return
if not self.initted:
@@ -35,7 +176,7 @@ class EMA(nn.Module):
self.update_moving_average(self.ema_model, self.online_model)
def update_moving_average(ma_model, current_model):
def update_moving_average(self, ma_model, current_model):
def calculate_ema(beta, old, new):
if not exists(old):
return new
@@ -51,3 +192,220 @@ class EMA(nn.Module):
def __call__(self, *args, **kwargs):
return self.ema_model(*args, **kwargs)
# diffusion prior trainer
class DiffusionPriorTrainer(nn.Module):
def __init__(
self,
diffusion_prior,
use_ema = True,
lr = 3e-4,
wd = 1e-2,
eps = 1e-6,
max_grad_norm = None,
amp = False,
**kwargs
):
super().__init__()
assert isinstance(diffusion_prior, DiffusionPrior)
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
self.diffusion_prior = diffusion_prior
# exponential moving average
self.use_ema = use_ema
if self.use_ema:
self.ema_diffusion_prior = EMA(diffusion_prior, **ema_kwargs)
# optimizer and mixed precision stuff
self.amp = amp
self.scaler = GradScaler(enabled = amp)
self.optimizer = get_optimizer(
diffusion_prior.parameters(),
lr = lr,
wd = wd,
eps = eps,
**kwargs
)
# gradient clipping if needed
self.max_grad_norm = max_grad_norm
self.register_buffer('step', torch.tensor([0.]))
def update(self):
if exists(self.max_grad_norm):
self.scaler.unscale_(self.optimizer)
nn.utils.clip_grad_norm_(self.diffusion_prior.parameters(), self.max_grad_norm)
self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.zero_grad()
if self.use_ema:
self.ema_diffusion_prior.update()
self.step += 1
@torch.inference_mode()
def p_sample_loop(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs)
@torch.inference_mode()
def sample(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs)
@torch.inference_mode()
def sample_batch_size(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.sample_batch_size(*args, **kwargs)
def forward(
self,
*args,
max_batch_size = None,
**kwargs
):
total_loss = 0.
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
with autocast(enabled = self.amp):
loss = self.diffusion_prior(*chunked_args, **chunked_kwargs)
loss = loss * chunk_size_frac
total_loss += loss.item()
self.scaler.scale(loss).backward()
return total_loss
# decoder trainer
class DecoderTrainer(nn.Module):
def __init__(
self,
decoder,
use_ema = True,
lr = 2e-5,
wd = 1e-2,
eps = 1e-8,
max_grad_norm = None,
amp = False,
**kwargs
):
super().__init__()
assert isinstance(decoder, Decoder)
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
self.decoder = decoder
self.num_unets = len(self.decoder.unets)
self.use_ema = use_ema
if use_ema:
has_lazy_linear = any([type(module) == nn.LazyLinear for module in decoder.modules()])
assert not has_lazy_linear, 'you must set the text_embed_dim on your u-nets if you plan on doing automatic exponential moving average'
self.ema_unets = nn.ModuleList([])
self.amp = amp
# be able to finely customize learning rate, weight decay
# per unet
lr, wd, eps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps))
for ind, (unet, unet_lr, unet_wd, unet_eps) in enumerate(zip(self.decoder.unets, lr, wd, eps)):
optimizer = get_optimizer(
unet.parameters(),
lr = unet_lr,
wd = unet_wd,
eps = unet_eps,
**kwargs
)
setattr(self, f'optim{ind}', optimizer) # cannot use pytorch ModuleList for some reason with optimizers
if self.use_ema:
self.ema_unets.append(EMA(unet, **ema_kwargs))
scaler = GradScaler(enabled = amp)
setattr(self, f'scaler{ind}', scaler)
# gradient clipping if needed
self.max_grad_norm = max_grad_norm
self.register_buffer('step', torch.tensor([0.]))
@property
def unets(self):
return nn.ModuleList([ema.ema_model for ema in self.ema_unets])
def scale(self, loss, *, unet_number):
assert 1 <= unet_number <= self.num_unets
index = unet_number - 1
scaler = getattr(self, f'scaler{index}')
return scaler.scale(loss)
def update(self, unet_number):
assert 1 <= unet_number <= self.num_unets
index = unet_number - 1
unet = self.decoder.unets[index]
optimizer = getattr(self, f'optim{index}')
scaler = getattr(self, f'scaler{index}')
if exists(self.max_grad_norm):
scaler.unscale_(optimizer)
nn.utils.clip_grad_norm_(unet.parameters(), self.max_grad_norm)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
if self.use_ema:
ema_unet = self.ema_unets[index]
ema_unet.update()
self.step += 1
@torch.no_grad()
def sample(self, *args, **kwargs):
if self.use_ema:
trainable_unets = self.decoder.unets
self.decoder.unets = self.unets # swap in exponential moving averaged unets for sampling
output = self.decoder.sample(*args, **kwargs)
if self.use_ema:
self.decoder.unets = trainable_unets # restore original training unets
# cast the ema_model unets back to original device
for ema in self.ema_unets:
ema.restore_ema_model_device()
return output
def forward(
self,
*args,
unet_number,
max_batch_size = None,
**kwargs
):
total_loss = 0.
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
with autocast(enabled = self.amp):
loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs)
loss = loss * chunk_size_frac
total_loss += loss.item()
self.scale(loss, unet_number = unet_number).backward()
return total_loss

View File

@@ -0,0 +1,277 @@
from math import sqrt
import copy
from random import choice
from pathlib import Path
from shutil import rmtree
from PIL import Image
import torch
from torch import nn
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as T
from torchvision.datasets import ImageFolder
from torchvision.utils import make_grid, save_image
from einops import rearrange
from dalle2_pytorch.train import EMA
from dalle2_pytorch.vqgan_vae import VQGanVAE
from dalle2_pytorch.optimizer import get_optimizer
# helpers
def exists(val):
return val is not None
def noop(*args, **kwargs):
pass
def cycle(dl):
while True:
for data in dl:
yield data
def cast_tuple(t):
return t if isinstance(t, (tuple, list)) else (t,)
def yes_or_no(question):
answer = input(f'{question} (y/n) ')
return answer.lower() in ('yes', 'y')
def accum_log(log, new_logs):
for key, new_value in new_logs.items():
old_value = log.get(key, 0.)
log[key] = old_value + new_value
return log
# classes
class ImageDataset(Dataset):
def __init__(
self,
folder,
image_size,
exts = ['jpg', 'jpeg', 'png']
):
super().__init__()
self.folder = folder
self.image_size = image_size
self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]
print(f'{len(self.paths)} training samples found at {folder}')
self.transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize(image_size),
T.RandomHorizontalFlip(),
T.CenterCrop(image_size),
T.ToTensor()
])
def __len__(self):
return len(self.paths)
def __getitem__(self, index):
path = self.paths[index]
img = Image.open(path)
return self.transform(img)
# main trainer class
class VQGanVAETrainer(nn.Module):
def __init__(
self,
vae,
*,
num_train_steps,
lr,
batch_size,
folder,
grad_accum_every,
wd = 0.,
save_results_every = 100,
save_model_every = 1000,
results_folder = './results',
valid_frac = 0.05,
random_split_seed = 42,
ema_beta = 0.995,
ema_update_after_step = 2000,
ema_update_every = 10,
apply_grad_penalty_every = 4,
amp = False
):
super().__init__()
assert isinstance(vae, VQGanVAE), 'vae must be instance of VQGanVAE'
image_size = vae.image_size
self.vae = vae
self.ema_vae = EMA(vae, update_after_step = ema_update_after_step, update_every = ema_update_every)
self.register_buffer('steps', torch.Tensor([0]))
self.num_train_steps = num_train_steps
self.batch_size = batch_size
self.grad_accum_every = grad_accum_every
all_parameters = set(vae.parameters())
discr_parameters = set(vae.discr.parameters())
vae_parameters = all_parameters - discr_parameters
self.optim = get_optimizer(vae_parameters, lr = lr, wd = wd)
self.discr_optim = get_optimizer(discr_parameters, lr = lr, wd = wd)
self.amp = amp
self.scaler = GradScaler(enabled = amp)
self.discr_scaler = GradScaler(enabled = amp)
# create dataset
self.ds = ImageDataset(folder, image_size = image_size)
# split for validation
if valid_frac > 0:
train_size = int((1 - valid_frac) * len(self.ds))
valid_size = len(self.ds) - train_size
self.ds, self.valid_ds = random_split(self.ds, [train_size, valid_size], generator = torch.Generator().manual_seed(random_split_seed))
print(f'training with dataset of {len(self.ds)} samples and validating with randomly splitted {len(self.valid_ds)} samples')
else:
self.valid_ds = self.ds
print(f'training with shared training and valid dataset of {len(self.ds)} samples')
# dataloader
self.dl = cycle(DataLoader(
self.ds,
batch_size = batch_size,
shuffle = True
))
self.valid_dl = cycle(DataLoader(
self.valid_ds,
batch_size = batch_size,
shuffle = True
))
self.save_model_every = save_model_every
self.save_results_every = save_results_every
self.apply_grad_penalty_every = apply_grad_penalty_every
self.results_folder = Path(results_folder)
if len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?'):
rmtree(str(self.results_folder))
self.results_folder.mkdir(parents = True, exist_ok = True)
def train_step(self):
device = next(self.vae.parameters()).device
steps = int(self.steps.item())
apply_grad_penalty = not (steps % self.apply_grad_penalty_every)
self.vae.train()
# logs
logs = {}
# update vae (generator)
for _ in range(self.grad_accum_every):
img = next(self.dl)
img = img.to(device)
with autocast(enabled = self.amp):
loss = self.vae(
img,
return_loss = True,
apply_grad_penalty = apply_grad_penalty
)
self.scaler.scale(loss / self.grad_accum_every).backward()
accum_log(logs, {'loss': loss.item() / self.grad_accum_every})
self.scaler.step(self.optim)
self.scaler.update()
self.optim.zero_grad()
# update discriminator
if exists(self.vae.discr):
discr_loss = 0
for _ in range(self.grad_accum_every):
img = next(self.dl)
img = img.to(device)
with autocast(enabled = self.amp):
loss = self.vae(img, return_discr_loss = True)
self.discr_scaler.scale(loss / self.grad_accum_every).backward()
accum_log(logs, {'discr_loss': loss.item() / self.grad_accum_every})
self.discr_scaler.step(self.discr_optim)
self.discr_scaler.update()
self.discr_optim.zero_grad()
# log
print(f"{steps}: vae loss: {logs['loss']} - discr loss: {logs['discr_loss']}")
# update exponential moving averaged generator
self.ema_vae.update()
# sample results every so often
if not (steps % self.save_results_every):
for model, filename in ((self.ema_vae.ema_model, f'{steps}.ema'), (self.vae, str(steps))):
model.eval()
imgs = next(self.dl)
imgs = imgs.to(device)
recons = model(imgs)
nrows = int(sqrt(self.batch_size))
imgs_and_recons = torch.stack((imgs, recons), dim = 0)
imgs_and_recons = rearrange(imgs_and_recons, 'r b ... -> (b r) ...')
imgs_and_recons = imgs_and_recons.detach().cpu().float().clamp(0., 1.)
grid = make_grid(imgs_and_recons, nrow = 2, normalize = True, value_range = (0, 1))
logs['reconstructions'] = grid
save_image(grid, str(self.results_folder / f'{filename}.png'))
print(f'{steps}: saving to {str(self.results_folder)}')
# save model every so often
if not (steps % self.save_model_every):
state_dict = self.vae.state_dict()
model_path = str(self.results_folder / f'vae.{steps}.pt')
torch.save(state_dict, model_path)
ema_state_dict = self.ema_vae.state_dict()
model_path = str(self.results_folder / f'vae.{steps}.ema.pt')
torch.save(ema_state_dict, model_path)
print(f'{steps}: saving model to {str(self.results_folder)}')
self.steps += 1
return logs
def train(self, log_fn = noop):
device = next(self.vae.parameters()).device
while self.steps < self.num_train_steps:
logs = self.train_step()
log_fn(logs)
print('training complete')

View File

@@ -12,8 +12,8 @@ from torch.autograd import grad as torch_grad
import torchvision
from einops import rearrange, reduce, repeat
from dalle2_pytorch.attention import QueryAttnUpsample
from einops_exts import rearrange_many
from einops.layers.torch import Rearrange
# constants
@@ -146,6 +146,8 @@ class LayerNormChan(nn.Module):
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (var + self.eps).sqrt() * self.gamma
# discriminator
class Discriminator(nn.Module):
def __init__(
self,
@@ -179,6 +181,8 @@ class Discriminator(nn.Module):
return self.to_logits(x)
# positional encoding
class ContinuousPositionBias(nn.Module):
""" from https://arxiv.org/abs/2111.09883 """
@@ -213,6 +217,88 @@ class ContinuousPositionBias(nn.Module):
bias = rearrange(rel_pos, 'i j h -> h i j')
return x + bias
# resnet encoder / decoder
class ResnetEncDec(nn.Module):
def __init__(
self,
dim,
*,
channels = 3,
layers = 4,
layer_mults = None,
num_resnet_blocks = 1,
resnet_groups = 16,
first_conv_kernel_size = 5,
use_attn = True,
attn_dim_head = 64,
attn_heads = 8,
attn_dropout = 0.,
):
super().__init__()
assert dim % resnet_groups == 0, f'dimension {dim} must be divisible by {resnet_groups} (groups for the groupnorm)'
self.layers = layers
self.encoders = MList([])
self.decoders = MList([])
layer_mults = default(layer_mults, list(map(lambda t: 2 ** t, range(layers))))
assert len(layer_mults) == layers, 'layer multipliers must be equal to designated number of layers'
layer_dims = [dim * mult for mult in layer_mults]
dims = (dim, *layer_dims)
self.encoded_dim = dims[-1]
dim_pairs = zip(dims[:-1], dims[1:])
append = lambda arr, t: arr.append(t)
prepend = lambda arr, t: arr.insert(0, t)
if not isinstance(num_resnet_blocks, tuple):
num_resnet_blocks = (*((0,) * (layers - 1)), num_resnet_blocks)
if not isinstance(use_attn, tuple):
use_attn = (*((False,) * (layers - 1)), use_attn)
assert len(num_resnet_blocks) == layers, 'number of resnet blocks config must be equal to number of layers'
assert len(use_attn) == layers
for layer_index, (dim_in, dim_out), layer_num_resnet_blocks, layer_use_attn in zip(range(layers), dim_pairs, num_resnet_blocks, use_attn):
append(self.encoders, nn.Sequential(nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), leaky_relu()))
prepend(self.decoders, nn.Sequential(nn.ConvTranspose2d(dim_out, dim_in, 4, 2, 1), leaky_relu()))
if layer_use_attn:
prepend(self.decoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
for _ in range(layer_num_resnet_blocks):
append(self.encoders, ResBlock(dim_out, groups = resnet_groups))
prepend(self.decoders, GLUResBlock(dim_out, groups = resnet_groups))
if layer_use_attn:
append(self.encoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
prepend(self.encoders, nn.Conv2d(channels, dim, first_conv_kernel_size, padding = first_conv_kernel_size // 2))
append(self.decoders, nn.Conv2d(dim, channels, 1))
def get_encoded_fmap_size(self, image_size):
return image_size // (2 ** self.layers)
@property
def last_dec_layer(self):
return self.decoders[-1].weight
def encode(self, x):
for enc in self.encoders:
x = enc(x)
return x
def decode(self, x):
for dec in self.decoders:
x = dec(x)
return x
class GLUResBlock(nn.Module):
def __init__(self, chan, groups = 16):
super().__init__()
@@ -246,6 +332,7 @@ class ResBlock(nn.Module):
return self.net(x) + x
# vqgan attention layer
class VQGanAttention(nn.Module):
def __init__(
self,
@@ -290,6 +377,149 @@ class VQGanAttention(nn.Module):
return out + residual
# ViT encoder / decoder
class RearrangeImage(nn.Module):
def forward(self, x):
n = x.shape[1]
w = h = int(sqrt(n))
return rearrange(x, 'b (h w) ... -> b h w ...', h = h, w = w)
class Attention(nn.Module):
def __init__(
self,
dim,
*,
heads = 8,
dim_head = 32
):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.heads = heads
self.scale = dim_head ** -0.5
inner_dim = dim_head * heads
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Linear(inner_dim, dim)
def forward(self, x):
h = self.heads
x = self.norm(x)
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = h)
q = q * self.scale
sim = einsum('b h i d, b h j d -> b h i j', q, k)
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
attn = sim.softmax(dim = -1)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
def FeedForward(dim, mult = 4):
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, dim * mult, bias = False),
nn.GELU(),
nn.Linear(dim * mult, dim, bias = False)
)
class Transformer(nn.Module):
def __init__(
self,
dim,
*,
layers,
dim_head = 32,
heads = 8,
ff_mult = 4
):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(layers):
self.layers.append(nn.ModuleList([
Attention(dim = dim, dim_head = dim_head, heads = heads),
FeedForward(dim = dim, mult = ff_mult)
]))
self.norm = nn.LayerNorm(dim)
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return self.norm(x)
class ViTEncDec(nn.Module):
def __init__(
self,
dim,
channels = 3,
layers = 4,
patch_size = 8,
dim_head = 32,
heads = 8,
ff_mult = 4
):
super().__init__()
self.encoded_dim = dim
self.patch_size = patch_size
input_dim = channels * (patch_size ** 2)
self.encoder = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
nn.Linear(input_dim, dim),
Transformer(
dim = dim,
dim_head = dim_head,
heads = heads,
ff_mult = ff_mult,
layers = layers
),
RearrangeImage(),
Rearrange('b h w c -> b c h w')
)
self.decoder = nn.Sequential(
Rearrange('b c h w -> b (h w) c'),
Transformer(
dim = dim,
dim_head = dim_head,
heads = heads,
ff_mult = ff_mult,
layers = layers
),
nn.Sequential(
nn.Linear(dim, dim * 4, bias = False),
nn.Tanh(),
nn.Linear(dim * 4, input_dim, bias = False),
),
RearrangeImage(),
Rearrange('b h w (p1 p2 c) -> b c (h p1) (w p2)', p1 = patch_size, p2 = patch_size)
)
def get_encoded_fmap_size(self, image_size):
return image_size // self.patch_size
@property
def last_dec_layer(self):
return self.decoder[-3][-1].weight
def encode(self, x):
return self.encoder(x)
def decode(self, x):
return self.decoder(x)
# main vqgan-vae classes
class NullVQGanVAE(nn.Module):
def __init__(
self,
@@ -320,81 +550,45 @@ class VQGanVAE(nn.Module):
image_size,
channels = 3,
layers = 4,
layer_mults = None,
l2_recon_loss = False,
use_hinge_loss = True,
num_resnet_blocks = 1,
vgg = None,
vq_codebook_dim = 256,
vq_codebook_size = 512,
vq_decay = 0.8,
vq_commitment_weight = 1.,
vq_kmeans_init = True,
vq_use_cosine_sim = True,
use_attn = True,
attn_dim_head = 64,
attn_heads = 8,
resnet_groups = 16,
attn_dropout = 0.,
first_conv_kernel_size = 5,
use_vgg_and_gan = True,
vae_type = 'resnet',
discr_layers = 4,
**kwargs
):
super().__init__()
assert dim % resnet_groups == 0, f'dimension {dim} must be divisible by {resnet_groups} (groups for the groupnorm)'
vq_kwargs, kwargs = groupby_prefix_and_trim('vq_', kwargs)
encdec_kwargs, kwargs = groupby_prefix_and_trim('encdec_', kwargs)
self.image_size = image_size
self.channels = channels
self.layers = layers
self.fmap_size = image_size // (layers ** 2)
self.codebook_size = vq_codebook_size
self.encoders = MList([])
self.decoders = MList([])
if vae_type == 'resnet':
enc_dec_klass = ResnetEncDec
elif vae_type == 'vit':
enc_dec_klass = ViTEncDec
else:
raise ValueError(f'{vae_type} not valid')
layer_mults = default(layer_mults, list(map(lambda t: 2 ** t, range(layers))))
assert len(layer_mults) == layers, 'layer multipliers must be equal to designated number of layers'
layer_dims = [dim * mult for mult in layer_mults]
dims = (dim, *layer_dims)
codebook_dim = layer_dims[-1]
self.encoded_dim = dims[-1]
dim_pairs = zip(dims[:-1], dims[1:])
append = lambda arr, t: arr.append(t)
prepend = lambda arr, t: arr.insert(0, t)
if not isinstance(num_resnet_blocks, tuple):
num_resnet_blocks = (*((0,) * (layers - 1)), num_resnet_blocks)
if not isinstance(use_attn, tuple):
use_attn = (*((False,) * (layers - 1)), use_attn)
assert len(num_resnet_blocks) == layers, 'number of resnet blocks config must be equal to number of layers'
assert len(use_attn) == layers
for layer_index, (dim_in, dim_out), layer_num_resnet_blocks, layer_use_attn in zip(range(layers), dim_pairs, num_resnet_blocks, use_attn):
append(self.encoders, nn.Sequential(nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), leaky_relu()))
prepend(self.decoders, nn.Sequential(nn.ConvTranspose2d(dim_out, dim_in, 4, 2, 1), leaky_relu()))
if layer_use_attn:
prepend(self.decoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
for _ in range(layer_num_resnet_blocks):
append(self.encoders, ResBlock(dim_out, groups = resnet_groups))
prepend(self.decoders, GLUResBlock(dim_out, groups = resnet_groups))
if layer_use_attn:
append(self.encoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
prepend(self.encoders, nn.Conv2d(channels, dim, first_conv_kernel_size, padding = first_conv_kernel_size // 2))
append(self.decoders, nn.Conv2d(dim, channels, 1))
self.enc_dec = enc_dec_klass(
dim = dim,
channels = channels,
layers = layers,
**encdec_kwargs
)
self.vq = VQ(
dim = codebook_dim,
dim = self.enc_dec.encoded_dim,
codebook_dim = vq_codebook_dim,
codebook_size = vq_codebook_size,
decay = vq_decay,
commitment_weight = vq_commitment_weight,
@@ -427,13 +621,21 @@ class VQGanVAE(nn.Module):
# gan related losses
layer_mults = list(map(lambda t: 2 ** t, range(discr_layers)))
layer_dims = [dim * mult for mult in layer_mults]
dims = (dim, *layer_dims)
self.discr = Discriminator(dims = dims, channels = channels)
self.discr_loss = hinge_discr_loss if use_hinge_loss else bce_discr_loss
self.gen_loss = hinge_gen_loss if use_hinge_loss else bce_gen_loss
@property
def encoded_dim(self):
return self.enc_dec.encoded_dim
def get_encoded_fmap_size(self, image_size):
return image_size // (2 ** self.layers)
return self.enc_dec.get_encoded_fmap_size(image_size)
def copy_for_eval(self):
device = next(self.parameters()).device
@@ -459,16 +661,13 @@ class VQGanVAE(nn.Module):
return self.vq.codebook
def encode(self, fmap):
for enc in self.encoders:
fmap = enc(fmap)
fmap = self.enc_dec.encode(fmap)
return fmap
def decode(self, fmap, return_indices_and_loss = False):
fmap, indices, commit_loss = self.vq(fmap)
for dec in self.decoders:
fmap = dec(fmap)
fmap = self.enc_dec.decode(fmap)
if not return_indices_and_loss:
return fmap
@@ -548,7 +747,7 @@ class VQGanVAE(nn.Module):
# calculate adaptive weight
last_dec_layer = self.decoders[-1].weight
last_dec_layer = self.enc_dec.last_dec_layer
norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_dec_layer).norm(p = 2)
norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_dec_layer).norm(p = 2)

View File

@@ -10,11 +10,12 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.51',
version = '0.2.31',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',
author_email = 'lucidrains@gmail.com',
long_description_content_type = 'text/markdown',
url = 'https://github.com/lucidrains/dalle2-pytorch',
keywords = [
'artificial intelligence',
@@ -23,16 +24,23 @@ setup(
],
install_requires=[
'click',
'clip-anytorch',
'coca-pytorch>=0.0.5',
'einops>=0.4',
'einops-exts>=0.0.3',
'embedding-reader',
'kornia>=0.5.4',
'pillow',
'resize-right>=0.0.2',
'rotary-embedding-torch',
'torch>=1.10',
'torchvision',
'tqdm',
'vector-quantize-pytorch',
'x-clip>=0.4.4',
'youtokentome'
'youtokentome',
'webdataset>=0.2.5',
'fsspec>=2022.1.0'
],
classifiers=[
'Development Status :: 4 - Beta',

408
train_diffusion_prior.py Normal file
View File

@@ -0,0 +1,408 @@
from pathlib import Path
import click
import math
import time
import numpy as np
import torch
from torch import nn
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork
from dalle2_pytorch.train import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model, print_ribbon
from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker
from embedding_reader import EmbeddingReader
from tqdm import tqdm
# constants
NUM_TEST_EMBEDDINGS = 100 # for cosine similarity reporting during training
REPORT_METRICS_EVERY = 100 # for cosine similarity and other metric reporting during training
tracker = WandbTracker()
# helpers functions
def exists(val):
val is not None
class Timer:
def __init__(self):
self.reset()
def reset(self):
self.last_time = time.time()
def elapsed(self):
return time.time() - self.last_time
# functions
def eval_model(model,device,image_reader,text_reader,start,end,batch_size,loss_type,phase="Validation"):
model.eval()
with torch.no_grad():
total_loss = 0.
total_samples = 0.
for emb_images, emb_text in zip(image_reader(batch_size=batch_size, start=start, end=end),
text_reader(batch_size=batch_size, start=start, end=end)):
emb_images_tensor = torch.tensor(emb_images[0]).to(device)
emb_text_tensor = torch.tensor(emb_text[0]).to(device)
batches = emb_images_tensor.shape[0]
loss = model(text_embed = emb_text_tensor, image_embed = emb_images_tensor)
total_loss += loss.item() * batches
total_samples += batches
avg_loss = (total_loss / total_samples)
tracker.log({f'{phase} {loss_type}': avg_loss})
def report_cosine_sims(diffusion_prior,image_reader,text_reader,train_set_size,NUM_TEST_EMBEDDINGS,device):
diffusion_prior.eval()
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
tstart = train_set_size
tend = train_set_size+NUM_TEST_EMBEDDINGS
for embt, embi in zip(text_reader(batch_size=NUM_TEST_EMBEDDINGS, start=tstart, end=tend),
image_reader(batch_size=NUM_TEST_EMBEDDINGS, start=tstart, end=tend)):
# make a copy of the text embeddings for shuffling
text_embed = torch.tensor(embt[0]).to(device)
text_embed_shuffled = text_embed.clone()
# roll the text embeddings to simulate "unrelated" captions
rolled_idx = torch.roll(torch.arange(NUM_TEST_EMBEDDINGS), 1)
text_embed_shuffled = text_embed_shuffled[rolled_idx]
text_embed_shuffled = text_embed_shuffled / \
text_embed_shuffled.norm(dim=1, keepdim=True)
test_text_shuffled_cond = dict(text_embed=text_embed_shuffled)
# prepare the text embedding
text_embed = text_embed / text_embed.norm(dim=1, keepdim=True)
test_text_cond = dict(text_embed=text_embed)
# prepare image embeddings
test_image_embeddings = torch.tensor(embi[0]).to(device)
test_image_embeddings = test_image_embeddings / \
test_image_embeddings.norm(dim=1, keepdim=True)
# predict on the unshuffled text embeddings
predicted_image_embeddings = diffusion_prior.p_sample_loop(
(NUM_TEST_EMBEDDINGS, 768), text_cond=test_text_cond)
predicted_image_embeddings = predicted_image_embeddings / \
predicted_image_embeddings.norm(dim=1, keepdim=True)
# predict on the shuffled embeddings
predicted_unrelated_embeddings = diffusion_prior.p_sample_loop(
(NUM_TEST_EMBEDDINGS, 768), text_cond=test_text_shuffled_cond)
predicted_unrelated_embeddings = predicted_unrelated_embeddings / \
predicted_unrelated_embeddings.norm(dim=1, keepdim=True)
# calculate similarities
original_similarity = cos(
text_embed, test_image_embeddings).cpu().numpy()
predicted_similarity = cos(
text_embed, predicted_image_embeddings).cpu().numpy()
unrelated_similarity = cos(
text_embed, predicted_unrelated_embeddings).cpu().numpy()
predicted_img_similarity = cos(
test_image_embeddings, predicted_image_embeddings).cpu().numpy()
tracker.log({"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity),
"CosineSimilarity(text_embed,predicted_image_embed)":np.mean(predicted_similarity),
"CosineSimilarity(orig_image_embed,predicted_image_embed)":np.mean(predicted_img_similarity),
"CosineSimilarity(text_embed,predicted_unrelated_embed)": np.mean(unrelated_similarity),
"Cosine similarity difference":np.mean(predicted_similarity - original_similarity)})
def train(image_embed_dim,
image_embed_url,
text_embed_url,
batch_size,
train_percent,
val_percent,
test_percent,
num_epochs,
dp_loss_type,
clip,
dp_condition_on_text_encodings,
dp_timesteps,
dp_normformer,
dp_cond_drop_prob,
dpn_depth,
dpn_dim_head,
dpn_heads,
save_interval,
save_path,
device,
RESUME,
DPRIOR_PATH,
config,
wandb_entity,
wandb_project,
learning_rate=0.001,
max_grad_norm=0.5,
weight_decay=0.01,
dropout=0.05,
amp=False):
# diffusion prior network
prior_network = DiffusionPriorNetwork(
dim = image_embed_dim,
depth = dpn_depth,
dim_head = dpn_dim_head,
heads = dpn_heads,
attn_dropout = dropout,
ff_dropout = dropout,
normformer = dp_normformer
)
# diffusion prior with text embeddings and image embeddings pre-computed
diffusion_prior = DiffusionPrior(
net = prior_network,
clip = clip,
image_embed_dim = image_embed_dim,
timesteps = dp_timesteps,
cond_drop_prob = dp_cond_drop_prob,
loss_type = dp_loss_type,
condition_on_text_encodings = dp_condition_on_text_encodings
)
# Load pre-trained model from DPRIOR_PATH
if RESUME:
diffusion_prior, loaded_obj = load_diffusion_model(DPRIOR_PATH, device)
tracker.init(entity = wandb_entity, project = wandb_project, config = config)
# diffusion prior trainer
trainer = DiffusionPriorTrainer(
diffusion_prior = diffusion_prior,
lr = learning_rate,
wd = weight_decay,
max_grad_norm = max_grad_norm,
amp = amp,
).to(device)
# load optimizer and scaler
if RESUME:
trainer.optimizer.load_state_dict(loaded_obj['optimizer'])
trainer.scaler.load_state_dict(loaded_obj['scaler'])
# Create save_path if it doesn't exist
Path(save_path).mkdir(exist_ok = True, parents = True)
# Get image and text embeddings from the servers
print_ribbon("Downloading embeddings - image and text")
image_reader = EmbeddingReader(embeddings_folder=image_embed_url, file_format="npy")
text_reader = EmbeddingReader(embeddings_folder=text_embed_url, file_format="npy")
num_data_points = text_reader.count
### Training code ###
timer = Timer()
epochs = num_epochs
train_set_size = int(train_percent*num_data_points)
val_set_size = int(val_percent*num_data_points)
eval_start = train_set_size
for _ in range(epochs):
for emb_images,emb_text in zip(image_reader(batch_size=batch_size, start=0, end=train_set_size),
text_reader(batch_size=batch_size, start=0, end=train_set_size)):
trainer.train()
emb_images_tensor = torch.tensor(emb_images[0]).to(device)
emb_text_tensor = torch.tensor(emb_text[0]).to(device)
loss = trainer(text_embed = emb_text_tensor, image_embed = emb_images_tensor)
# Samples per second
samples_per_sec = batch_size * step / timer.elapsed()
# Save checkpoint every save_interval minutes
if(int(timer.elapsed()) >= 60 * save_interval):
timer.reset()
save_diffusion_model(
save_path,
diffusion_prior,
trainer.optimizer,
trainer.scaler,
config,
image_embed_dim)
# Log to wandb
tracker.log({"Training loss": loss.item(),
"Steps": step,
"Samples per second": samples_per_sec})
# Log cosineSim(text_embed,predicted_image_embed) - cosineSim(text_embed,image_embed)
# Use NUM_TEST_EMBEDDINGS samples from the test set each time
# Get embeddings from the most recently saved model
if(step % REPORT_METRICS_EVERY) == 0:
report_cosine_sims(diffusion_prior,
image_reader,
text_reader,
train_set_size,
NUM_TEST_EMBEDDINGS,
device)
### Evaluate model(validation run) ###
eval_model(diffusion_prior,
device,
image_reader,
text_reader,
eval_start,
eval_start+NUM_TEST_EMBEDDINGS,
NUM_TEST_EMBEDDINGS,
dp_loss_type,
phase="Validation")
trainer.update()
### Test run ###
test_set_size = int(test_percent*train_set_size)
start = train_set_size+val_set_size
end = num_data_points
eval_model(diffusion_prior,device,image_reader,text_reader,start,end,batch_size,dp_loss_type,phase="Test")
@click.command()
@click.option("--wandb-entity", default="laion")
@click.option("--wandb-project", default="diffusion-prior")
@click.option("--wandb-dataset", default="LAION-5B")
@click.option("--wandb-arch", default="DiffusionPrior")
@click.option("--image-embed-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/")
@click.option("--text-embed-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/")
@click.option("--learning-rate", default=1.1e-4)
@click.option("--weight-decay", default=6.02e-2)
@click.option("--dropout", default=5e-2)
@click.option("--max-grad-norm", default=0.5)
@click.option("--batch-size", default=10**4)
@click.option("--num-epochs", default=5)
@click.option("--image-embed-dim", default=768)
@click.option("--train-percent", default=0.7)
@click.option("--val-percent", default=0.2)
@click.option("--test-percent", default=0.1)
@click.option("--dpn-depth", default=6)
@click.option("--dpn-dim-head", default=64)
@click.option("--dpn-heads", default=8)
@click.option("--dp-condition-on-text-encodings", default=False)
@click.option("--dp-timesteps", default=100)
@click.option("--dp-normformer", default=False)
@click.option("--dp-cond-drop-prob", default=0.1)
@click.option("--dp-loss-type", default="l2")
@click.option("--clip", default=None)
@click.option("--amp", default=False)
@click.option("--save-interval", default=30)
@click.option("--save-path", default="./diffusion_prior_checkpoints")
@click.option("--pretrained-model-path", default=None)
def main(
wandb_entity,
wandb_project,
wandb_dataset,
wandb_arch,
image_embed_url,
text_embed_url,
learning_rate,
weight_decay,
dropout,
max_grad_norm,
batch_size,
num_epochs,
image_embed_dim,
train_percent,
val_percent,
test_percent,
dpn_depth,
dpn_dim_head,
dpn_heads,
dp_condition_on_text_encodings,
dp_timesteps,
dp_normformer,
dp_cond_drop_prob,
dp_loss_type,
clip,
amp,
save_interval,
save_path,
pretrained_model_path
):
config = {
"learning_rate": learning_rate,
"architecture": wandb_arch,
"dataset": wandb_dataset,
"weight_decay": weight_decay,
"max_gradient_clipping_norm": max_grad_norm,
"batch_size": batch_size,
"epochs": num_epochs,
"diffusion_prior_network": {
"depth": dpn_depth,
"dim_head": dpn_dim_head,
"heads": dpn_heads,
"normformer": dp_normformer
},
"diffusion_prior": {
"condition_on_text_encodings": dp_condition_on_text_encodings,
"timesteps": dp_timesteps,
"cond_drop_prob": dp_cond_drop_prob,
"loss_type": dp_loss_type,
"clip": clip
}
}
# Check if DPRIOR_PATH exists(saved model path)
DPRIOR_PATH = args.pretrained_model_path
RESUME = exists(DPRIOR_PATH)
if not RESUME:
tracker.init(
entity = wandb_entity,
project = wandb_project,
config = config
)
# Obtain the utilized device.
has_cuda = torch.cuda.is_available()
if has_cuda:
device = torch.device("cuda:0")
torch.cuda.set_device(device)
# Training loop
train(image_embed_dim,
image_embed_url,
text_embed_url,
batch_size,
train_percent,
val_percent,
test_percent,
num_epochs,
dp_loss_type,
clip,
dp_condition_on_text_encodings,
dp_timesteps,
dp_normformer,
dp_cond_drop_prob,
dpn_depth,
dpn_dim_head,
dpn_heads,
save_interval,
save_path,
device,
RESUME,
DPRIOR_PATH,
config,
wandb_entity,
wandb_project,
learning_rate,
max_grad_norm,
weight_decay,
dropout,
amp)
if __name__ == "__main__":
main()