mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-15 00:15:07 +01:00
Compare commits
64 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6651eafa93 | ||
|
|
e6bb75e5ab | ||
|
|
b4c3e5b854 | ||
|
|
b7f9607258 | ||
|
|
2219348a6e | ||
|
|
9eea9b9862 | ||
|
|
5d958713c0 | ||
|
|
0f31980362 | ||
|
|
bee5bf3815 | ||
|
|
350a3d6045 | ||
|
|
1a81670718 | ||
|
|
934c9728dc | ||
|
|
ce4b0107c1 | ||
|
|
64c2f9c4eb | ||
|
|
22cc613278 | ||
|
|
83517849e5 | ||
|
|
708809ed6c | ||
|
|
9cc475f6e7 | ||
|
|
ffd342e9d0 | ||
|
|
f8bfd3493a | ||
|
|
9025345e29 | ||
|
|
8cc278447e | ||
|
|
38cd62010c | ||
|
|
1cc288af39 | ||
|
|
a851168633 | ||
|
|
1ffeecd0ca | ||
|
|
3df899f7a4 | ||
|
|
09534119a1 | ||
|
|
6f8b90d4d7 | ||
|
|
b588286288 | ||
|
|
b693e0be03 | ||
|
|
a0bed30a84 | ||
|
|
387c5bf774 | ||
|
|
a13d2d89c5 | ||
|
|
44d4b1bba9 | ||
|
|
f12a7589c5 | ||
|
|
b8af2210df | ||
|
|
f4fe6c570d | ||
|
|
645e207441 | ||
|
|
00743b3a0b | ||
|
|
01589aff6a | ||
|
|
7ecfd76cc0 | ||
|
|
6161b61c55 | ||
|
|
1ed0f9d80b | ||
|
|
f326a95e26 | ||
|
|
d7a0a2ce4b | ||
|
|
f23fab7ef7 | ||
|
|
857b9fbf1e | ||
|
|
8864fd0aa7 | ||
|
|
72bf159331 | ||
|
|
e5e47cfecb | ||
|
|
fa533962bd | ||
|
|
276abf337b | ||
|
|
ae42d03006 | ||
|
|
4d346e98d9 | ||
|
|
2b1fd1ad2e | ||
|
|
82a2ef37d9 | ||
|
|
5c397c9d66 | ||
|
|
0f4edff214 | ||
|
|
501a8c7c46 | ||
|
|
4e49373fc5 | ||
|
|
49de72040c | ||
|
|
271a376eaf | ||
|
|
e527002472 |
45
README.md
45
README.md
@@ -12,7 +12,7 @@ This model is SOTA for text-to-image for now.
|
|||||||
|
|
||||||
Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord" src="https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white"></a> if you are interested in helping out with the replication with the <a href="https://laion.ai/">LAION</a> community | <a href="https://www.youtube.com/watch?v=AIOE1l1W0Tw">Yannic Interview</a>
|
Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord" src="https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white"></a> if you are interested in helping out with the replication with the <a href="https://laion.ai/">LAION</a> community | <a href="https://www.youtube.com/watch?v=AIOE1l1W0Tw">Yannic Interview</a>
|
||||||
|
|
||||||
There was enough interest for a <a href="https://github.com/lucidrains/dalle2-jax">Jax version</a>. I will also eventually extend this to <a href="https://github.com/lucidrains/dalle2-video">text to video</a>, once the repository is in a good place.
|
As of 5/23/22, it is no longer SOTA. SOTA will be <a href="https://github.com/lucidrains/imagen-pytorch">here</a>. Jax versions as well as text-to-video project will be shifted towards the Imagen architecture, as it is way simpler.
|
||||||
|
|
||||||
## Status
|
## Status
|
||||||
|
|
||||||
@@ -24,6 +24,13 @@ There was enough interest for a <a href="https://github.com/lucidrains/dalle2-ja
|
|||||||
|
|
||||||
*ongoing at 21k steps*
|
*ongoing at 21k steps*
|
||||||
|
|
||||||
|
- <a href="https://twitter.com/Buntworthy/status/1529475416775434240?t=0GEge3Kr9I36cjcUVCQUTg">Justin Pinkney</a> successfully trained the diffusion prior in the repository for his CLIP to Stylegan2 text-to-image application
|
||||||
|
|
||||||
|
## Pre-Trained Models
|
||||||
|
- LAION is training prior models. Checkpoints are available on <a href="https://huggingface.co/zenglishuci/conditioned-prior">🤗huggingface</a> and the training statistics are available on <a href="https://wandb.ai/nousr_laion/conditioned-prior/reports/LAION-DALLE2-PyTorch-Prior--VmlldzoyMDI2OTIx">🐝WANDB</a>.
|
||||||
|
- Decoder - <a href="https://wandb.ai/veldrovive/dalle2_train_decoder/runs/jkrtg0so?workspace=user-veldrovive">In-progress test run</a> 🚧
|
||||||
|
- DALL-E 2 🚧
|
||||||
|
|
||||||
## Install
|
## Install
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
@@ -936,7 +943,7 @@ from dalle2_pytorch.dataloaders import ImageEmbeddingDataset, create_image_embed
|
|||||||
|
|
||||||
# Create a dataloader directly.
|
# Create a dataloader directly.
|
||||||
dataloader = create_image_embedding_dataloader(
|
dataloader = create_image_embedding_dataloader(
|
||||||
tar_url="/path/or/url/to/webdataset/{0000..9999}.tar", # Uses braket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar
|
tar_url="/path/or/url/to/webdataset/{0000..9999}.tar", # Uses bracket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar
|
||||||
embeddings_url="path/or/url/to/embeddings/folder", # Included if .npy files are not in webdataset. Left out or set to None otherwise
|
embeddings_url="path/or/url/to/embeddings/folder", # Included if .npy files are not in webdataset. Left out or set to None otherwise
|
||||||
num_workers=4,
|
num_workers=4,
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
@@ -1043,6 +1050,7 @@ This library would not have gotten to this working state without the help of
|
|||||||
- <a href="https://github.com/rom1504">Romain</a> for the pull request reviews and project management
|
- <a href="https://github.com/rom1504">Romain</a> for the pull request reviews and project management
|
||||||
- <a href="https://github.com/Ciaohe">He Cao</a> and <a href="https://github.com/xiankgx">xiankgx</a> for the Q&A and for identifying of critical bugs
|
- <a href="https://github.com/Ciaohe">He Cao</a> and <a href="https://github.com/xiankgx">xiankgx</a> for the Q&A and for identifying of critical bugs
|
||||||
- <a href="https://github.com/crowsonkb">Katherine</a> for her advice
|
- <a href="https://github.com/crowsonkb">Katherine</a> for her advice
|
||||||
|
- <a href="https://stability.ai/">Stability AI</a> for the generous sponsorship
|
||||||
|
|
||||||
... and many others. Thank you! 🙏
|
... and many others. Thank you! 🙏
|
||||||
|
|
||||||
@@ -1077,21 +1085,21 @@ This library would not have gotten to this working state without the help of
|
|||||||
- [x] cross embed layers for downsampling, as an option
|
- [x] cross embed layers for downsampling, as an option
|
||||||
- [x] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a>
|
- [x] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a>
|
||||||
- [x] use pydantic for config drive training
|
- [x] use pydantic for config drive training
|
||||||
|
- [x] for both diffusion prior and decoder, all exponential moving averaged models needs to be saved and restored as well (as well as the step number)
|
||||||
|
- [x] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
|
||||||
|
- [x] allow for creation of diffusion prior model off pydantic config classes - consider the same for tracker configs
|
||||||
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
|
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
|
||||||
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
|
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
|
||||||
- [ ] train on a toy task, offer in colab
|
- [ ] train on a toy task, offer in colab
|
||||||
- [ ] think about how best to design a declarative training config that handles preencoding for prior and training of multiple networks in decoder
|
- [ ] think about how best to design a declarative training config that handles preencoding for prior and training of multiple networks in decoder
|
||||||
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
|
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
|
||||||
- [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824
|
- [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824
|
||||||
- [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove
|
- [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove https://arxiv.org/abs/2204.01697
|
||||||
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
|
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
|
||||||
- [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783
|
- [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783
|
||||||
- [ ] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
|
- [ ] bring in skip-layer excitations (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training
|
||||||
- [ ] bring in skip-layer excitatons (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training
|
|
||||||
- [ ] decoder needs one day worth of refactor for tech debt
|
- [ ] decoder needs one day worth of refactor for tech debt
|
||||||
- [ ] allow for unet to be able to condition non-cross attention style as well
|
- [ ] allow for unet to be able to condition non-cross attention style as well
|
||||||
- [ ] for all model classes with hyperparameters that changes the network architecture, make it requirement that they must expose a config property, and write a simple function that asserts that it restores the object correctly
|
|
||||||
- [ ] for both diffusion prior and decoder, all exponential moving averaged models needs to be saved and restored as well (as well as the step number)
|
|
||||||
- [ ] read the paper, figure it out, and build it https://github.com/lucidrains/DALLE2-pytorch/issues/89
|
- [ ] read the paper, figure it out, and build it https://github.com/lucidrains/DALLE2-pytorch/issues/89
|
||||||
|
|
||||||
## Citations
|
## Citations
|
||||||
@@ -1135,8 +1143,9 @@ This library would not have gotten to this working state without the help of
|
|||||||
```bibtex
|
```bibtex
|
||||||
@inproceedings{Tu2022MaxViTMV,
|
@inproceedings{Tu2022MaxViTMV,
|
||||||
title = {MaxViT: Multi-Axis Vision Transformer},
|
title = {MaxViT: Multi-Axis Vision Transformer},
|
||||||
author = {Zhe-Wei Tu and Hossein Talebi and Han Zhang and Feng Yang and Peyman Milanfar and Alan Conrad Bovik and Yinxiao Li},
|
author = {Zhengzhong Tu and Hossein Talebi and Han Zhang and Feng Yang and Peyman Milanfar and Alan Conrad Bovik and Yinxiao Li},
|
||||||
year = {2022}
|
year = {2022},
|
||||||
|
url = {https://arxiv.org/abs/2204.01697}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -1190,4 +1199,22 @@ This library would not have gotten to this working state without the help of
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@misc{Saharia2022,
|
||||||
|
title = {Imagen: unprecedented photorealism × deep level of language understanding},
|
||||||
|
author = {Chitwan Saharia*, William Chan*, Saurabh Saxena†, Lala Li†, Jay Whang†, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S. Sara Mahdavi, Rapha Gontijo Lopes, Tim Salimans, Jonathan Ho†, David Fleet†, Mohammad Norouzi*},
|
||||||
|
year = {2022}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@article{Choi2022PerceptionPT,
|
||||||
|
title = {Perception Prioritized Training of Diffusion Models},
|
||||||
|
author = {Jooyoung Choi and Jungbeom Lee and Chaehun Shin and Sungwon Kim and Hyunwoo J. Kim and Sung-Hoon Yoon},
|
||||||
|
journal = {ArXiv},
|
||||||
|
year = {2022},
|
||||||
|
volume = {abs/2204.00227}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>
|
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>
|
||||||
|
|||||||
@@ -6,9 +6,10 @@ For more complex configuration, we provide the option of using a configuration f
|
|||||||
|
|
||||||
The decoder trainer has 7 main configuration options. A full example of their use can be found in the [example decoder configuration](train_decoder_config.example.json).
|
The decoder trainer has 7 main configuration options. A full example of their use can be found in the [example decoder configuration](train_decoder_config.example.json).
|
||||||
|
|
||||||
**<ins>Unets</ins>:**
|
**<ins>Unet</ins>:**
|
||||||
|
|
||||||
|
This is a single unet config, which belongs as an array nested under the decoder config as a list of `unets`
|
||||||
|
|
||||||
Each member of this array defines a single unet that will be added to the decoder.
|
|
||||||
| Option | Required | Default | Description |
|
| Option | Required | Default | Description |
|
||||||
| ------ | -------- | ------- | ----------- |
|
| ------ | -------- | ------- | ----------- |
|
||||||
| `dim` | Yes | N/A | The starting channels of the unet. |
|
| `dim` | Yes | N/A | The starting channels of the unet. |
|
||||||
@@ -22,6 +23,7 @@ Any parameter from the `Unet` constructor can also be given here.
|
|||||||
Defines the configuration options for the decoder model. The unets defined above will automatically be inserted.
|
Defines the configuration options for the decoder model. The unets defined above will automatically be inserted.
|
||||||
| Option | Required | Default | Description |
|
| Option | Required | Default | Description |
|
||||||
| ------ | -------- | ------- | ----------- |
|
| ------ | -------- | ------- | ----------- |
|
||||||
|
| `unets` | Yes | N/A | A list of unets, using the configuration above |
|
||||||
| `image_sizes` | Yes | N/A | The resolution of the image after each upsampling step. The length of this array should be the number of unets defined. |
|
| `image_sizes` | Yes | N/A | The resolution of the image after each upsampling step. The length of this array should be the number of unets defined. |
|
||||||
| `image_size` | Yes | N/A | Not used. Can be any number. |
|
| `image_size` | Yes | N/A | Not used. Can be any number. |
|
||||||
| `timesteps` | No | `1000` | The number of diffusion timesteps used for generation. |
|
| `timesteps` | No | `1000` | The number of diffusion timesteps used for generation. |
|
||||||
@@ -81,7 +83,7 @@ Defines which evaluation metrics will be used to test the model.
|
|||||||
Each metric can be enabled by setting its configuration. The configuration keys for each metric are defined by the torchmetrics constructors which will be linked.
|
Each metric can be enabled by setting its configuration. The configuration keys for each metric are defined by the torchmetrics constructors which will be linked.
|
||||||
| Option | Required | Default | Description |
|
| Option | Required | Default | Description |
|
||||||
| ------ | -------- | ------- | ----------- |
|
| ------ | -------- | ------- | ----------- |
|
||||||
| `n_evalation_samples` | No | `1000` | The number of samples to generate to test the model. |
|
| `n_evaluation_samples` | No | `1000` | The number of samples to generate to test the model. |
|
||||||
| `FID` | No | `None` | Setting to an object enables the [Frechet Inception Distance](https://torchmetrics.readthedocs.io/en/stable/image/frechet_inception_distance.html) metric.
|
| `FID` | No | `None` | Setting to an object enables the [Frechet Inception Distance](https://torchmetrics.readthedocs.io/en/stable/image/frechet_inception_distance.html) metric.
|
||||||
| `IS` | No | `None` | Setting to an object enables the [Inception Score](https://torchmetrics.readthedocs.io/en/stable/image/inception_score.html) metric.
|
| `IS` | No | `None` | Setting to an object enables the [Inception Score](https://torchmetrics.readthedocs.io/en/stable/image/inception_score.html) metric.
|
||||||
| `KID` | No | `None` | Setting to an object enables the [Kernel Inception Distance](https://torchmetrics.readthedocs.io/en/stable/image/kernel_inception_distance.html) metric. |
|
| `KID` | No | `None` | Setting to an object enables the [Kernel Inception Distance](https://torchmetrics.readthedocs.io/en/stable/image/kernel_inception_distance.html) metric. |
|
||||||
|
|||||||
@@ -1,16 +1,16 @@
|
|||||||
{
|
{
|
||||||
"unets": [
|
|
||||||
{
|
|
||||||
"dim": 128,
|
|
||||||
"image_embed_dim": 768,
|
|
||||||
"cond_dim": 64,
|
|
||||||
"channels": 3,
|
|
||||||
"dim_mults": [1, 2, 4, 8],
|
|
||||||
"attn_dim_head": 32,
|
|
||||||
"attn_heads": 16
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"decoder": {
|
"decoder": {
|
||||||
|
"unets": [
|
||||||
|
{
|
||||||
|
"dim": 128,
|
||||||
|
"image_embed_dim": 768,
|
||||||
|
"cond_dim": 64,
|
||||||
|
"channels": 3,
|
||||||
|
"dim_mults": [1, 2, 4, 8],
|
||||||
|
"attn_dim_head": 32,
|
||||||
|
"attn_heads": 16
|
||||||
|
}
|
||||||
|
],
|
||||||
"image_sizes": [64],
|
"image_sizes": [64],
|
||||||
"channels": 3,
|
"channels": 3,
|
||||||
"timesteps": 1000,
|
"timesteps": 1000,
|
||||||
|
|||||||
70
configs/train_prior_config.example.json
Normal file
70
configs/train_prior_config.example.json
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
{
|
||||||
|
"prior": {
|
||||||
|
"clip": {
|
||||||
|
"make": "x-clip",
|
||||||
|
"model": "ViT-L/14",
|
||||||
|
"base_model_kwargs": {
|
||||||
|
"dim_text": 768,
|
||||||
|
"dim_image": 768,
|
||||||
|
"dim_latent": 768
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"net": {
|
||||||
|
"dim": 768,
|
||||||
|
"depth": 12,
|
||||||
|
"num_timesteps": 1000,
|
||||||
|
"num_time_embeds": 1,
|
||||||
|
"num_image_embeds": 1,
|
||||||
|
"num_text_embeds": 1,
|
||||||
|
"dim_head": 64,
|
||||||
|
"heads": 12,
|
||||||
|
"ff_mult": 4,
|
||||||
|
"norm_out": true,
|
||||||
|
"attn_dropout": 0.0,
|
||||||
|
"ff_dropout": 0.0,
|
||||||
|
"final_proj": true,
|
||||||
|
"normformer": true,
|
||||||
|
"rotary_emb": true
|
||||||
|
},
|
||||||
|
"image_embed_dim": 768,
|
||||||
|
"image_size": 224,
|
||||||
|
"image_channels": 3,
|
||||||
|
"timesteps": 1000,
|
||||||
|
"cond_drop_prob": 0.1,
|
||||||
|
"loss_type": "l2",
|
||||||
|
"predict_x_start": true,
|
||||||
|
"beta_schedule": "cosine",
|
||||||
|
"condition_on_text_encodings": true
|
||||||
|
},
|
||||||
|
"data": {
|
||||||
|
"image_url": "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/",
|
||||||
|
"text_url": "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/",
|
||||||
|
"meta_url": "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/laion2B-en-metadata/",
|
||||||
|
"batch_size": 256,
|
||||||
|
"splits": {
|
||||||
|
"train": 0.9,
|
||||||
|
"val": 1e-7,
|
||||||
|
"test": 0.0999999
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"train": {
|
||||||
|
"epochs": 1,
|
||||||
|
"lr": 1.1e-4,
|
||||||
|
"wd": 6.02e-2,
|
||||||
|
"max_grad_norm": 0.5,
|
||||||
|
"use_ema": true,
|
||||||
|
"amp": false,
|
||||||
|
"save_every": 10000
|
||||||
|
},
|
||||||
|
"load": {
|
||||||
|
"source": null,
|
||||||
|
"resume": false
|
||||||
|
},
|
||||||
|
"tracker": {
|
||||||
|
"tracker_type": "wandb",
|
||||||
|
"data_path": "./prior_checkpoints",
|
||||||
|
"wandb_entity": "laion",
|
||||||
|
"wandb_project": "diffusion-prior",
|
||||||
|
"verbose": true
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,3 +1,4 @@
|
|||||||
|
from dalle2_pytorch.version import __version__
|
||||||
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
|
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
|
||||||
from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter
|
from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter
|
||||||
from dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer
|
from dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import math
|
import math
|
||||||
|
import random
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from inspect import isfunction
|
|
||||||
from functools import partial, wraps
|
from functools import partial, wraps
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
@@ -11,7 +11,7 @@ import torch.nn.functional as F
|
|||||||
from torch import nn, einsum
|
from torch import nn, einsum
|
||||||
import torchvision.transforms as T
|
import torchvision.transforms as T
|
||||||
|
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat, reduce
|
||||||
from einops.layers.torch import Rearrange
|
from einops.layers.torch import Rearrange
|
||||||
from einops_exts import rearrange_many, repeat_many, check_shape
|
from einops_exts import rearrange_many, repeat_many, check_shape
|
||||||
from einops_exts.torch import EinopsToAndFrom
|
from einops_exts.torch import EinopsToAndFrom
|
||||||
@@ -56,7 +56,7 @@ def maybe(fn):
|
|||||||
def default(val, d):
|
def default(val, d):
|
||||||
if exists(val):
|
if exists(val):
|
||||||
return val
|
return val
|
||||||
return d() if isfunction(d) else d
|
return d() if callable(d) else d
|
||||||
|
|
||||||
def cast_tuple(val, length = 1):
|
def cast_tuple(val, length = 1):
|
||||||
if isinstance(val, list):
|
if isinstance(val, list):
|
||||||
@@ -313,11 +313,6 @@ def extract(a, t, x_shape):
|
|||||||
out = a.gather(-1, t)
|
out = a.gather(-1, t)
|
||||||
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
||||||
|
|
||||||
def noise_like(shape, device, repeat=False):
|
|
||||||
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
|
|
||||||
noise = lambda: torch.randn(shape, device=device)
|
|
||||||
return repeat_noise() if repeat else noise()
|
|
||||||
|
|
||||||
def meanflat(x):
|
def meanflat(x):
|
||||||
return x.mean(dim = tuple(range(1, len(x.shape))))
|
return x.mean(dim = tuple(range(1, len(x.shape))))
|
||||||
|
|
||||||
@@ -372,7 +367,7 @@ def quadratic_beta_schedule(timesteps):
|
|||||||
scale = 1000 / timesteps
|
scale = 1000 / timesteps
|
||||||
beta_start = scale * 0.0001
|
beta_start = scale * 0.0001
|
||||||
beta_end = scale * 0.02
|
beta_end = scale * 0.02
|
||||||
return torch.linspace(beta_start**2, beta_end**2, timesteps, dtype = torch.float64) ** 2
|
return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype = torch.float64) ** 2
|
||||||
|
|
||||||
|
|
||||||
def sigmoid_beta_schedule(timesteps):
|
def sigmoid_beta_schedule(timesteps):
|
||||||
@@ -384,7 +379,7 @@ def sigmoid_beta_schedule(timesteps):
|
|||||||
|
|
||||||
|
|
||||||
class BaseGaussianDiffusion(nn.Module):
|
class BaseGaussianDiffusion(nn.Module):
|
||||||
def __init__(self, *, beta_schedule, timesteps, loss_type):
|
def __init__(self, *, beta_schedule, timesteps, loss_type, p2_loss_weight_gamma = 0., p2_loss_weight_k = 1):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if beta_schedule == "cosine":
|
if beta_schedule == "cosine":
|
||||||
@@ -449,6 +444,11 @@ class BaseGaussianDiffusion(nn.Module):
|
|||||||
register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
||||||
register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
|
register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
|
||||||
|
|
||||||
|
# p2 loss reweighting
|
||||||
|
|
||||||
|
self.has_p2_loss_reweighting = p2_loss_weight_gamma > 0.
|
||||||
|
register_buffer('p2_loss_weight', (p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod)) ** -p2_loss_weight_gamma)
|
||||||
|
|
||||||
def q_posterior(self, x_start, x_t, t):
|
def q_posterior(self, x_start, x_t, t):
|
||||||
posterior_mean = (
|
posterior_mean = (
|
||||||
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
|
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
|
||||||
@@ -890,6 +890,8 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if exists(clip):
|
if exists(clip):
|
||||||
|
assert image_channels == clip.image_channels, f'channels of image ({image_channels}) should be equal to the channels that CLIP accepts ({clip.image_channels})'
|
||||||
|
|
||||||
if isinstance(clip, CLIP):
|
if isinstance(clip, CLIP):
|
||||||
clip = XClipAdapter(clip, **clip_adapter_overrides)
|
clip = XClipAdapter(clip, **clip_adapter_overrides)
|
||||||
elif isinstance(clip, CoCa):
|
elif isinstance(clip, CoCa):
|
||||||
@@ -943,10 +945,10 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
|||||||
return model_mean, posterior_variance, posterior_log_variance
|
return model_mean, posterior_variance, posterior_log_variance
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def p_sample(self, x, t, text_cond = None, clip_denoised = True, repeat_noise = False, cond_scale = 1.):
|
def p_sample(self, x, t, text_cond = None, clip_denoised = True, cond_scale = 1.):
|
||||||
b, *_, device = *x.shape, x.device
|
b, *_, device = *x.shape, x.device
|
||||||
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, text_cond = text_cond, clip_denoised = clip_denoised, cond_scale = cond_scale)
|
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, text_cond = text_cond, clip_denoised = clip_denoised, cond_scale = cond_scale)
|
||||||
noise = noise_like(x.shape, device, repeat_noise)
|
noise = torch.randn_like(x)
|
||||||
# no noise when t == 0
|
# no noise when t == 0
|
||||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||||
@@ -1082,8 +1084,9 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
|||||||
def Upsample(dim):
|
def Upsample(dim):
|
||||||
return nn.ConvTranspose2d(dim, dim, 4, 2, 1)
|
return nn.ConvTranspose2d(dim, dim, 4, 2, 1)
|
||||||
|
|
||||||
def Downsample(dim):
|
def Downsample(dim, *, dim_out = None):
|
||||||
return nn.Conv2d(dim, dim, 4, 2, 1)
|
dim_out = default(dim_out, dim)
|
||||||
|
return nn.Conv2d(dim, dim_out, 4, 2, 1)
|
||||||
|
|
||||||
class SinusoidalPosEmb(nn.Module):
|
class SinusoidalPosEmb(nn.Module):
|
||||||
def __init__(self, dim):
|
def __init__(self, dim):
|
||||||
@@ -1105,13 +1108,20 @@ class Block(nn.Module):
|
|||||||
groups = 8
|
groups = 8
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.block = nn.Sequential(
|
self.project = nn.Conv2d(dim, dim_out, 3, padding = 1)
|
||||||
nn.Conv2d(dim, dim_out, 3, padding = 1),
|
self.norm = nn.GroupNorm(groups, dim_out)
|
||||||
nn.GroupNorm(groups, dim_out),
|
self.act = nn.SiLU()
|
||||||
nn.SiLU()
|
|
||||||
)
|
def forward(self, x, scale_shift = None):
|
||||||
def forward(self, x):
|
x = self.project(x)
|
||||||
return self.block(x)
|
x = self.norm(x)
|
||||||
|
|
||||||
|
if exists(scale_shift):
|
||||||
|
scale, shift = scale_shift
|
||||||
|
x = x * (scale + 1) + shift
|
||||||
|
|
||||||
|
x = self.act(x)
|
||||||
|
return x
|
||||||
|
|
||||||
class ResnetBlock(nn.Module):
|
class ResnetBlock(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -1130,7 +1140,7 @@ class ResnetBlock(nn.Module):
|
|||||||
if exists(time_cond_dim):
|
if exists(time_cond_dim):
|
||||||
self.time_mlp = nn.Sequential(
|
self.time_mlp = nn.Sequential(
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
nn.Linear(time_cond_dim, dim_out)
|
nn.Linear(time_cond_dim, dim_out * 2)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.cross_attn = None
|
self.cross_attn = None
|
||||||
@@ -1150,11 +1160,14 @@ class ResnetBlock(nn.Module):
|
|||||||
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
|
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
|
||||||
|
|
||||||
def forward(self, x, cond = None, time_emb = None):
|
def forward(self, x, cond = None, time_emb = None):
|
||||||
h = self.block1(x)
|
|
||||||
|
|
||||||
|
scale_shift = None
|
||||||
if exists(self.time_mlp) and exists(time_emb):
|
if exists(self.time_mlp) and exists(time_emb):
|
||||||
time_emb = self.time_mlp(time_emb)
|
time_emb = self.time_mlp(time_emb)
|
||||||
h = rearrange(time_emb, 'b c -> b c 1 1') + h
|
time_emb = rearrange(time_emb, 'b c -> b c 1 1')
|
||||||
|
scale_shift = time_emb.chunk(2, dim = 1)
|
||||||
|
|
||||||
|
h = self.block1(x, scale_shift = scale_shift)
|
||||||
|
|
||||||
if exists(self.cross_attn):
|
if exists(self.cross_attn):
|
||||||
assert exists(cond)
|
assert exists(cond)
|
||||||
@@ -1331,12 +1344,15 @@ class Unet(nn.Module):
|
|||||||
cond_on_text_encodings = False,
|
cond_on_text_encodings = False,
|
||||||
max_text_len = 256,
|
max_text_len = 256,
|
||||||
cond_on_image_embeds = False,
|
cond_on_image_embeds = False,
|
||||||
|
add_image_embeds_to_time = True, # alerted by @mhh0318 to a phrase in the paper - "Specifically, we modify the architecture described in Nichol et al. (2021) by projecting and adding CLIP embeddings to the existing timestep embedding"
|
||||||
init_dim = None,
|
init_dim = None,
|
||||||
init_conv_kernel_size = 7,
|
init_conv_kernel_size = 7,
|
||||||
resnet_groups = 8,
|
resnet_groups = 8,
|
||||||
|
num_resnet_blocks = 2,
|
||||||
init_cross_embed_kernel_sizes = (3, 7, 15),
|
init_cross_embed_kernel_sizes = (3, 7, 15),
|
||||||
cross_embed_downsample = False,
|
cross_embed_downsample = False,
|
||||||
cross_embed_downsample_kernel_sizes = (2, 4),
|
cross_embed_downsample_kernel_sizes = (2, 4),
|
||||||
|
memory_efficient = False,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -1356,7 +1372,7 @@ class Unet(nn.Module):
|
|||||||
self.channels_out = default(channels_out, channels)
|
self.channels_out = default(channels_out, channels)
|
||||||
|
|
||||||
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
|
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
|
||||||
init_dim = default(init_dim, dim // 3 * 2)
|
init_dim = default(init_dim, dim)
|
||||||
|
|
||||||
self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1)
|
self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1)
|
||||||
|
|
||||||
@@ -1383,11 +1399,16 @@ class Unet(nn.Module):
|
|||||||
nn.Linear(time_cond_dim, time_cond_dim)
|
nn.Linear(time_cond_dim, time_cond_dim)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.image_to_cond = nn.Sequential(
|
self.image_to_tokens = nn.Sequential(
|
||||||
nn.Linear(image_embed_dim, cond_dim * num_image_tokens),
|
nn.Linear(image_embed_dim, cond_dim * num_image_tokens),
|
||||||
Rearrange('b (n d) -> b n d', n = num_image_tokens)
|
Rearrange('b (n d) -> b n d', n = num_image_tokens)
|
||||||
) if cond_on_image_embeds and image_embed_dim != cond_dim else nn.Identity()
|
) if cond_on_image_embeds and image_embed_dim != cond_dim else nn.Identity()
|
||||||
|
|
||||||
|
self.to_image_hiddens = nn.Sequential(
|
||||||
|
nn.Linear(image_embed_dim, time_cond_dim),
|
||||||
|
nn.GELU()
|
||||||
|
) if cond_on_image_embeds and add_image_embeds_to_time else None
|
||||||
|
|
||||||
self.norm_cond = nn.LayerNorm(cond_dim)
|
self.norm_cond = nn.LayerNorm(cond_dim)
|
||||||
self.norm_mid_cond = nn.LayerNorm(cond_dim)
|
self.norm_mid_cond = nn.LayerNorm(cond_dim)
|
||||||
|
|
||||||
@@ -1408,6 +1429,7 @@ class Unet(nn.Module):
|
|||||||
# for classifier free guidance
|
# for classifier free guidance
|
||||||
|
|
||||||
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
|
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
|
||||||
|
self.null_image_hiddens = nn.Parameter(torch.randn(1, time_cond_dim))
|
||||||
|
|
||||||
self.max_text_len = max_text_len
|
self.max_text_len = max_text_len
|
||||||
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim))
|
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim))
|
||||||
@@ -1419,6 +1441,7 @@ class Unet(nn.Module):
|
|||||||
# resnet block klass
|
# resnet block klass
|
||||||
|
|
||||||
resnet_groups = cast_tuple(resnet_groups, len(in_out))
|
resnet_groups = cast_tuple(resnet_groups, len(in_out))
|
||||||
|
num_resnet_blocks = cast_tuple(num_resnet_blocks, len(in_out))
|
||||||
|
|
||||||
assert len(resnet_groups) == len(in_out)
|
assert len(resnet_groups) == len(in_out)
|
||||||
|
|
||||||
@@ -1434,16 +1457,17 @@ class Unet(nn.Module):
|
|||||||
self.ups = nn.ModuleList([])
|
self.ups = nn.ModuleList([])
|
||||||
num_resolutions = len(in_out)
|
num_resolutions = len(in_out)
|
||||||
|
|
||||||
for ind, ((dim_in, dim_out), groups) in enumerate(zip(in_out, resnet_groups)):
|
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks) in enumerate(zip(in_out, resnet_groups, num_resnet_blocks)):
|
||||||
is_first = ind == 0
|
is_first = ind == 0
|
||||||
is_last = ind >= (num_resolutions - 1)
|
is_last = ind >= (num_resolutions - 1)
|
||||||
layer_cond_dim = cond_dim if not is_first else None
|
layer_cond_dim = cond_dim if not is_first else None
|
||||||
|
|
||||||
self.downs.append(nn.ModuleList([
|
self.downs.append(nn.ModuleList([
|
||||||
ResnetBlock(dim_in, dim_out, time_cond_dim = time_cond_dim, groups = groups),
|
downsample_klass(dim_in, dim_out = dim_out) if memory_efficient else None,
|
||||||
|
ResnetBlock(dim_out if memory_efficient else dim_in, dim_out, time_cond_dim = time_cond_dim, groups = groups),
|
||||||
Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
||||||
ResnetBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
|
nn.ModuleList([ResnetBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
|
||||||
downsample_klass(dim_out) if not is_last else nn.Identity()
|
downsample_klass(dim_out) if not is_last and not memory_efficient else None
|
||||||
]))
|
]))
|
||||||
|
|
||||||
mid_dim = dims[-1]
|
mid_dim = dims[-1]
|
||||||
@@ -1452,19 +1476,19 @@ class Unet(nn.Module):
|
|||||||
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
|
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
|
||||||
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
|
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
|
||||||
|
|
||||||
for ind, ((dim_in, dim_out), groups) in enumerate(zip(reversed(in_out[1:]), reversed(resnet_groups))):
|
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks) in enumerate(zip(reversed(in_out), reversed(resnet_groups), reversed(num_resnet_blocks))):
|
||||||
is_last = ind >= (num_resolutions - 2)
|
is_last = ind >= (len(in_out) - 1)
|
||||||
layer_cond_dim = cond_dim if not is_last else None
|
layer_cond_dim = cond_dim if not is_last else None
|
||||||
|
|
||||||
self.ups.append(nn.ModuleList([
|
self.ups.append(nn.ModuleList([
|
||||||
ResnetBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
|
ResnetBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
|
||||||
Residual(LinearAttention(dim_in, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
Residual(LinearAttention(dim_in, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
||||||
ResnetBlock(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
|
nn.ModuleList([ResnetBlock(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
|
||||||
Upsample(dim_in)
|
Upsample(dim_in) if not is_last or memory_efficient else nn.Identity()
|
||||||
]))
|
]))
|
||||||
|
|
||||||
self.final_conv = nn.Sequential(
|
self.final_conv = nn.Sequential(
|
||||||
ResnetBlock(dim, dim, groups = resnet_groups[0]),
|
ResnetBlock(dim * 2, dim, groups = resnet_groups[0]),
|
||||||
nn.Conv2d(dim, self.channels_out, 1)
|
nn.Conv2d(dim, self.channels_out, 1)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1536,6 +1560,7 @@ class Unet(nn.Module):
|
|||||||
# initial convolution
|
# initial convolution
|
||||||
|
|
||||||
x = self.init_conv(x)
|
x = self.init_conv(x)
|
||||||
|
r = x.clone() # final residual
|
||||||
|
|
||||||
# time conditioning
|
# time conditioning
|
||||||
|
|
||||||
@@ -1549,7 +1574,23 @@ class Unet(nn.Module):
|
|||||||
image_keep_mask = prob_mask_like((batch_size,), 1 - image_cond_drop_prob, device = device)
|
image_keep_mask = prob_mask_like((batch_size,), 1 - image_cond_drop_prob, device = device)
|
||||||
text_keep_mask = prob_mask_like((batch_size,), 1 - text_cond_drop_prob, device = device)
|
text_keep_mask = prob_mask_like((batch_size,), 1 - text_cond_drop_prob, device = device)
|
||||||
|
|
||||||
image_keep_mask, text_keep_mask = rearrange_many((image_keep_mask, text_keep_mask), 'b -> b 1 1')
|
text_keep_mask = rearrange(text_keep_mask, 'b -> b 1 1')
|
||||||
|
|
||||||
|
# image embedding to be summed to time embedding
|
||||||
|
# discovered by @mhh0318 in the paper
|
||||||
|
|
||||||
|
if exists(image_embed) and exists(self.to_image_hiddens):
|
||||||
|
image_hiddens = self.to_image_hiddens(image_embed)
|
||||||
|
image_keep_mask_hidden = rearrange(image_keep_mask, 'b -> b 1')
|
||||||
|
null_image_hiddens = self.null_image_hiddens.to(image_hiddens.dtype)
|
||||||
|
|
||||||
|
image_hiddens = torch.where(
|
||||||
|
image_keep_mask_hidden,
|
||||||
|
image_hiddens,
|
||||||
|
null_image_hiddens
|
||||||
|
)
|
||||||
|
|
||||||
|
t = t + image_hiddens
|
||||||
|
|
||||||
# mask out image embedding depending on condition dropout
|
# mask out image embedding depending on condition dropout
|
||||||
# for classifier free guidance
|
# for classifier free guidance
|
||||||
@@ -1557,11 +1598,12 @@ class Unet(nn.Module):
|
|||||||
image_tokens = None
|
image_tokens = None
|
||||||
|
|
||||||
if self.cond_on_image_embeds:
|
if self.cond_on_image_embeds:
|
||||||
image_tokens = self.image_to_cond(image_embed)
|
image_keep_mask_embed = rearrange(image_keep_mask, 'b -> b 1 1')
|
||||||
|
image_tokens = self.image_to_tokens(image_embed)
|
||||||
null_image_embed = self.null_image_embed.to(image_tokens.dtype) # for some reason pytorch AMP not working
|
null_image_embed = self.null_image_embed.to(image_tokens.dtype) # for some reason pytorch AMP not working
|
||||||
|
|
||||||
image_tokens = torch.where(
|
image_tokens = torch.where(
|
||||||
image_keep_mask,
|
image_keep_mask_embed,
|
||||||
image_tokens,
|
image_tokens,
|
||||||
null_image_embed
|
null_image_embed
|
||||||
)
|
)
|
||||||
@@ -1616,12 +1658,20 @@ class Unet(nn.Module):
|
|||||||
|
|
||||||
hiddens = []
|
hiddens = []
|
||||||
|
|
||||||
for block1, sparse_attn, block2, downsample in self.downs:
|
for pre_downsample, init_block, sparse_attn, resnet_blocks, post_downsample in self.downs:
|
||||||
x = block1(x, c, t)
|
if exists(pre_downsample):
|
||||||
|
x = pre_downsample(x)
|
||||||
|
|
||||||
|
x = init_block(x, c, t)
|
||||||
x = sparse_attn(x)
|
x = sparse_attn(x)
|
||||||
x = block2(x, c, t)
|
|
||||||
|
for resnet_block in resnet_blocks:
|
||||||
|
x = resnet_block(x, c, t)
|
||||||
|
|
||||||
hiddens.append(x)
|
hiddens.append(x)
|
||||||
x = downsample(x)
|
|
||||||
|
if exists(post_downsample):
|
||||||
|
x = post_downsample(x)
|
||||||
|
|
||||||
x = self.mid_block1(x, mid_c, t)
|
x = self.mid_block1(x, mid_c, t)
|
||||||
|
|
||||||
@@ -1630,20 +1680,24 @@ class Unet(nn.Module):
|
|||||||
|
|
||||||
x = self.mid_block2(x, mid_c, t)
|
x = self.mid_block2(x, mid_c, t)
|
||||||
|
|
||||||
for block1, sparse_attn, block2, upsample in self.ups:
|
for init_block, sparse_attn, resnet_blocks, upsample in self.ups:
|
||||||
x = torch.cat((x, hiddens.pop()), dim=1)
|
x = torch.cat((x, hiddens.pop()), dim = 1)
|
||||||
x = block1(x, c, t)
|
x = init_block(x, c, t)
|
||||||
x = sparse_attn(x)
|
x = sparse_attn(x)
|
||||||
x = block2(x, c, t)
|
|
||||||
|
for resnet_block in resnet_blocks:
|
||||||
|
x = resnet_block(x, c, t)
|
||||||
|
|
||||||
x = upsample(x)
|
x = upsample(x)
|
||||||
|
|
||||||
|
x = torch.cat((x, r), dim = 1)
|
||||||
return self.final_conv(x)
|
return self.final_conv(x)
|
||||||
|
|
||||||
class LowresConditioner(nn.Module):
|
class LowresConditioner(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
downsample_first = True,
|
downsample_first = True,
|
||||||
blur_sigma = 0.1,
|
blur_sigma = (0.1, 0.2),
|
||||||
blur_kernel_size = 3,
|
blur_kernel_size = 3,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -1667,6 +1721,18 @@ class LowresConditioner(nn.Module):
|
|||||||
# when training, blur the low resolution conditional image
|
# when training, blur the low resolution conditional image
|
||||||
blur_sigma = default(blur_sigma, self.blur_sigma)
|
blur_sigma = default(blur_sigma, self.blur_sigma)
|
||||||
blur_kernel_size = default(blur_kernel_size, self.blur_kernel_size)
|
blur_kernel_size = default(blur_kernel_size, self.blur_kernel_size)
|
||||||
|
|
||||||
|
# allow for drawing a random sigma between lo and hi float values
|
||||||
|
if isinstance(blur_sigma, tuple):
|
||||||
|
blur_sigma = tuple(map(float, blur_sigma))
|
||||||
|
blur_sigma = random.uniform(*blur_sigma)
|
||||||
|
|
||||||
|
# allow for drawing a random kernel size between lo and hi int values
|
||||||
|
if isinstance(blur_kernel_size, tuple):
|
||||||
|
blur_kernel_size = tuple(map(int, blur_kernel_size))
|
||||||
|
kernel_size_lo, kernel_size_hi = blur_kernel_size
|
||||||
|
blur_kernel_size = random.randrange(kernel_size_lo, kernel_size_hi + 1)
|
||||||
|
|
||||||
cond_fmap = gaussian_blur2d(cond_fmap, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2))
|
cond_fmap = gaussian_blur2d(cond_fmap, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2))
|
||||||
|
|
||||||
cond_fmap = resize_image_to(cond_fmap, target_image_size)
|
cond_fmap = resize_image_to(cond_fmap, target_image_size)
|
||||||
@@ -1692,30 +1758,44 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
image_sizes = None, # for cascading ddpm, image size at each stage
|
image_sizes = None, # for cascading ddpm, image size at each stage
|
||||||
random_crop_sizes = None, # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops)
|
random_crop_sizes = None, # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops)
|
||||||
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
|
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
|
||||||
blur_sigma = 0.1, # cascading ddpm - blur sigma
|
blur_sigma = (0.1, 0.2), # cascading ddpm - blur sigma
|
||||||
blur_kernel_size = 3, # cascading ddpm - blur kernel size
|
blur_kernel_size = 3, # cascading ddpm - blur kernel size
|
||||||
condition_on_text_encodings = False, # the paper suggested that this didn't do much in the decoder, but i'm allowing the option for experimentation
|
condition_on_text_encodings = False, # the paper suggested that this didn't do much in the decoder, but i'm allowing the option for experimentation
|
||||||
clip_denoised = True,
|
clip_denoised = True,
|
||||||
clip_x_start = True,
|
clip_x_start = True,
|
||||||
clip_adapter_overrides = dict(),
|
clip_adapter_overrides = dict(),
|
||||||
learned_variance = True,
|
learned_variance = True,
|
||||||
|
learned_variance_constrain_frac = False,
|
||||||
vb_loss_weight = 0.001,
|
vb_loss_weight = 0.001,
|
||||||
unconditional = False,
|
unconditional = False,
|
||||||
auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader
|
auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader
|
||||||
|
use_dynamic_thres = False, # from the Imagen paper
|
||||||
|
dynamic_thres_percentile = 0.9,
|
||||||
|
p2_loss_weight_gamma = 0., # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended
|
||||||
|
p2_loss_weight_k = 1
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
beta_schedule = beta_schedule,
|
beta_schedule = beta_schedule,
|
||||||
timesteps = timesteps,
|
timesteps = timesteps,
|
||||||
loss_type = loss_type
|
loss_type = loss_type,
|
||||||
|
p2_loss_weight_gamma = p2_loss_weight_gamma,
|
||||||
|
p2_loss_weight_k = p2_loss_weight_k
|
||||||
)
|
)
|
||||||
|
|
||||||
self.unconditional = unconditional
|
self.unconditional = unconditional
|
||||||
assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present'
|
|
||||||
|
|
||||||
assert self.unconditional or (exists(clip) ^ exists(image_size)), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)'
|
# text conditioning
|
||||||
|
|
||||||
|
assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present'
|
||||||
|
self.condition_on_text_encodings = condition_on_text_encodings
|
||||||
|
|
||||||
|
# clip
|
||||||
|
|
||||||
self.clip = None
|
self.clip = None
|
||||||
if exists(clip):
|
if exists(clip):
|
||||||
|
assert not unconditional, 'clip must not be given if doing unconditional image training'
|
||||||
|
assert channels == clip.image_channels, f'channels of image ({channels}) should be equal to the channels that CLIP accepts ({clip.image_channels})'
|
||||||
|
|
||||||
if isinstance(clip, CLIP):
|
if isinstance(clip, CLIP):
|
||||||
clip = XClipAdapter(clip, **clip_adapter_overrides)
|
clip = XClipAdapter(clip, **clip_adapter_overrides)
|
||||||
elif isinstance(clip, CoCa):
|
elif isinstance(clip, CoCa):
|
||||||
@@ -1725,13 +1805,20 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
assert isinstance(clip, BaseClipAdapter)
|
assert isinstance(clip, BaseClipAdapter)
|
||||||
|
|
||||||
self.clip = clip
|
self.clip = clip
|
||||||
self.clip_image_size = clip.image_size
|
|
||||||
self.channels = clip.image_channels
|
|
||||||
else:
|
|
||||||
self.clip_image_size = image_size
|
|
||||||
self.channels = channels
|
|
||||||
|
|
||||||
self.condition_on_text_encodings = condition_on_text_encodings
|
# determine image size, with image_size and image_sizes taking precedence
|
||||||
|
|
||||||
|
if exists(image_size) or exists(image_sizes):
|
||||||
|
assert exists(image_size) ^ exists(image_sizes), 'only one of image_size or image_sizes must be given'
|
||||||
|
image_size = default(image_size, lambda: image_sizes[-1])
|
||||||
|
elif exists(clip):
|
||||||
|
image_size = clip.image_size
|
||||||
|
else:
|
||||||
|
raise Error('either image_size, image_sizes, or clip must be given to decoder')
|
||||||
|
|
||||||
|
# channels
|
||||||
|
|
||||||
|
self.channels = channels
|
||||||
|
|
||||||
# automatically take care of ensuring that first unet is unconditional
|
# automatically take care of ensuring that first unet is unconditional
|
||||||
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
|
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
|
||||||
@@ -1743,6 +1830,7 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
|
|
||||||
learned_variance = pad_tuple_to_length(cast_tuple(learned_variance), len(unets), fillvalue = False)
|
learned_variance = pad_tuple_to_length(cast_tuple(learned_variance), len(unets), fillvalue = False)
|
||||||
self.learned_variance = learned_variance
|
self.learned_variance = learned_variance
|
||||||
|
self.learned_variance_constrain_frac = learned_variance_constrain_frac # whether to constrain the output of the network (the interpolation fraction) from 0 to 1
|
||||||
self.vb_loss_weight = vb_loss_weight
|
self.vb_loss_weight = vb_loss_weight
|
||||||
|
|
||||||
# construct unets and vaes
|
# construct unets and vaes
|
||||||
@@ -1773,7 +1861,7 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
|
|
||||||
# unet image sizes
|
# unet image sizes
|
||||||
|
|
||||||
image_sizes = default(image_sizes, (self.clip_image_size,))
|
image_sizes = default(image_sizes, (image_size,))
|
||||||
image_sizes = tuple(sorted(set(image_sizes)))
|
image_sizes = tuple(sorted(set(image_sizes)))
|
||||||
|
|
||||||
assert len(self.unets) == len(image_sizes), f'you did not supply the correct number of u-nets ({len(self.unets)}) for resolutions {image_sizes}'
|
assert len(self.unets) == len(image_sizes), f'you did not supply the correct number of u-nets ({len(self.unets)}) for resolutions {image_sizes}'
|
||||||
@@ -1810,7 +1898,13 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
self.clip_denoised = clip_denoised
|
self.clip_denoised = clip_denoised
|
||||||
self.clip_x_start = clip_x_start
|
self.clip_x_start = clip_x_start
|
||||||
|
|
||||||
|
# dynamic thresholding settings, if clipping denoised during sampling
|
||||||
|
|
||||||
|
self.use_dynamic_thres = use_dynamic_thres
|
||||||
|
self.dynamic_thres_percentile = dynamic_thres_percentile
|
||||||
|
|
||||||
# normalize and unnormalize image functions
|
# normalize and unnormalize image functions
|
||||||
|
|
||||||
self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity
|
self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity
|
||||||
self.unnormalize_img = unnormalize_zero_to_one if auto_normalize_img else identity
|
self.unnormalize_img = unnormalize_zero_to_one if auto_normalize_img else identity
|
||||||
|
|
||||||
@@ -1851,7 +1945,21 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
x_recon = self.predict_start_from_noise(x, t = t, noise = pred)
|
x_recon = self.predict_start_from_noise(x, t = t, noise = pred)
|
||||||
|
|
||||||
if clip_denoised:
|
if clip_denoised:
|
||||||
x_recon.clamp_(-1., 1.)
|
# s is the threshold amount
|
||||||
|
# static thresholding would just be s = 1
|
||||||
|
s = 1.
|
||||||
|
if self.use_dynamic_thres:
|
||||||
|
s = torch.quantile(
|
||||||
|
rearrange(x_recon, 'b ... -> b (...)').abs(),
|
||||||
|
self.dynamic_thres_percentile,
|
||||||
|
dim = -1
|
||||||
|
)
|
||||||
|
|
||||||
|
s.clamp_(min = 1.)
|
||||||
|
s = s.view(-1, *((1,) * (x_recon.ndim - 1)))
|
||||||
|
|
||||||
|
# clip by threshold, depending on whether static or dynamic
|
||||||
|
x_recon = x_recon.clamp(-s, s) / s
|
||||||
|
|
||||||
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
||||||
|
|
||||||
@@ -1863,16 +1971,19 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
max_log = extract(torch.log(self.betas), t, x.shape)
|
max_log = extract(torch.log(self.betas), t, x.shape)
|
||||||
var_interp_frac = unnormalize_zero_to_one(var_interp_frac_unnormalized)
|
var_interp_frac = unnormalize_zero_to_one(var_interp_frac_unnormalized)
|
||||||
|
|
||||||
|
if self.learned_variance_constrain_frac:
|
||||||
|
var_interp_frac = var_interp_frac.sigmoid()
|
||||||
|
|
||||||
posterior_log_variance = var_interp_frac * max_log + (1 - var_interp_frac) * min_log
|
posterior_log_variance = var_interp_frac * max_log + (1 - var_interp_frac) * min_log
|
||||||
posterior_variance = posterior_log_variance.exp()
|
posterior_variance = posterior_log_variance.exp()
|
||||||
|
|
||||||
return model_mean, posterior_variance, posterior_log_variance
|
return model_mean, posterior_variance, posterior_log_variance
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def p_sample(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True, repeat_noise = False):
|
def p_sample(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True):
|
||||||
b, *_, device = *x.shape, x.device
|
b, *_, device = *x.shape, x.device
|
||||||
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, learned_variance = learned_variance)
|
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, learned_variance = learned_variance)
|
||||||
noise = noise_like(x.shape, device, repeat_noise)
|
noise = torch.randn_like(x)
|
||||||
# no noise when t == 0
|
# no noise when t == 0
|
||||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||||
@@ -1936,7 +2047,13 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
|
|
||||||
target = noise if not predict_x_start else x_start
|
target = noise if not predict_x_start else x_start
|
||||||
|
|
||||||
loss = self.loss_fn(pred, target)
|
loss = self.loss_fn(pred, target, reduction = 'none')
|
||||||
|
loss = reduce(loss, 'b ... -> b (...)', 'mean')
|
||||||
|
|
||||||
|
if self.has_p2_loss_reweighting:
|
||||||
|
loss = loss * extract(self.p2_loss_weight, times, loss.shape)
|
||||||
|
|
||||||
|
loss = loss.mean()
|
||||||
|
|
||||||
if not learned_variance:
|
if not learned_variance:
|
||||||
# return simple loss if not using learned variance
|
# return simple loss if not using learned variance
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ from dalle2_pytorch.dataloaders import ImageEmbeddingDataset, create_image_embed
|
|||||||
|
|
||||||
# Create a dataloader directly.
|
# Create a dataloader directly.
|
||||||
dataloader = create_image_embedding_dataloader(
|
dataloader = create_image_embedding_dataloader(
|
||||||
tar_url="/path/or/url/to/webdataset/{0000..9999}.tar", # Uses braket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar
|
tar_url="/path/or/url/to/webdataset/{0000..9999}.tar", # Uses bracket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar
|
||||||
embeddings_url="path/or/url/to/embeddings/folder", # Included if .npy files are not in webdataset. Left out or set to None otherwise
|
embeddings_url="path/or/url/to/embeddings/folder", # Included if .npy files are not in webdataset. Left out or set to None otherwise
|
||||||
num_workers=4,
|
num_workers=4,
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
@@ -39,3 +39,37 @@ dataset = ImageEmbeddingDataset(
|
|||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Diffusion Prior: Prior Embedding Dataset
|
||||||
|
When training the prior it is much more efficient to work with pre-computed embeddings. The `PriorEmbeddingDataset` class enables you to leverage the same script (with minimal modification) for both embedding-only and text-conditioned prior training. This saves you from having to worry about a lot of the boilerplate code.
|
||||||
|
|
||||||
|
To utilize the `PriorEmbeddingDataset`, all you need to do is make a single call to `get_reader()` which will create `EmbeddingReader` object(s) for you. Afterwards, you can utilize `make_splits()` to cleanly create DataLoader objects from for your training run.
|
||||||
|
|
||||||
|
If you are training in a distributed manner, `make_splits()` accepts `rank` and `world_size` arguments to properly distribute to each process. The defaults for these values are `rank=0` and `world_size=1`, so single-process training can safely ignore these parameters.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
```python
|
||||||
|
from dalle2_pytorch.dataloaders import get_reader, make_splits
|
||||||
|
|
||||||
|
# grab embeddings from some specified location
|
||||||
|
IMG_URL = "data/img_emb/"
|
||||||
|
META_URL = "data/meta/"
|
||||||
|
|
||||||
|
reader = get_reader(text_conditioned=True, img_url=IMG_URL, meta_url=META_URL)
|
||||||
|
|
||||||
|
# some config for training
|
||||||
|
TRAIN_ARGS = {
|
||||||
|
"world_size": 3,
|
||||||
|
"text_conditioned": True,
|
||||||
|
"start": 0,
|
||||||
|
"num_data_points": 10000,
|
||||||
|
"batch_size": 2,
|
||||||
|
"train_split": 0.5,
|
||||||
|
"eval_split": 0.25,
|
||||||
|
"image_reader": reader,
|
||||||
|
}
|
||||||
|
|
||||||
|
# specifying a rank will handle allocation internally
|
||||||
|
rank0_train, rank0_eval, rank0_test = make_splits(rank=0, **TRAIN_ARGS)
|
||||||
|
rank1_train, rank1_eval, rank1_test = make_splits(rank=1, **TRAIN_ARGS)
|
||||||
|
rank2_train, rank2_eval, rank2_test = make_splits(rank=2, **TRAIN_ARGS)
|
||||||
|
```
|
||||||
|
|||||||
@@ -1,2 +1,2 @@
|
|||||||
from dalle2_pytorch.dataloaders.decoder_loader import ImageEmbeddingDataset, create_image_embedding_dataloader
|
from dalle2_pytorch.dataloaders.decoder_loader import ImageEmbeddingDataset, create_image_embedding_dataloader
|
||||||
from dalle2_pytorch.dataloaders.embedding_wrapper import make_splits
|
from dalle2_pytorch.dataloaders.prior_loader import make_splits, get_reader, PriorEmbeddingDataset
|
||||||
|
|||||||
@@ -1,180 +0,0 @@
|
|||||||
from torch.utils.data import IterableDataset
|
|
||||||
from torch import from_numpy
|
|
||||||
from clip import tokenize
|
|
||||||
from embedding_reader import EmbeddingReader
|
|
||||||
|
|
||||||
|
|
||||||
class PriorEmbeddingLoader(IterableDataset):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
text_conditioned: bool,
|
|
||||||
batch_size: int,
|
|
||||||
start: int,
|
|
||||||
stop: int,
|
|
||||||
image_reader,
|
|
||||||
text_reader: EmbeddingReader = None,
|
|
||||||
device: str = "cpu",
|
|
||||||
) -> None:
|
|
||||||
super(PriorEmbeddingLoader).__init__()
|
|
||||||
|
|
||||||
self.text_conditioned = text_conditioned
|
|
||||||
|
|
||||||
if not self.text_conditioned:
|
|
||||||
self.text_reader = text_reader
|
|
||||||
|
|
||||||
self.image_reader = image_reader
|
|
||||||
self.batch_size = batch_size
|
|
||||||
self.start = start
|
|
||||||
self.stop = stop
|
|
||||||
self.device = device
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
self.n = 0
|
|
||||||
loader_args = dict(
|
|
||||||
batch_size=self.batch_size,
|
|
||||||
start=self.start,
|
|
||||||
end=self.stop,
|
|
||||||
show_progress=False,
|
|
||||||
)
|
|
||||||
if self.text_conditioned:
|
|
||||||
self.loader = self.image_reader(**loader_args)
|
|
||||||
else:
|
|
||||||
self.loader = zip(
|
|
||||||
self.image_reader(**loader_args), self.text_reader(**loader_args)
|
|
||||||
)
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __next__(self):
|
|
||||||
try:
|
|
||||||
return self.get_sample()
|
|
||||||
except StopIteration:
|
|
||||||
raise StopIteration
|
|
||||||
|
|
||||||
def get_sample(self):
|
|
||||||
"""
|
|
||||||
pre-proocess data from either reader into a common format
|
|
||||||
"""
|
|
||||||
self.n += 1
|
|
||||||
|
|
||||||
if self.text_conditioned:
|
|
||||||
image_embedding, caption = next(self.loader)
|
|
||||||
|
|
||||||
image_embedding = from_numpy(image_embedding).to(self.device)
|
|
||||||
tokenized_caption = tokenize(
|
|
||||||
caption["caption"].to_list(), truncate=True
|
|
||||||
).to(self.device)
|
|
||||||
|
|
||||||
return image_embedding, tokenized_caption
|
|
||||||
|
|
||||||
else:
|
|
||||||
(image_embedding, _), (text_embedding, _) = next(self.loader)
|
|
||||||
|
|
||||||
image_embedding = from_numpy(image_embedding).to(self.device)
|
|
||||||
text_embedding = from_numpy(text_embedding).to(self.device)
|
|
||||||
|
|
||||||
return image_embedding, text_embedding
|
|
||||||
|
|
||||||
|
|
||||||
def make_splits(
|
|
||||||
text_conditioned: bool,
|
|
||||||
batch_size: int,
|
|
||||||
num_data_points: int,
|
|
||||||
train_split: float,
|
|
||||||
eval_split: float,
|
|
||||||
device: str,
|
|
||||||
img_url: str,
|
|
||||||
meta_url: str = None,
|
|
||||||
txt_url: str = None,
|
|
||||||
):
|
|
||||||
|
|
||||||
assert img_url is not None, "Must supply some image embeddings"
|
|
||||||
|
|
||||||
if text_conditioned:
|
|
||||||
assert meta_url is not None, "Must supply metadata url if text-conditioning"
|
|
||||||
image_reader = EmbeddingReader(
|
|
||||||
embeddings_folder=img_url,
|
|
||||||
file_format="parquet_npy",
|
|
||||||
meta_columns=["caption"],
|
|
||||||
metadata_folder=meta_url,
|
|
||||||
)
|
|
||||||
|
|
||||||
# compute split points
|
|
||||||
if num_data_points > image_reader.count:
|
|
||||||
print("Specified point count is larger than the number of points available...defaulting to max length of reader.")
|
|
||||||
num_data_points = image_reader.count
|
|
||||||
|
|
||||||
train_set_size = int(train_split * num_data_points)
|
|
||||||
eval_set_size = int(eval_split * num_data_points)
|
|
||||||
eval_stop = int(train_set_size + eval_set_size)
|
|
||||||
|
|
||||||
train_loader = PriorEmbeddingLoader(
|
|
||||||
text_conditioned=text_conditioned,
|
|
||||||
image_reader=image_reader,
|
|
||||||
batch_size=batch_size,
|
|
||||||
start=0,
|
|
||||||
stop=train_set_size,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
eval_loader = PriorEmbeddingLoader(
|
|
||||||
text_conditioned=text_conditioned,
|
|
||||||
image_reader=image_reader,
|
|
||||||
batch_size=batch_size,
|
|
||||||
start=train_set_size,
|
|
||||||
stop=eval_stop,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
test_loader = PriorEmbeddingLoader(
|
|
||||||
text_conditioned=text_conditioned,
|
|
||||||
image_reader=image_reader,
|
|
||||||
batch_size=batch_size,
|
|
||||||
start=eval_stop,
|
|
||||||
stop=int(num_data_points),
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
assert (
|
|
||||||
txt_url is not None
|
|
||||||
), "Must supply text embedding url if not text-conditioning"
|
|
||||||
|
|
||||||
image_reader = EmbeddingReader(img_url, file_format="npy")
|
|
||||||
text_reader = EmbeddingReader(txt_url, file_format="npy")
|
|
||||||
|
|
||||||
# compute split points
|
|
||||||
if num_data_points > image_reader.count:
|
|
||||||
print("Specified point count is larger than the number of points available...defaulting to max length of reader.")
|
|
||||||
num_data_points = image_reader.count
|
|
||||||
|
|
||||||
train_set_size = int(train_split * num_data_points)
|
|
||||||
eval_set_size = int(eval_split * num_data_points)
|
|
||||||
eval_stop = int(train_set_size + eval_set_size)
|
|
||||||
|
|
||||||
train_loader = PriorEmbeddingLoader(
|
|
||||||
text_conditioned=text_conditioned,
|
|
||||||
image_reader=image_reader,
|
|
||||||
text_reader=text_reader,
|
|
||||||
batch_size=batch_size,
|
|
||||||
start=0,
|
|
||||||
stop=train_set_size,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
eval_loader = PriorEmbeddingLoader(
|
|
||||||
text_conditioned=text_conditioned,
|
|
||||||
image_reader=image_reader,
|
|
||||||
text_reader=text_reader,
|
|
||||||
batch_size=batch_size,
|
|
||||||
start=train_set_size,
|
|
||||||
stop=eval_stop,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
test_loader = PriorEmbeddingLoader(
|
|
||||||
text_conditioned=text_conditioned,
|
|
||||||
image_reader=image_reader,
|
|
||||||
text_reader=text_reader,
|
|
||||||
batch_size=batch_size,
|
|
||||||
start=eval_stop,
|
|
||||||
stop=int(num_data_points),
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
|
|
||||||
return train_loader, eval_loader, test_loader
|
|
||||||
273
dalle2_pytorch/dataloaders/prior_loader.py
Normal file
273
dalle2_pytorch/dataloaders/prior_loader.py
Normal file
@@ -0,0 +1,273 @@
|
|||||||
|
from math import ceil
|
||||||
|
from clip import tokenize
|
||||||
|
from embedding_reader import EmbeddingReader
|
||||||
|
from torch import from_numpy
|
||||||
|
from torch.utils.data import IterableDataset, DataLoader
|
||||||
|
|
||||||
|
|
||||||
|
class PriorEmbeddingDataset(IterableDataset):
|
||||||
|
"""
|
||||||
|
PriorEmbeddingDataset is a wrapper of EmbeddingReader.
|
||||||
|
|
||||||
|
It enables one to simplify the logic necessary to yield samples from
|
||||||
|
the different EmbeddingReader configurations available.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
text_conditioned: bool,
|
||||||
|
batch_size: int,
|
||||||
|
start: int,
|
||||||
|
stop: int,
|
||||||
|
image_reader,
|
||||||
|
text_reader: EmbeddingReader = None,
|
||||||
|
) -> None:
|
||||||
|
super(PriorEmbeddingDataset).__init__()
|
||||||
|
|
||||||
|
self.text_conditioned = text_conditioned
|
||||||
|
|
||||||
|
if not self.text_conditioned:
|
||||||
|
self.text_reader = text_reader
|
||||||
|
|
||||||
|
self.image_reader = image_reader
|
||||||
|
self.start = start
|
||||||
|
self.stop = stop
|
||||||
|
self.batch_size = batch_size
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.stop - self.start
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
# D.R.Y loader args
|
||||||
|
loader_args = dict(
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
start=self.start,
|
||||||
|
end=self.stop,
|
||||||
|
show_progress=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# if the data requested is text conditioned, only load images
|
||||||
|
if self.text_conditioned:
|
||||||
|
self.loader = self.image_reader(**loader_args)
|
||||||
|
# otherwise, include text embeddings and bypass metadata
|
||||||
|
else:
|
||||||
|
self.loader = zip(
|
||||||
|
self.image_reader(**loader_args), self.text_reader(**loader_args)
|
||||||
|
)
|
||||||
|
|
||||||
|
# return the data loader in its formatted state
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
try:
|
||||||
|
return self.get_sample()
|
||||||
|
except StopIteration:
|
||||||
|
raise StopIteration
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"<PriorEmbeddingDataset: start: {self.start}, stop: {self.stop}, len: {self.__len__()}>"
|
||||||
|
|
||||||
|
def get_sample(self):
|
||||||
|
"""
|
||||||
|
pre-proocess data from either reader into a common format
|
||||||
|
"""
|
||||||
|
if self.text_conditioned:
|
||||||
|
image_embedding, caption = next(self.loader)
|
||||||
|
|
||||||
|
image_embedding = from_numpy(image_embedding)
|
||||||
|
tokenized_caption = tokenize(caption["caption"].to_list(), truncate=True)
|
||||||
|
|
||||||
|
return image_embedding, tokenized_caption
|
||||||
|
|
||||||
|
else:
|
||||||
|
(image_embedding, _), (text_embedding, _) = next(self.loader)
|
||||||
|
|
||||||
|
image_embedding = from_numpy(image_embedding)
|
||||||
|
text_embedding = from_numpy(text_embedding)
|
||||||
|
|
||||||
|
return image_embedding, text_embedding
|
||||||
|
|
||||||
|
|
||||||
|
# helper functions
|
||||||
|
|
||||||
|
|
||||||
|
def distribute_to_rank(start, stop, rank, world_size):
|
||||||
|
"""
|
||||||
|
Distribute data to each rank given the world size.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
- New start and stop points for this rank.
|
||||||
|
"""
|
||||||
|
num_samples = int(stop - start)
|
||||||
|
|
||||||
|
per_rank = int(ceil((num_samples) / float(world_size)))
|
||||||
|
|
||||||
|
assert (
|
||||||
|
per_rank > 0
|
||||||
|
), f"Number of samples per rank must be larger than 0, (found: {per_rank})"
|
||||||
|
|
||||||
|
rank_start = start + rank * per_rank
|
||||||
|
|
||||||
|
rank_stop = min(rank_start + per_rank, stop)
|
||||||
|
|
||||||
|
new_length = rank_stop - rank_start
|
||||||
|
|
||||||
|
assert (
|
||||||
|
new_length > 0
|
||||||
|
), "Calculated start and stop points result in a length of zero for this rank."
|
||||||
|
|
||||||
|
return rank_start, rank_stop
|
||||||
|
|
||||||
|
|
||||||
|
def get_reader(
|
||||||
|
text_conditioned: bool, img_url: str, meta_url: str = None, txt_url: str = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Create an EmbeddingReader object from the specified URLs
|
||||||
|
|
||||||
|
get_reader() will always expect a url to image embeddings.
|
||||||
|
|
||||||
|
If text-conditioned, it will also expect a meta_url for the captions.
|
||||||
|
Otherwise, it will need txt_url for the matching text embeddings.
|
||||||
|
|
||||||
|
Returns an image_reader object if text-conditioned.
|
||||||
|
Otherwise it returns both an image_reader and a text_reader
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert img_url is not None, "Must supply a image url"
|
||||||
|
|
||||||
|
if text_conditioned:
|
||||||
|
assert meta_url is not None, "Must supply meta url if text-conditioned"
|
||||||
|
|
||||||
|
image_reader = EmbeddingReader(
|
||||||
|
embeddings_folder=img_url,
|
||||||
|
file_format="parquet_npy",
|
||||||
|
# will assume the caption column exists and is the only one requested
|
||||||
|
meta_columns=["caption"],
|
||||||
|
metadata_folder=meta_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
return image_reader
|
||||||
|
|
||||||
|
# otherwise we will require text embeddings as well and return two readers
|
||||||
|
assert (
|
||||||
|
txt_url is not None
|
||||||
|
), "Must supply text embedding url if not text-conditioning"
|
||||||
|
|
||||||
|
image_reader = EmbeddingReader(img_url, file_format="npy")
|
||||||
|
text_reader = EmbeddingReader(txt_url, file_format="npy")
|
||||||
|
|
||||||
|
return image_reader, text_reader
|
||||||
|
|
||||||
|
|
||||||
|
def make_splits(
|
||||||
|
text_conditioned: bool,
|
||||||
|
batch_size: int,
|
||||||
|
num_data_points: int,
|
||||||
|
train_split: float,
|
||||||
|
eval_split: float,
|
||||||
|
image_reader: EmbeddingReader,
|
||||||
|
text_reader: EmbeddingReader = None,
|
||||||
|
start=0,
|
||||||
|
rank=0,
|
||||||
|
world_size=1,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Split an embedding reader object as needed.
|
||||||
|
|
||||||
|
NOTE: make_splits() will infer the test set size from your train and eval.
|
||||||
|
|
||||||
|
Input:
|
||||||
|
- text_conditioned: whether to prepare text-conditioned training data
|
||||||
|
- batch_size: the batch size for a single gpu
|
||||||
|
- num_data_points: the total number of data points you wish to train on
|
||||||
|
- train_split: the percentage of data you wish to train on
|
||||||
|
- eval_split: the percentage of data you wish to validate on
|
||||||
|
- image_reader: the image_reader you wish to split
|
||||||
|
- text_reader: the text_reader you want to split (if !text_conditioned)
|
||||||
|
- start: the starting point within your dataset
|
||||||
|
- rank: the rank of your worker
|
||||||
|
- world_size: the total world size of your distributed training run
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- PyTorch Dataloaders that yield tuples of (img, txt) data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert start < image_reader.count, "start position cannot exceed reader count."
|
||||||
|
|
||||||
|
# verify that the num_data_points does not exceed the max points
|
||||||
|
if num_data_points > (image_reader.count - start):
|
||||||
|
print(
|
||||||
|
"Specified count is larger than what's available...defaulting to reader's count."
|
||||||
|
)
|
||||||
|
num_data_points = image_reader.count
|
||||||
|
|
||||||
|
# compute split points
|
||||||
|
train_set_size = int(train_split * num_data_points)
|
||||||
|
eval_set_size = int(eval_split * num_data_points)
|
||||||
|
eval_start = train_set_size
|
||||||
|
eval_stop = int(eval_start + eval_set_size)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
train_split + eval_split
|
||||||
|
) < 1.0, "Specified train and eval split is too large to infer a test split."
|
||||||
|
|
||||||
|
# distribute to rank
|
||||||
|
rank_train_start, rank_train_stop = distribute_to_rank(
|
||||||
|
start, train_set_size, rank, world_size
|
||||||
|
)
|
||||||
|
rank_eval_start, rank_eval_stop = distribute_to_rank(
|
||||||
|
train_set_size, eval_stop, rank, world_size
|
||||||
|
)
|
||||||
|
rank_test_start, rank_test_stop = distribute_to_rank(
|
||||||
|
eval_stop, num_data_points, rank, world_size
|
||||||
|
)
|
||||||
|
|
||||||
|
# wrap up splits into a dict
|
||||||
|
train_split_args = dict(
|
||||||
|
start=rank_train_start, stop=rank_train_stop, batch_size=batch_size
|
||||||
|
)
|
||||||
|
eval_split_args = dict(
|
||||||
|
start=rank_eval_start, stop=rank_eval_stop, batch_size=batch_size
|
||||||
|
)
|
||||||
|
test_split_args = dict(
|
||||||
|
start=rank_test_start, stop=rank_test_stop, batch_size=batch_size
|
||||||
|
)
|
||||||
|
|
||||||
|
if text_conditioned:
|
||||||
|
# add the text-conditioned args to a unified dict
|
||||||
|
reader_args = dict(
|
||||||
|
text_conditioned=text_conditioned,
|
||||||
|
image_reader=image_reader,
|
||||||
|
)
|
||||||
|
|
||||||
|
train_split_args = dict(**reader_args, **train_split_args)
|
||||||
|
eval_split_args = dict(**reader_args, **eval_split_args)
|
||||||
|
test_split_args = dict(**reader_args, **test_split_args)
|
||||||
|
|
||||||
|
train = PriorEmbeddingDataset(**train_split_args)
|
||||||
|
val = PriorEmbeddingDataset(**eval_split_args)
|
||||||
|
test = PriorEmbeddingDataset(**test_split_args)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# add the non-conditioned args to a unified dict
|
||||||
|
reader_args = dict(
|
||||||
|
text_conditioned=text_conditioned,
|
||||||
|
image_reader=image_reader,
|
||||||
|
text_reader=text_reader,
|
||||||
|
)
|
||||||
|
|
||||||
|
train_split_args = dict(**reader_args, **train_split_args)
|
||||||
|
eval_split_args = dict(**reader_args, **eval_split_args)
|
||||||
|
test_split_args = dict(**reader_args, **test_split_args)
|
||||||
|
|
||||||
|
train = PriorEmbeddingDataset(**train_split_args)
|
||||||
|
val = PriorEmbeddingDataset(**eval_split_args)
|
||||||
|
test = PriorEmbeddingDataset(**test_split_args)
|
||||||
|
|
||||||
|
# true batch size is specifed in the PriorEmbeddingDataset
|
||||||
|
train_loader = DataLoader(train, batch_size=None)
|
||||||
|
eval_loader = DataLoader(val, batch_size=None)
|
||||||
|
test_loader = DataLoader(test, batch_size=None)
|
||||||
|
|
||||||
|
return train_loader, eval_loader, test_loader
|
||||||
@@ -1,17 +1,20 @@
|
|||||||
from torch.optim import AdamW, Adam
|
from torch.optim import AdamW, Adam
|
||||||
|
|
||||||
def separate_weight_decayable_params(params):
|
def separate_weight_decayable_params(params):
|
||||||
no_wd_params = set([param for param in params if param.ndim < 2])
|
wd_params, no_wd_params = [], []
|
||||||
wd_params = set(params) - no_wd_params
|
for param in params:
|
||||||
|
param_list = no_wd_params if param.ndim < 2 else wd_params
|
||||||
|
param_list.append(param)
|
||||||
return wd_params, no_wd_params
|
return wd_params, no_wd_params
|
||||||
|
|
||||||
def get_optimizer(
|
def get_optimizer(
|
||||||
params,
|
params,
|
||||||
lr = 1e-4,
|
lr = 1e-4,
|
||||||
wd = 1e-2,
|
wd = 1e-2,
|
||||||
betas = (0.9, 0.999),
|
betas = (0.9, 0.99),
|
||||||
eps = 1e-8,
|
eps = 1e-8,
|
||||||
filter_by_requires_grad = False,
|
filter_by_requires_grad = False,
|
||||||
|
group_wd_params = True,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
if filter_by_requires_grad:
|
if filter_by_requires_grad:
|
||||||
@@ -20,12 +23,12 @@ def get_optimizer(
|
|||||||
if wd == 0:
|
if wd == 0:
|
||||||
return Adam(params, lr = lr, betas = betas, eps = eps)
|
return Adam(params, lr = lr, betas = betas, eps = eps)
|
||||||
|
|
||||||
params = set(params)
|
if group_wd_params:
|
||||||
wd_params, no_wd_params = separate_weight_decayable_params(params)
|
wd_params, no_wd_params = separate_weight_decayable_params(params)
|
||||||
|
|
||||||
param_groups = [
|
params = [
|
||||||
{'params': list(wd_params)},
|
{'params': wd_params},
|
||||||
{'params': list(no_wd_params), 'weight_decay': 0},
|
{'params': no_wd_params, 'weight_decay': 0},
|
||||||
]
|
]
|
||||||
|
|
||||||
return AdamW(param_groups, lr = lr, weight_decay = wd, betas = betas, eps = eps)
|
return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps)
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
# to give users a quick easy start to training DALL-E without doing BPE
|
# to give users a quick easy start to training DALL-E without doing BPE
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import youtokentome as yttm
|
|
||||||
|
|
||||||
import html
|
import html
|
||||||
import os
|
import os
|
||||||
@@ -11,6 +10,8 @@ import regex as re
|
|||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
from dalle2_pytorch.utils import import_or_print_error
|
||||||
|
|
||||||
# OpenAI simple tokenizer
|
# OpenAI simple tokenizer
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
@@ -156,7 +157,9 @@ class YttmTokenizer:
|
|||||||
bpe_path = Path(bpe_path)
|
bpe_path = Path(bpe_path)
|
||||||
assert bpe_path.exists(), f'BPE json path {str(bpe_path)} does not exist'
|
assert bpe_path.exists(), f'BPE json path {str(bpe_path)} does not exist'
|
||||||
|
|
||||||
tokenizer = yttm.BPE(model = str(bpe_path))
|
self.yttm = import_or_print_error('youtokentome', 'you need to install youtokentome by `pip install youtokentome`')
|
||||||
|
|
||||||
|
tokenizer = self.yttm.BPE(model = str(bpe_path))
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.vocab_size = tokenizer.vocab_size()
|
self.vocab_size = tokenizer.vocab_size()
|
||||||
|
|
||||||
@@ -167,7 +170,7 @@ class YttmTokenizer:
|
|||||||
return self.tokenizer.decode(tokens, ignore_ids = pad_tokens.union({0}))
|
return self.tokenizer.decode(tokens, ignore_ids = pad_tokens.union({0}))
|
||||||
|
|
||||||
def encode(self, texts):
|
def encode(self, texts):
|
||||||
encoded = self.tokenizer.encode(texts, output_type = yttm.OutputType.ID)
|
encoded = self.tokenizer.encode(texts, output_type = self.yttm.OutputType.ID)
|
||||||
return list(map(torch.tensor, encoded))
|
return list(map(torch.tensor, encoded))
|
||||||
|
|
||||||
def tokenize(self, texts, context_length = 256, truncate_text = False):
|
def tokenize(self, texts, context_length = 256, truncate_text = False):
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ from itertools import zip_longest
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
from dalle2_pytorch.utils import import_or_print_error
|
||||||
|
|
||||||
# constants
|
# constants
|
||||||
|
|
||||||
DEFAULT_DATA_PATH = './.tracker-data'
|
DEFAULT_DATA_PATH = './.tracker-data'
|
||||||
@@ -15,14 +17,6 @@ DEFAULT_DATA_PATH = './.tracker-data'
|
|||||||
def exists(val):
|
def exists(val):
|
||||||
return val is not None
|
return val is not None
|
||||||
|
|
||||||
def import_or_print_error(pkg_name, err_str = None):
|
|
||||||
try:
|
|
||||||
return importlib.import_module(pkg_name)
|
|
||||||
except ModuleNotFoundError as e:
|
|
||||||
if exists(err_str):
|
|
||||||
print(err_str)
|
|
||||||
exit()
|
|
||||||
|
|
||||||
# load state dict functions
|
# load state dict functions
|
||||||
|
|
||||||
def load_wandb_state_dict(run_path, file_path, **kwargs):
|
def load_wandb_state_dict(run_path, file_path, **kwargs):
|
||||||
|
|||||||
@@ -3,15 +3,160 @@ from torchvision import transforms as T
|
|||||||
from pydantic import BaseModel, validator, root_validator
|
from pydantic import BaseModel, validator, root_validator
|
||||||
from typing import List, Iterable, Optional, Union, Tuple, Dict, Any
|
from typing import List, Iterable, Optional, Union, Tuple, Dict, Any
|
||||||
|
|
||||||
|
from x_clip import CLIP as XCLIP
|
||||||
|
from coca_pytorch import CoCa
|
||||||
|
|
||||||
|
from dalle2_pytorch.dalle2_pytorch import (
|
||||||
|
CoCaAdapter,
|
||||||
|
OpenAIClipAdapter,
|
||||||
|
Unet,
|
||||||
|
Decoder,
|
||||||
|
DiffusionPrior,
|
||||||
|
DiffusionPriorNetwork,
|
||||||
|
XClipAdapter,
|
||||||
|
)
|
||||||
|
|
||||||
|
# helper functions
|
||||||
|
|
||||||
def exists(val):
|
def exists(val):
|
||||||
return val is not None
|
return val is not None
|
||||||
|
|
||||||
def default(val, d):
|
def default(val, d):
|
||||||
return val if exists(val) else d
|
return val if exists(val) else d
|
||||||
|
|
||||||
|
def ListOrTuple(inner_type):
|
||||||
|
return Union[List[inner_type], Tuple[inner_type]]
|
||||||
|
|
||||||
|
def SingularOrIterable(inner_type):
|
||||||
|
return Union[inner_type, ListOrTuple(inner_type)]
|
||||||
|
|
||||||
|
# general pydantic classes
|
||||||
|
|
||||||
|
class TrainSplitConfig(BaseModel):
|
||||||
|
train: float = 0.75
|
||||||
|
val: float = 0.15
|
||||||
|
test: float = 0.1
|
||||||
|
|
||||||
|
@root_validator
|
||||||
|
def validate_all(cls, fields):
|
||||||
|
actual_sum = sum([*fields.values()])
|
||||||
|
if actual_sum != 1.:
|
||||||
|
raise ValueError(f'{fields.keys()} must sum to 1.0. Found: {actual_sum}')
|
||||||
|
return fields
|
||||||
|
|
||||||
|
class TrackerConfig(BaseModel):
|
||||||
|
tracker_type: str = 'console' # Decoder currently supports console and wandb
|
||||||
|
data_path: str = './models' # The path where files will be saved locally
|
||||||
|
init_config: Dict[str, Any] = None
|
||||||
|
wandb_entity: str = '' # Only needs to be set if tracker_type is wandb
|
||||||
|
wandb_project: str = ''
|
||||||
|
verbose: bool = False # Whether to print console logging for non-console trackers
|
||||||
|
|
||||||
|
# diffusion prior pydantic classes
|
||||||
|
|
||||||
|
class AdapterConfig(BaseModel):
|
||||||
|
make: str = "openai"
|
||||||
|
model: str = "ViT-L/14"
|
||||||
|
base_model_kwargs: Dict[str, Any] = None
|
||||||
|
|
||||||
|
def create(self):
|
||||||
|
if self.make == "openai":
|
||||||
|
return OpenAIClipAdapter(self.model)
|
||||||
|
elif self.make == "x-clip":
|
||||||
|
return XClipAdapter(XCLIP(**self.base_model_kwargs))
|
||||||
|
elif self.make == "coca":
|
||||||
|
return CoCaAdapter(CoCa(**self.base_model_kwargs))
|
||||||
|
else:
|
||||||
|
raise AttributeError("No adapter with that name is available.")
|
||||||
|
|
||||||
|
class DiffusionPriorNetworkConfig(BaseModel):
|
||||||
|
dim: int
|
||||||
|
depth: int
|
||||||
|
num_timesteps: int = None
|
||||||
|
num_time_embeds: int = 1
|
||||||
|
num_image_embeds: int = 1
|
||||||
|
num_text_embeds: int = 1
|
||||||
|
dim_head: int = 64
|
||||||
|
heads: int = 8
|
||||||
|
ff_mult: int = 4
|
||||||
|
norm_out: bool = True
|
||||||
|
attn_dropout: float = 0.
|
||||||
|
ff_dropout: float = 0.
|
||||||
|
final_proj: bool = True
|
||||||
|
normformer: bool = False
|
||||||
|
rotary_emb: bool = True
|
||||||
|
|
||||||
|
def create(self):
|
||||||
|
kwargs = self.dict()
|
||||||
|
return DiffusionPriorNetwork(**kwargs)
|
||||||
|
|
||||||
|
class DiffusionPriorConfig(BaseModel):
|
||||||
|
clip: AdapterConfig = None
|
||||||
|
net: DiffusionPriorNetworkConfig
|
||||||
|
image_embed_dim: int
|
||||||
|
image_size: int
|
||||||
|
image_channels: int = 3
|
||||||
|
timesteps: int = 1000
|
||||||
|
cond_drop_prob: float = 0.
|
||||||
|
loss_type: str = 'l2'
|
||||||
|
predict_x_start: bool = True
|
||||||
|
beta_schedule: str = 'cosine'
|
||||||
|
condition_on_text_encodings: bool = True
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
extra = "allow"
|
||||||
|
|
||||||
|
def create(self):
|
||||||
|
kwargs = self.dict()
|
||||||
|
|
||||||
|
has_clip = exists(kwargs.pop('clip'))
|
||||||
|
kwargs.pop('net')
|
||||||
|
|
||||||
|
clip = None
|
||||||
|
if has_clip:
|
||||||
|
clip = self.clip.create()
|
||||||
|
|
||||||
|
diffusion_prior_network = self.net.create()
|
||||||
|
return DiffusionPrior(net = diffusion_prior_network, clip = clip, **kwargs)
|
||||||
|
|
||||||
|
class DiffusionPriorTrainConfig(BaseModel):
|
||||||
|
epochs: int = 1
|
||||||
|
lr: float = 1.1e-4
|
||||||
|
wd: float = 6.02e-2
|
||||||
|
max_grad_norm: float = 0.5
|
||||||
|
use_ema: bool = True
|
||||||
|
ema_beta: float = 0.99
|
||||||
|
amp: bool = False
|
||||||
|
save_every: int = 10000 # what steps to save on
|
||||||
|
|
||||||
|
class DiffusionPriorDataConfig(BaseModel):
|
||||||
|
image_url: str # path to embeddings folder
|
||||||
|
meta_url: str # path to metadata (captions) for images
|
||||||
|
splits: TrainSplitConfig
|
||||||
|
batch_size: int = 64
|
||||||
|
|
||||||
|
class DiffusionPriorLoadConfig(BaseModel):
|
||||||
|
source: str = None
|
||||||
|
resume: bool = False
|
||||||
|
|
||||||
|
class TrainDiffusionPriorConfig(BaseModel):
|
||||||
|
prior: DiffusionPriorConfig
|
||||||
|
data: DiffusionPriorDataConfig
|
||||||
|
train: DiffusionPriorTrainConfig
|
||||||
|
load: DiffusionPriorLoadConfig
|
||||||
|
tracker: TrackerConfig
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_json_path(cls, json_path):
|
||||||
|
with open(json_path) as f:
|
||||||
|
config = json.load(f)
|
||||||
|
return cls(**config)
|
||||||
|
|
||||||
|
# decoder pydantic classes
|
||||||
|
|
||||||
class UnetConfig(BaseModel):
|
class UnetConfig(BaseModel):
|
||||||
dim: int
|
dim: int
|
||||||
dim_mults: List[int]
|
dim_mults: ListOrTuple(int)
|
||||||
image_embed_dim: int = None
|
image_embed_dim: int = None
|
||||||
cond_dim: int = None
|
cond_dim: int = None
|
||||||
channels: int = 3
|
channels: int = 3
|
||||||
@@ -22,13 +167,22 @@ class UnetConfig(BaseModel):
|
|||||||
extra = "allow"
|
extra = "allow"
|
||||||
|
|
||||||
class DecoderConfig(BaseModel):
|
class DecoderConfig(BaseModel):
|
||||||
|
unets: ListOrTuple(UnetConfig)
|
||||||
image_size: int = None
|
image_size: int = None
|
||||||
image_sizes: Union[List[int], Tuple[int]] = None
|
image_sizes: ListOrTuple(int) = None
|
||||||
channels: int = 3
|
channels: int = 3
|
||||||
timesteps: int = 1000
|
timesteps: int = 1000
|
||||||
loss_type: str = 'l2'
|
loss_type: str = 'l2'
|
||||||
beta_schedule: str = 'cosine'
|
beta_schedule: str = 'cosine'
|
||||||
learned_variance: bool = True
|
learned_variance: bool = True
|
||||||
|
image_cond_drop_prob: float = 0.1
|
||||||
|
text_cond_drop_prob: float = 0.5
|
||||||
|
|
||||||
|
def create(self):
|
||||||
|
decoder_kwargs = self.dict()
|
||||||
|
unet_configs = decoder_kwargs.pop('unets')
|
||||||
|
unets = [Unet(**config) for config in unet_configs]
|
||||||
|
return Decoder(unets, **decoder_kwargs)
|
||||||
|
|
||||||
@validator('image_sizes')
|
@validator('image_sizes')
|
||||||
def check_image_sizes(cls, image_sizes, values):
|
def check_image_sizes(cls, image_sizes, values):
|
||||||
@@ -39,17 +193,6 @@ class DecoderConfig(BaseModel):
|
|||||||
class Config:
|
class Config:
|
||||||
extra = "allow"
|
extra = "allow"
|
||||||
|
|
||||||
class TrainSplitConfig(BaseModel):
|
|
||||||
train: float = 0.75
|
|
||||||
val: float = 0.15
|
|
||||||
test: float = 0.1
|
|
||||||
|
|
||||||
@root_validator
|
|
||||||
def validate_all(cls, fields):
|
|
||||||
if sum([*fields.values()]) != 1.:
|
|
||||||
raise ValueError(f'{fields.keys()} must sum to 1.0')
|
|
||||||
return fields
|
|
||||||
|
|
||||||
class DecoderDataConfig(BaseModel):
|
class DecoderDataConfig(BaseModel):
|
||||||
webdataset_base_url: str # path to a webdataset with jpg images
|
webdataset_base_url: str # path to a webdataset with jpg images
|
||||||
embeddings_url: str # path to .npy files with embeddings
|
embeddings_url: str # path to .npy files with embeddings
|
||||||
@@ -64,23 +207,39 @@ class DecoderDataConfig(BaseModel):
|
|||||||
resample_train: bool = False
|
resample_train: bool = False
|
||||||
preprocessing: Dict[str, Any] = {'ToTensor': True}
|
preprocessing: Dict[str, Any] = {'ToTensor': True}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def img_preproc(self):
|
||||||
|
def _get_transformation(transformation_name, **kwargs):
|
||||||
|
if transformation_name == "RandomResizedCrop":
|
||||||
|
return T.RandomResizedCrop(**kwargs)
|
||||||
|
elif transformation_name == "RandomHorizontalFlip":
|
||||||
|
return T.RandomHorizontalFlip()
|
||||||
|
elif transformation_name == "ToTensor":
|
||||||
|
return T.ToTensor()
|
||||||
|
|
||||||
|
transforms = []
|
||||||
|
for transform_name, transform_kwargs_or_bool in self.preprocessing.items():
|
||||||
|
transform_kwargs = {} if not isinstance(transform_kwargs_or_bool, dict) else transform_kwargs_or_bool
|
||||||
|
transforms.append(_get_transformation(transform_name, **transform_kwargs))
|
||||||
|
return T.Compose(transforms)
|
||||||
|
|
||||||
class DecoderTrainConfig(BaseModel):
|
class DecoderTrainConfig(BaseModel):
|
||||||
epochs: int = 20
|
epochs: int = 20
|
||||||
lr: float = 1e-4
|
lr: SingularOrIterable(float) = 1e-4
|
||||||
wd: float = 0.01
|
wd: SingularOrIterable(float) = 0.01
|
||||||
max_grad_norm: float = 0.5
|
max_grad_norm: SingularOrIterable(float) = 0.5
|
||||||
save_every_n_samples: int = 100000
|
save_every_n_samples: int = 100000
|
||||||
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
|
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
|
||||||
device: str = 'cuda:0'
|
device: str = 'cuda:0'
|
||||||
epoch_samples: int = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
|
epoch_samples: int = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
|
||||||
validation_samples: int = None # Same as above but for validation.
|
validation_samples: int = None # Same as above but for validation.
|
||||||
use_ema: bool = True
|
use_ema: bool = True
|
||||||
ema_beta: float = 0.99
|
ema_beta: float = 0.999
|
||||||
amp: bool = False
|
amp: bool = False
|
||||||
save_all: bool = False # Whether to preserve all checkpoints
|
save_all: bool = False # Whether to preserve all checkpoints
|
||||||
save_latest: bool = True # Whether to always save the latest checkpoint
|
save_latest: bool = True # Whether to always save the latest checkpoint
|
||||||
save_best: bool = True # Whether to save the best checkpoint
|
save_best: bool = True # Whether to save the best checkpoint
|
||||||
unet_training_mask: List[bool] = None # If None, use all unets
|
unet_training_mask: ListOrTuple(bool) = None # If None, use all unets
|
||||||
|
|
||||||
class DecoderEvaluateConfig(BaseModel):
|
class DecoderEvaluateConfig(BaseModel):
|
||||||
n_evaluation_samples: int = 1000
|
n_evaluation_samples: int = 1000
|
||||||
@@ -89,14 +248,6 @@ class DecoderEvaluateConfig(BaseModel):
|
|||||||
KID: Dict[str, Any] = None
|
KID: Dict[str, Any] = None
|
||||||
LPIPS: Dict[str, Any] = None
|
LPIPS: Dict[str, Any] = None
|
||||||
|
|
||||||
class TrackerConfig(BaseModel):
|
|
||||||
tracker_type: str = 'console' # Decoder currently supports console and wandb
|
|
||||||
data_path: str = './models' # The path where files will be saved locally
|
|
||||||
init_config: Dict[str, Any] = None
|
|
||||||
wandb_entity: str = '' # Only needs to be set if tracker_type is wandb
|
|
||||||
wandb_project: str = ''
|
|
||||||
verbose: bool = False # Whether to print console logging for non-console trackers
|
|
||||||
|
|
||||||
class DecoderLoadConfig(BaseModel):
|
class DecoderLoadConfig(BaseModel):
|
||||||
source: str = None # Supports file and wandb
|
source: str = None # Supports file and wandb
|
||||||
run_path: str = '' # Used only if source is wandb
|
run_path: str = '' # Used only if source is wandb
|
||||||
@@ -104,7 +255,6 @@ class DecoderLoadConfig(BaseModel):
|
|||||||
resume: bool = False # If using wandb, whether to resume the run
|
resume: bool = False # If using wandb, whether to resume the run
|
||||||
|
|
||||||
class TrainDecoderConfig(BaseModel):
|
class TrainDecoderConfig(BaseModel):
|
||||||
unets: List[UnetConfig]
|
|
||||||
decoder: DecoderConfig
|
decoder: DecoderConfig
|
||||||
data: DecoderDataConfig
|
data: DecoderDataConfig
|
||||||
train: DecoderTrainConfig
|
train: DecoderTrainConfig
|
||||||
@@ -117,19 +267,3 @@ class TrainDecoderConfig(BaseModel):
|
|||||||
with open(json_path) as f:
|
with open(json_path) as f:
|
||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
return cls(**config)
|
return cls(**config)
|
||||||
|
|
||||||
@property
|
|
||||||
def img_preproc(self):
|
|
||||||
def _get_transformation(transformation_name, **kwargs):
|
|
||||||
if transformation_name == "RandomResizedCrop":
|
|
||||||
return T.RandomResizedCrop(**kwargs)
|
|
||||||
elif transformation_name == "RandomHorizontalFlip":
|
|
||||||
return T.RandomHorizontalFlip()
|
|
||||||
elif transformation_name == "ToTensor":
|
|
||||||
return T.ToTensor()
|
|
||||||
|
|
||||||
transforms = []
|
|
||||||
for transform_name, transform_kwargs_or_bool in self.data.preprocessing.items():
|
|
||||||
transform_kwargs = {} if not isinstance(transform_kwargs_or_bool, dict) else transform_kwargs_or_bool
|
|
||||||
transforms.append(_get_transformation(transform_name, **transform_kwargs))
|
|
||||||
return T.Compose(transforms)
|
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import time
|
import time
|
||||||
import copy
|
import copy
|
||||||
|
from pathlib import Path
|
||||||
from math import ceil
|
from math import ceil
|
||||||
from functools import partial, wraps
|
from functools import partial, wraps
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
@@ -10,6 +11,8 @@ from torch.cuda.amp import autocast, GradScaler
|
|||||||
|
|
||||||
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
|
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
|
||||||
from dalle2_pytorch.optimizer import get_optimizer
|
from dalle2_pytorch.optimizer import get_optimizer
|
||||||
|
from dalle2_pytorch.version import __version__
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -55,6 +58,16 @@ def num_to_groups(num, divisor):
|
|||||||
arr.append(remainder)
|
arr.append(remainder)
|
||||||
return arr
|
return arr
|
||||||
|
|
||||||
|
def clamp(value, min_value = None, max_value = None):
|
||||||
|
assert exists(min_value) or exists(max_value)
|
||||||
|
if exists(min_value):
|
||||||
|
value = max(value, min_value)
|
||||||
|
|
||||||
|
if exists(max_value):
|
||||||
|
value = min(value, max_value)
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
# decorators
|
# decorators
|
||||||
|
|
||||||
def cast_torch_tensor(fn):
|
def cast_torch_tensor(fn):
|
||||||
@@ -128,12 +141,6 @@ def split_args_and_kwargs(*args, split_size = None, **kwargs):
|
|||||||
chunk_size_frac = chunk_size / batch_size
|
chunk_size_frac = chunk_size / batch_size
|
||||||
yield chunk_size_frac, (chunked_args, chunked_kwargs)
|
yield chunk_size_frac, (chunked_args, chunked_kwargs)
|
||||||
|
|
||||||
# print helpers
|
|
||||||
|
|
||||||
def print_ribbon(s, symbol = '=', repeat = 40):
|
|
||||||
flank = symbol * repeat
|
|
||||||
return f'{flank} {s} {flank}'
|
|
||||||
|
|
||||||
# saving and loading functions
|
# saving and loading functions
|
||||||
|
|
||||||
# for diffusion prior
|
# for diffusion prior
|
||||||
@@ -175,12 +182,34 @@ def save_diffusion_model(save_path, model, optimizer, scaler, config, image_embe
|
|||||||
# exponential moving average wrapper
|
# exponential moving average wrapper
|
||||||
|
|
||||||
class EMA(nn.Module):
|
class EMA(nn.Module):
|
||||||
|
"""
|
||||||
|
Implements exponential moving average shadowing for your model.
|
||||||
|
|
||||||
|
Utilizes an inverse decay schedule to manage longer term training runs.
|
||||||
|
By adjusting the power, you can control how fast EMA will ramp up to your specified beta.
|
||||||
|
|
||||||
|
@crowsonkb's notes on EMA Warmup:
|
||||||
|
|
||||||
|
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are
|
||||||
|
good values for models you plan to train for a million or more steps (reaches decay
|
||||||
|
factor 0.999 at 31.6K steps, 0.9999 at 1M steps), gamma=1, power=3/4 for models
|
||||||
|
you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at
|
||||||
|
215.4k steps).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
|
||||||
|
power (float): Exponential factor of EMA warmup. Default: 1.
|
||||||
|
min_value (float): The minimum EMA decay rate. Default: 0.
|
||||||
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model,
|
model,
|
||||||
beta = 0.9999,
|
beta = 0.9999,
|
||||||
update_after_step = 1000,
|
update_after_step = 10000,
|
||||||
update_every = 10,
|
update_every = 10,
|
||||||
|
inv_gamma = 1.0,
|
||||||
|
power = 2/3,
|
||||||
|
min_value = 0.0,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.beta = beta
|
self.beta = beta
|
||||||
@@ -188,47 +217,65 @@ class EMA(nn.Module):
|
|||||||
self.ema_model = copy.deepcopy(model)
|
self.ema_model = copy.deepcopy(model)
|
||||||
|
|
||||||
self.update_every = update_every
|
self.update_every = update_every
|
||||||
self.update_after_step = update_after_step // update_every # only start EMA after this step number, starting at 0
|
self.update_after_step = update_after_step
|
||||||
|
|
||||||
|
self.inv_gamma = inv_gamma
|
||||||
|
self.power = power
|
||||||
|
self.min_value = min_value
|
||||||
|
|
||||||
self.register_buffer('initted', torch.Tensor([False]))
|
self.register_buffer('initted', torch.Tensor([False]))
|
||||||
self.register_buffer('step', torch.tensor([0.]))
|
self.register_buffer('step', torch.tensor([0]))
|
||||||
|
|
||||||
def restore_ema_model_device(self):
|
def restore_ema_model_device(self):
|
||||||
device = self.initted.device
|
device = self.initted.device
|
||||||
self.ema_model.to(device)
|
self.ema_model.to(device)
|
||||||
|
|
||||||
def copy_params_from_model_to_ema(self):
|
def copy_params_from_model_to_ema(self):
|
||||||
self.ema_model.state_dict(self.online_model.state_dict())
|
for ma_param, current_param in zip(list(self.ema_model.parameters()), list(self.online_model.parameters())):
|
||||||
|
ma_param.data.copy_(current_param.data)
|
||||||
|
|
||||||
|
for ma_buffer, current_buffer in zip(list(self.ema_model.buffers()), list(self.online_model.buffers())):
|
||||||
|
ma_buffer.data.copy_(current_buffer.data)
|
||||||
|
|
||||||
|
def get_current_decay(self):
|
||||||
|
epoch = clamp(self.step.item() - self.update_after_step - 1, min_value = 0)
|
||||||
|
value = 1 - (1 + epoch / self.inv_gamma) ** - self.power
|
||||||
|
|
||||||
|
if epoch <= 0:
|
||||||
|
return 0.
|
||||||
|
|
||||||
|
return clamp(value, min_value = self.min_value, max_value = self.beta)
|
||||||
|
|
||||||
def update(self):
|
def update(self):
|
||||||
|
step = self.step.item()
|
||||||
self.step += 1
|
self.step += 1
|
||||||
|
|
||||||
if (self.step % self.update_every) != 0:
|
if (step % self.update_every) != 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
if self.step <= self.update_after_step:
|
if step <= self.update_after_step:
|
||||||
self.copy_params_from_model_to_ema()
|
self.copy_params_from_model_to_ema()
|
||||||
return
|
return
|
||||||
|
|
||||||
if not self.initted:
|
if not self.initted.item():
|
||||||
self.copy_params_from_model_to_ema()
|
self.copy_params_from_model_to_ema()
|
||||||
self.initted.data.copy_(torch.Tensor([True]))
|
self.initted.data.copy_(torch.Tensor([True]))
|
||||||
|
|
||||||
self.update_moving_average(self.ema_model, self.online_model)
|
self.update_moving_average(self.ema_model, self.online_model)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def update_moving_average(self, ma_model, current_model):
|
def update_moving_average(self, ma_model, current_model):
|
||||||
def calculate_ema(beta, old, new):
|
current_decay = self.get_current_decay()
|
||||||
if not exists(old):
|
|
||||||
return new
|
|
||||||
return old * beta + (1 - beta) * new
|
|
||||||
|
|
||||||
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
|
for current_params, ma_params in zip(list(current_model.parameters()), list(ma_model.parameters())):
|
||||||
old_weight, up_weight = ma_params.data, current_params.data
|
difference = ma_params.data - current_params.data
|
||||||
ma_params.data = calculate_ema(self.beta, old_weight, up_weight)
|
difference.mul_(1.0 - current_decay)
|
||||||
|
ma_params.sub_(difference)
|
||||||
|
|
||||||
for current_buffer, ma_buffer in zip(current_model.buffers(), ma_model.buffers()):
|
for current_buffer, ma_buffer in zip(list(current_model.buffers()), list(ma_model.buffers())):
|
||||||
new_buffer_value = calculate_ema(self.beta, ma_buffer, current_buffer)
|
difference = ma_buffer - current_buffer
|
||||||
ma_buffer.copy_(new_buffer_value)
|
difference.mul_(1.0 - current_decay)
|
||||||
|
ma_buffer.sub_(difference)
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
return self.ema_model(*args, **kwargs)
|
return self.ema_model(*args, **kwargs)
|
||||||
@@ -255,6 +302,7 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
eps = 1e-6,
|
eps = 1e-6,
|
||||||
max_grad_norm = None,
|
max_grad_norm = None,
|
||||||
amp = False,
|
amp = False,
|
||||||
|
group_wd_params = True,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -280,6 +328,7 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
lr = lr,
|
lr = lr,
|
||||||
wd = wd,
|
wd = wd,
|
||||||
eps = eps,
|
eps = eps,
|
||||||
|
group_wd_params = group_wd_params,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -287,7 +336,50 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
|
|
||||||
self.max_grad_norm = max_grad_norm
|
self.max_grad_norm = max_grad_norm
|
||||||
|
|
||||||
self.register_buffer('step', torch.tensor([0.]))
|
self.register_buffer('step', torch.tensor([0]))
|
||||||
|
|
||||||
|
def save(self, path, overwrite = True, **kwargs):
|
||||||
|
path = Path(path)
|
||||||
|
assert not (path.exists() and not overwrite)
|
||||||
|
path.parent.mkdir(parents = True, exist_ok = True)
|
||||||
|
|
||||||
|
save_obj = dict(
|
||||||
|
scaler = self.scaler.state_dict(),
|
||||||
|
optimizer = self.optimizer.state_dict(),
|
||||||
|
model = self.diffusion_prior.state_dict(),
|
||||||
|
version = __version__,
|
||||||
|
step = self.step.item(),
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.use_ema:
|
||||||
|
save_obj = {**save_obj, 'ema': self.ema_diffusion_prior.state_dict()}
|
||||||
|
|
||||||
|
torch.save(save_obj, str(path))
|
||||||
|
|
||||||
|
def load(self, path, only_model = False, strict = True):
|
||||||
|
path = Path(path)
|
||||||
|
assert path.exists()
|
||||||
|
|
||||||
|
loaded_obj = torch.load(str(path))
|
||||||
|
|
||||||
|
if version.parse(__version__) != loaded_obj['version']:
|
||||||
|
print(f'loading saved diffusion prior at version {loaded_obj["version"]} but current package version is at {__version__}')
|
||||||
|
|
||||||
|
self.diffusion_prior.load_state_dict(loaded_obj['model'], strict = strict)
|
||||||
|
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
|
||||||
|
|
||||||
|
if only_model:
|
||||||
|
return loaded_obj
|
||||||
|
|
||||||
|
self.scaler.load_state_dict(loaded_obj['scaler'])
|
||||||
|
self.optimizer.load_state_dict(loaded_obj['optimizer'])
|
||||||
|
|
||||||
|
if self.use_ema:
|
||||||
|
assert 'ema' in loaded_obj
|
||||||
|
self.ema_diffusion_prior.load_state_dict(loaded_obj['ema'], strict = strict)
|
||||||
|
|
||||||
|
return loaded_obj
|
||||||
|
|
||||||
def update(self):
|
def update(self):
|
||||||
if exists(self.max_grad_norm):
|
if exists(self.max_grad_norm):
|
||||||
@@ -368,6 +460,7 @@ class DecoderTrainer(nn.Module):
|
|||||||
eps = 1e-8,
|
eps = 1e-8,
|
||||||
max_grad_norm = 0.5,
|
max_grad_norm = 0.5,
|
||||||
amp = False,
|
amp = False,
|
||||||
|
group_wd_params = True,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -393,6 +486,7 @@ class DecoderTrainer(nn.Module):
|
|||||||
lr = unet_lr,
|
lr = unet_lr,
|
||||||
wd = unet_wd,
|
wd = unet_wd,
|
||||||
eps = unet_eps,
|
eps = unet_eps,
|
||||||
|
group_wd_params = group_wd_params,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -410,6 +504,60 @@ class DecoderTrainer(nn.Module):
|
|||||||
|
|
||||||
self.register_buffer('step', torch.tensor([0.]))
|
self.register_buffer('step', torch.tensor([0.]))
|
||||||
|
|
||||||
|
def save(self, path, overwrite = True, **kwargs):
|
||||||
|
path = Path(path)
|
||||||
|
assert not (path.exists() and not overwrite)
|
||||||
|
path.parent.mkdir(parents = True, exist_ok = True)
|
||||||
|
|
||||||
|
save_obj = dict(
|
||||||
|
model = self.decoder.state_dict(),
|
||||||
|
version = __version__,
|
||||||
|
step = self.step.item(),
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
for ind in range(0, self.num_unets):
|
||||||
|
scaler_key = f'scaler{ind}'
|
||||||
|
optimizer_key = f'scaler{ind}'
|
||||||
|
scaler = getattr(self, scaler_key)
|
||||||
|
optimizer = getattr(self, optimizer_key)
|
||||||
|
save_obj = {**save_obj, scaler_key: scaler.state_dict(), optimizer_key: optimizer.state_dict()}
|
||||||
|
|
||||||
|
if self.use_ema:
|
||||||
|
save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}
|
||||||
|
|
||||||
|
torch.save(save_obj, str(path))
|
||||||
|
|
||||||
|
def load(self, path, only_model = False, strict = True):
|
||||||
|
path = Path(path)
|
||||||
|
assert path.exists()
|
||||||
|
|
||||||
|
loaded_obj = torch.load(str(path))
|
||||||
|
|
||||||
|
if version.parse(__version__) != loaded_obj['version']:
|
||||||
|
print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {__version__}')
|
||||||
|
|
||||||
|
self.decoder.load_state_dict(loaded_obj['model'], strict = strict)
|
||||||
|
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
|
||||||
|
|
||||||
|
if only_model:
|
||||||
|
return loaded_obj
|
||||||
|
|
||||||
|
for ind in range(0, self.num_unets):
|
||||||
|
scaler_key = f'scaler{ind}'
|
||||||
|
optimizer_key = f'scaler{ind}'
|
||||||
|
scaler = getattr(self, scaler_key)
|
||||||
|
optimizer = getattr(self, optimizer_key)
|
||||||
|
|
||||||
|
scaler.load_state_dict(loaded_obj[scaler_key])
|
||||||
|
optimizer.load_state_dict(loaded_obj[optimizer_key])
|
||||||
|
|
||||||
|
if self.use_ema:
|
||||||
|
assert 'ema' in loaded_obj
|
||||||
|
self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict)
|
||||||
|
|
||||||
|
return loaded_obj
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def unets(self):
|
def unets(self):
|
||||||
return nn.ModuleList([ema.ema_model for ema in self.ema_unets])
|
return nn.ModuleList([ema.ema_model for ema in self.ema_unets])
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
|
# time helpers
|
||||||
|
|
||||||
class Timer:
|
class Timer:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.reset()
|
self.reset()
|
||||||
@@ -9,3 +11,19 @@ class Timer:
|
|||||||
|
|
||||||
def elapsed(self):
|
def elapsed(self):
|
||||||
return time.time() - self.last_time
|
return time.time() - self.last_time
|
||||||
|
|
||||||
|
# print helpers
|
||||||
|
|
||||||
|
def print_ribbon(s, symbol = '=', repeat = 40):
|
||||||
|
flank = symbol * repeat
|
||||||
|
return f'{flank} {s} {flank}'
|
||||||
|
|
||||||
|
# import helpers
|
||||||
|
|
||||||
|
def import_or_print_error(pkg_name, err_str = None):
|
||||||
|
try:
|
||||||
|
return importlib.import_module(pkg_name)
|
||||||
|
except ModuleNotFoundError as e:
|
||||||
|
if exists(err_str):
|
||||||
|
print(err_str)
|
||||||
|
exit()
|
||||||
|
|||||||
1
dalle2_pytorch/version.py
Normal file
1
dalle2_pytorch/version.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
__version__ = '0.9.2'
|
||||||
@@ -68,8 +68,8 @@ def group_dict_by_key(cond, d):
|
|||||||
return_val[ind][key] = d[key]
|
return_val[ind][key] = d[key]
|
||||||
return (*return_val,)
|
return (*return_val,)
|
||||||
|
|
||||||
def string_begins_with(prefix, str):
|
def string_begins_with(prefix, string_input):
|
||||||
return str.startswith(prefix)
|
return string_input.startswith(prefix)
|
||||||
|
|
||||||
def group_by_key_prefix(prefix, d):
|
def group_by_key_prefix(prefix, d):
|
||||||
return group_dict_by_key(partial(string_begins_with, prefix), d)
|
return group_dict_by_key(partial(string_begins_with, prefix), d)
|
||||||
|
|||||||
5
setup.py
5
setup.py
@@ -1,4 +1,5 @@
|
|||||||
from setuptools import setup, find_packages
|
from setuptools import setup, find_packages
|
||||||
|
exec(open('dalle2_pytorch/version.py').read())
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name = 'dalle2-pytorch',
|
name = 'dalle2-pytorch',
|
||||||
@@ -10,7 +11,7 @@ setup(
|
|||||||
'dream = dalle2_pytorch.cli:dream'
|
'dream = dalle2_pytorch.cli:dream'
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
version = '0.4.2',
|
version = __version__,
|
||||||
license='MIT',
|
license='MIT',
|
||||||
description = 'DALL-E 2',
|
description = 'DALL-E 2',
|
||||||
author = 'Phil Wang',
|
author = 'Phil Wang',
|
||||||
@@ -31,6 +32,7 @@ setup(
|
|||||||
'embedding-reader',
|
'embedding-reader',
|
||||||
'kornia>=0.5.4',
|
'kornia>=0.5.4',
|
||||||
'numpy',
|
'numpy',
|
||||||
|
'packaging',
|
||||||
'pillow',
|
'pillow',
|
||||||
'pydantic',
|
'pydantic',
|
||||||
'resize-right>=0.0.2',
|
'resize-right>=0.0.2',
|
||||||
@@ -40,7 +42,6 @@ setup(
|
|||||||
'tqdm',
|
'tqdm',
|
||||||
'vector-quantize-pytorch',
|
'vector-quantize-pytorch',
|
||||||
'x-clip>=0.4.4',
|
'x-clip>=0.4.4',
|
||||||
'youtokentome',
|
|
||||||
'webdataset>=0.2.5',
|
'webdataset>=0.2.5',
|
||||||
'fsspec>=2022.1.0',
|
'fsspec>=2022.1.0',
|
||||||
'torchmetrics[image]>=0.8.0'
|
'torchmetrics[image]>=0.8.0'
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
from dalle2_pytorch import Unet, Decoder
|
from dalle2_pytorch import Unet, Decoder
|
||||||
from dalle2_pytorch.trainer import DecoderTrainer, print_ribbon
|
from dalle2_pytorch.trainer import DecoderTrainer
|
||||||
from dalle2_pytorch.dataloaders import create_image_embedding_dataloader
|
from dalle2_pytorch.dataloaders import create_image_embedding_dataloader
|
||||||
from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker
|
from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker
|
||||||
from dalle2_pytorch.train_configs import TrainDecoderConfig
|
from dalle2_pytorch.train_configs import TrainDecoderConfig
|
||||||
from dalle2_pytorch.utils import Timer
|
from dalle2_pytorch.utils import Timer, print_ribbon
|
||||||
|
from dalle2_pytorch.dalle2_pytorch import resize_image_to
|
||||||
|
|
||||||
import torchvision
|
import torchvision
|
||||||
import torch
|
import torch
|
||||||
@@ -85,20 +86,6 @@ def create_dataloaders(
|
|||||||
"test_sampling": test_sampling_dataloader
|
"test_sampling": test_sampling_dataloader
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def create_decoder(device, decoder_config, unets_config):
|
|
||||||
"""Creates a sample decoder"""
|
|
||||||
|
|
||||||
unets = [Unet(**config.dict()) for config in unets_config]
|
|
||||||
|
|
||||||
decoder = Decoder(
|
|
||||||
unet=unets,
|
|
||||||
**decoder_config.dict()
|
|
||||||
)
|
|
||||||
|
|
||||||
decoder.to(device=device)
|
|
||||||
return decoder
|
|
||||||
|
|
||||||
def get_dataset_keys(dataloader):
|
def get_dataset_keys(dataloader):
|
||||||
"""
|
"""
|
||||||
It is sometimes neccesary to get the keys the dataloader is returning. Since the dataset is burried in the dataloader, we need to do a process to recover it.
|
It is sometimes neccesary to get the keys the dataloader is returning. Since the dataset is burried in the dataloader, we need to do a process to recover it.
|
||||||
@@ -150,6 +137,14 @@ def generate_grid_samples(trainer, examples, text_prepend=""):
|
|||||||
Generates samples and uses torchvision to put them in a side by side grid for easy viewing
|
Generates samples and uses torchvision to put them in a side by side grid for easy viewing
|
||||||
"""
|
"""
|
||||||
real_images, generated_images, captions = generate_samples(trainer, examples, text_prepend)
|
real_images, generated_images, captions = generate_samples(trainer, examples, text_prepend)
|
||||||
|
|
||||||
|
real_image_size = real_images[0].shape[-1]
|
||||||
|
generated_image_size = generated_images[0].shape[-1]
|
||||||
|
|
||||||
|
# training images may be larger than the generated one
|
||||||
|
if real_image_size > generated_image_size:
|
||||||
|
real_images = [resize_image_to(image, generated_image_size) for image in real_images]
|
||||||
|
|
||||||
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
|
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
|
||||||
return grid_images, captions
|
return grid_images, captions
|
||||||
|
|
||||||
@@ -216,7 +211,7 @@ def recall_trainer(tracker, trainer, recall_source=None, **load_config):
|
|||||||
Loads the model with an appropriate method depending on the tracker
|
Loads the model with an appropriate method depending on the tracker
|
||||||
"""
|
"""
|
||||||
print(print_ribbon(f"Loading model from {recall_source}"))
|
print(print_ribbon(f"Loading model from {recall_source}"))
|
||||||
state_dict = tracker.recall_state_dict(recall_source, **load_config)
|
state_dict = tracker.recall_state_dict(recall_source, **load_config.dict())
|
||||||
trainer.load_state_dict(state_dict["trainer"])
|
trainer.load_state_dict(state_dict["trainer"])
|
||||||
print("Model loaded")
|
print("Model loaded")
|
||||||
return state_dict["epoch"], state_dict["step"], state_dict["validation_losses"]
|
return state_dict["epoch"], state_dict["step"], state_dict["validation_losses"]
|
||||||
@@ -336,7 +331,7 @@ def train(
|
|||||||
sample = 0
|
sample = 0
|
||||||
average_loss = 0
|
average_loss = 0
|
||||||
timer = Timer()
|
timer = Timer()
|
||||||
for i, (img, emb, txt) in enumerate(dataloaders["val"]):
|
for i, (img, emb, *_) in enumerate(dataloaders["val"]):
|
||||||
sample += img.shape[0]
|
sample += img.shape[0]
|
||||||
img, emb = send_to_device((img, emb))
|
img, emb = send_to_device((img, emb))
|
||||||
|
|
||||||
@@ -361,7 +356,7 @@ def train(
|
|||||||
# Compute evaluation metrics
|
# Compute evaluation metrics
|
||||||
if exists(evaluate_config):
|
if exists(evaluate_config):
|
||||||
print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
|
print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
|
||||||
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config)
|
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config.dict())
|
||||||
tracker.log(evaluation, step=step, verbose=True)
|
tracker.log(evaluation, step=step, verbose=True)
|
||||||
|
|
||||||
# Generate sample images
|
# Generate sample images
|
||||||
@@ -420,7 +415,7 @@ def initialize_training(config):
|
|||||||
|
|
||||||
dataloaders = create_dataloaders (
|
dataloaders = create_dataloaders (
|
||||||
available_shards=all_shards,
|
available_shards=all_shards,
|
||||||
img_preproc = config.img_preproc,
|
img_preproc = config.data.img_preproc,
|
||||||
train_prop = config.data.splits.train,
|
train_prop = config.data.splits.train,
|
||||||
val_prop = config.data.splits.val,
|
val_prop = config.data.splits.val,
|
||||||
test_prop = config.data.splits.test,
|
test_prop = config.data.splits.test,
|
||||||
@@ -428,7 +423,7 @@ def initialize_training(config):
|
|||||||
**config.data.dict()
|
**config.data.dict()
|
||||||
)
|
)
|
||||||
|
|
||||||
decoder = create_decoder(device, config.decoder, config.unets)
|
decoder = config.decoder.create().to(device = device)
|
||||||
num_parameters = sum(p.numel() for p in decoder.parameters())
|
num_parameters = sum(p.numel() for p in decoder.parameters())
|
||||||
print(print_ribbon("Loaded Config", repeat=40))
|
print(print_ribbon("Loaded Config", repeat=40))
|
||||||
print(f"Number of parameters: {num_parameters}")
|
print(f"Number of parameters: {num_parameters}")
|
||||||
|
|||||||
@@ -7,14 +7,12 @@ import torch
|
|||||||
import clip
|
import clip
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from dalle2_pytorch.dataloaders import make_splits
|
from dalle2_pytorch.dataloaders import make_splits, get_reader
|
||||||
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter
|
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter
|
||||||
from dalle2_pytorch.trainer import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model, print_ribbon
|
from dalle2_pytorch.trainer import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model
|
||||||
|
|
||||||
from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker
|
from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker
|
||||||
from dalle2_pytorch.utils import Timer
|
from dalle2_pytorch.utils import Timer, print_ribbon
|
||||||
|
|
||||||
from embedding_reader import EmbeddingReader
|
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
@@ -31,7 +29,7 @@ def exists(val):
|
|||||||
|
|
||||||
# functions
|
# functions
|
||||||
|
|
||||||
def eval_model(model, dataloader, text_conditioned, loss_type, phase="Validation"):
|
def eval_model(model, dataloader, text_conditioned, loss_type, device, phase="Validation",):
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@@ -39,6 +37,8 @@ def eval_model(model, dataloader, text_conditioned, loss_type, phase="Validation
|
|||||||
total_samples = 0.
|
total_samples = 0.
|
||||||
|
|
||||||
for image_embeddings, text_data in tqdm(dataloader):
|
for image_embeddings, text_data in tqdm(dataloader):
|
||||||
|
image_embeddings = image_embeddings.to(device)
|
||||||
|
text_data = text_data.to(device)
|
||||||
|
|
||||||
batches = image_embeddings.shape[0]
|
batches = image_embeddings.shape[0]
|
||||||
|
|
||||||
@@ -57,12 +57,14 @@ def eval_model(model, dataloader, text_conditioned, loss_type, phase="Validation
|
|||||||
|
|
||||||
tracker.log({f'{phase} {loss_type}': avg_loss})
|
tracker.log({f'{phase} {loss_type}': avg_loss})
|
||||||
|
|
||||||
def report_cosine_sims(diffusion_prior, dataloader, text_conditioned):
|
def report_cosine_sims(diffusion_prior, dataloader, text_conditioned, device):
|
||||||
diffusion_prior.eval()
|
diffusion_prior.eval()
|
||||||
|
|
||||||
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
|
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
|
||||||
|
|
||||||
for test_image_embeddings, text_data in tqdm(dataloader):
|
for test_image_embeddings, text_data in tqdm(dataloader):
|
||||||
|
test_image_embeddings = test_image_embeddings.to(device)
|
||||||
|
text_data = text_data.to(device)
|
||||||
|
|
||||||
# we are text conditioned, we produce an embedding from the tokenized text
|
# we are text conditioned, we produce an embedding from the tokenized text
|
||||||
if text_conditioned:
|
if text_conditioned:
|
||||||
@@ -296,15 +298,31 @@ def train(
|
|||||||
|
|
||||||
# Utilize wrapper to abstract away loader logic
|
# Utilize wrapper to abstract away loader logic
|
||||||
print_ribbon("Downloading Embeddings")
|
print_ribbon("Downloading Embeddings")
|
||||||
loader_args = dict(text_conditioned=dp_condition_on_text_encodings, batch_size=batch_size, num_data_points=num_data_points,
|
reader_args = dict(text_conditioned=dp_condition_on_text_encodings, img_url=image_embed_url)
|
||||||
train_split=train_percent, eval_split=val_percent, device=device, img_url=image_embed_url)
|
|
||||||
|
|
||||||
if dp_condition_on_text_encodings:
|
if dp_condition_on_text_encodings:
|
||||||
loader_args = dict(**loader_args, meta_url=meta_url)
|
reader_args = dict(**reader_args, meta_url=meta_url)
|
||||||
|
img_reader = get_reader(**reader_args)
|
||||||
|
train_loader, eval_loader, test_loader = make_splits(
|
||||||
|
text_conditioned=dp_condition_on_text_encodings,
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_data_points=num_data_points,
|
||||||
|
train_split=train_percent,
|
||||||
|
eval_split=val_percent,
|
||||||
|
image_reader=img_reader
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
loader_args = dict(**loader_args, txt_url=text_embed_url)
|
reader_args = dict(**reader_args, txt_url=text_embed_url)
|
||||||
|
img_reader, txt_reader = get_reader(**reader_args)
|
||||||
train_loader, eval_loader, test_loader = make_splits(**loader_args)
|
train_loader, eval_loader, test_loader = make_splits(
|
||||||
|
text_conditioned=dp_condition_on_text_encodings,
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_data_points=num_data_points,
|
||||||
|
train_split=train_percent,
|
||||||
|
eval_split=val_percent,
|
||||||
|
image_reader=img_reader,
|
||||||
|
text_reader=txt_reader
|
||||||
|
)
|
||||||
|
|
||||||
### Training code ###
|
### Training code ###
|
||||||
|
|
||||||
@@ -315,9 +333,11 @@ def train(
|
|||||||
for _ in range(epochs):
|
for _ in range(epochs):
|
||||||
|
|
||||||
for image, text in tqdm(train_loader):
|
for image, text in tqdm(train_loader):
|
||||||
|
|
||||||
diffusion_prior.train()
|
diffusion_prior.train()
|
||||||
|
|
||||||
|
image = image.to(device)
|
||||||
|
text = text.to(device)
|
||||||
|
|
||||||
input_args = dict(image_embed=image)
|
input_args = dict(image_embed=image)
|
||||||
if dp_condition_on_text_encodings:
|
if dp_condition_on_text_encodings:
|
||||||
input_args = dict(**input_args, text = text)
|
input_args = dict(**input_args, text = text)
|
||||||
@@ -350,9 +370,9 @@ def train(
|
|||||||
# Use NUM_TEST_EMBEDDINGS samples from the test set each time
|
# Use NUM_TEST_EMBEDDINGS samples from the test set each time
|
||||||
# Get embeddings from the most recently saved model
|
# Get embeddings from the most recently saved model
|
||||||
if(step % REPORT_METRICS_EVERY) == 0:
|
if(step % REPORT_METRICS_EVERY) == 0:
|
||||||
report_cosine_sims(diffusion_prior, eval_loader, dp_condition_on_text_encodings)
|
report_cosine_sims(diffusion_prior, eval_loader, dp_condition_on_text_encodings, device=device)
|
||||||
### Evaluate model(validation run) ###
|
### Evaluate model(validation run) ###
|
||||||
eval_model(diffusion_prior, eval_loader, dp_condition_on_text_encodings, dp_loss_type, phase="Validation")
|
eval_model(diffusion_prior, eval_loader, dp_condition_on_text_encodings, dp_loss_type, phase="Validation", device=device)
|
||||||
|
|
||||||
step += 1
|
step += 1
|
||||||
trainer.update()
|
trainer.update()
|
||||||
|
|||||||
Reference in New Issue
Block a user