mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 11:34:29 +01:00
Compare commits
356 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6161b61c55 | ||
|
|
1ed0f9d80b | ||
|
|
f326a95e26 | ||
|
|
d7a0a2ce4b | ||
|
|
f23fab7ef7 | ||
|
|
857b9fbf1e | ||
|
|
8864fd0aa7 | ||
|
|
72bf159331 | ||
|
|
e5e47cfecb | ||
|
|
fa533962bd | ||
|
|
276abf337b | ||
|
|
ae42d03006 | ||
|
|
4d346e98d9 | ||
|
|
2b1fd1ad2e | ||
|
|
82a2ef37d9 | ||
|
|
5c397c9d66 | ||
|
|
0f4edff214 | ||
|
|
501a8c7c46 | ||
|
|
4e49373fc5 | ||
|
|
49de72040c | ||
|
|
271a376eaf | ||
|
|
e527002472 | ||
|
|
c12e067178 | ||
|
|
c6629c431a | ||
|
|
7ac2fc79f2 | ||
|
|
a1ef023193 | ||
|
|
d49eca62fa | ||
|
|
8aab69b91e | ||
|
|
b432df2f7b | ||
|
|
ebaa0d28c2 | ||
|
|
8b0d459b25 | ||
|
|
0064661729 | ||
|
|
b895f52843 | ||
|
|
80497e9839 | ||
|
|
f526f14d7c | ||
|
|
8997f178d6 | ||
|
|
022c94e443 | ||
|
|
430961cb97 | ||
|
|
721f9687c1 | ||
|
|
e0524a6aff | ||
|
|
c85e0d5c35 | ||
|
|
db0642c4cd | ||
|
|
bb86ab2404 | ||
|
|
ae056dd67c | ||
|
|
033d6b0ce8 | ||
|
|
c7ea8748db | ||
|
|
13382885d9 | ||
|
|
c3d4a7ffe4 | ||
|
|
164d9be444 | ||
|
|
5562ec6be2 | ||
|
|
89ff04cfe2 | ||
|
|
f4016f6302 | ||
|
|
1212f7058d | ||
|
|
dab106d4e5 | ||
|
|
bb151ca6b1 | ||
|
|
4a59dea4cf | ||
|
|
ecf9e8027d | ||
|
|
36c5079bd7 | ||
|
|
4a4c7ac9e6 | ||
|
|
fad7481479 | ||
|
|
123658d082 | ||
|
|
11d4e11f10 | ||
|
|
99778e12de | ||
|
|
0f0011caf0 | ||
|
|
7b7a62044a | ||
|
|
156fe5ed9f | ||
|
|
5ec34bebe1 | ||
|
|
8eaacf1ac1 | ||
|
|
e66c7b0249 | ||
|
|
f7cd4a0992 | ||
|
|
68e7d2f241 | ||
|
|
74f222596a | ||
|
|
aa6772dcff | ||
|
|
71d0c4edae | ||
|
|
f7eee09d8b | ||
|
|
89de5af63e | ||
|
|
4ec6d0ba81 | ||
|
|
aee92dba4a | ||
|
|
b0cd5f24b6 | ||
|
|
b494ed81d4 | ||
|
|
ff3474f05c | ||
|
|
d5293f19f1 | ||
|
|
e697183849 | ||
|
|
591d37e266 | ||
|
|
d1f02e8f49 | ||
|
|
9faab59b23 | ||
|
|
5d27029e98 | ||
|
|
3115fa17b3 | ||
|
|
124d8577c8 | ||
|
|
2db0c9794c | ||
|
|
2277b47ffd | ||
|
|
28b58e568c | ||
|
|
924455d97d | ||
|
|
6021945fc8 | ||
|
|
6f76652d11 | ||
|
|
3dda2570ed | ||
|
|
2f3c02dba8 | ||
|
|
908088cfea | ||
|
|
8dc8a3de0d | ||
|
|
35f89556ba | ||
|
|
2b55f753b9 | ||
|
|
fc8fce38fb | ||
|
|
a1bfb03ba4 | ||
|
|
b1e7b5f6bb | ||
|
|
10b905b445 | ||
|
|
9b322ea634 | ||
|
|
ba64ea45cc | ||
|
|
64f7be1926 | ||
|
|
db805e73e1 | ||
|
|
cb07b37970 | ||
|
|
a774bfefe2 | ||
|
|
2ae57f0cf5 | ||
|
|
e46eaec817 | ||
|
|
8647cb5e76 | ||
|
|
53c189e46a | ||
|
|
dde51fd362 | ||
|
|
2eac7996fa | ||
|
|
4010aec033 | ||
|
|
c87b84a259 | ||
|
|
8b05468653 | ||
|
|
830afd3c15 | ||
|
|
8f93729d19 | ||
|
|
cd5f2c1de4 | ||
|
|
85ed77d512 | ||
|
|
fd53fa17db | ||
|
|
3676ef4d49 | ||
|
|
28e944f328 | ||
|
|
14e63a3f67 | ||
|
|
09e9eaa5a6 | ||
|
|
e6d752cf4a | ||
|
|
ad20a14a4d | ||
|
|
0be1e0d64c | ||
|
|
98df1ba51e | ||
|
|
878b555ef7 | ||
|
|
63029f7388 | ||
|
|
c76a964fd6 | ||
|
|
79fabc4341 | ||
|
|
f7ef4bde38 | ||
|
|
93ba019069 | ||
|
|
8518684ae9 | ||
|
|
1d5dc08810 | ||
|
|
d8d8b6caf1 | ||
|
|
15acc03bd4 | ||
|
|
896f19786d | ||
|
|
aec5575d09 | ||
|
|
9773f10d6c | ||
|
|
a6bf8ddef6 | ||
|
|
86e692d24f | ||
|
|
97b751209f | ||
|
|
74103fd8d6 | ||
|
|
1992d25cad | ||
|
|
5b619c2fd5 | ||
|
|
9359ad2e91 | ||
|
|
9ff228188b | ||
|
|
2d9963d30e | ||
|
|
58d9b422f3 | ||
|
|
44b319cb57 | ||
|
|
c30f380689 | ||
|
|
e4e884bb8b | ||
|
|
803ad9c17d | ||
|
|
a88dd6a9c0 | ||
|
|
72c16b496e | ||
|
|
81d83dd7f2 | ||
|
|
fa66f7e1e9 | ||
|
|
aa8d135245 | ||
|
|
70282de23b | ||
|
|
83f761847e | ||
|
|
11469dc0c6 | ||
|
|
2d25c89f35 | ||
|
|
3fe96c208a | ||
|
|
0fc6c9cdf3 | ||
|
|
7ee0ecc388 | ||
|
|
1924c7cc3d | ||
|
|
f7df3caaf3 | ||
|
|
fc954ee788 | ||
|
|
c1db2753f5 | ||
|
|
ad87bfe28f | ||
|
|
76c767b1ce | ||
|
|
d991b8c39c | ||
|
|
902693e271 | ||
|
|
35cd63982d | ||
|
|
53ce6dfdf6 | ||
|
|
ad8d7a368b | ||
|
|
b8cf1e5c20 | ||
|
|
94aaa08d97 | ||
|
|
8b9bbec7d1 | ||
|
|
1bb9fc9829 | ||
|
|
5e421bd5bb | ||
|
|
67fcab1122 | ||
|
|
5bfbccda22 | ||
|
|
989275ff59 | ||
|
|
56408f4a40 | ||
|
|
d1a697ac23 | ||
|
|
ebe01749ed | ||
|
|
63195cc2cb | ||
|
|
a2ef69af66 | ||
|
|
5fff22834e | ||
|
|
a9421f49ec | ||
|
|
77fa34eae9 | ||
|
|
1c1e508369 | ||
|
|
f19c99ecb0 | ||
|
|
721a444686 | ||
|
|
63450b466d | ||
|
|
20e7eb5a9b | ||
|
|
e2f9615afa | ||
|
|
0d1c07c803 | ||
|
|
a389f81138 | ||
|
|
0283556608 | ||
|
|
5063d192b6 | ||
|
|
f4a54e475e | ||
|
|
fb662a62f3 | ||
|
|
587c8c9b44 | ||
|
|
aa900213e7 | ||
|
|
cb26187450 | ||
|
|
625ce23f6b | ||
|
|
dbf4a281f1 | ||
|
|
4ab527e779 | ||
|
|
d0cdeb3247 | ||
|
|
8c610aad9a | ||
|
|
6700381a37 | ||
|
|
20377f889a | ||
|
|
6edb1c5dd0 | ||
|
|
b093f92182 | ||
|
|
fa3bb6ba5c | ||
|
|
2705e7c9b0 | ||
|
|
77141882c8 | ||
|
|
4075d02139 | ||
|
|
de0296106b | ||
|
|
eafb136214 | ||
|
|
bfbcc283a3 | ||
|
|
c30544b73a | ||
|
|
bdf5e9c009 | ||
|
|
9878be760b | ||
|
|
7ba6357c05 | ||
|
|
76e063e8b7 | ||
|
|
4d25976f33 | ||
|
|
0b28ee0d01 | ||
|
|
45262a4bb7 | ||
|
|
13a58a78c4 | ||
|
|
f75d49c781 | ||
|
|
3b520dfa85 | ||
|
|
79198c6ae4 | ||
|
|
77a246b1b9 | ||
|
|
f93a3f6ed8 | ||
|
|
8f2a0c7e00 | ||
|
|
863f4ef243 | ||
|
|
fb8a66a2de | ||
|
|
579d4b42dd | ||
|
|
473808850a | ||
|
|
d5318aef4f | ||
|
|
f82917e1fd | ||
|
|
05b74be69a | ||
|
|
a8b5d5d753 | ||
|
|
976ef7f87c | ||
|
|
fd175bcc0e | ||
|
|
76b32f18b3 | ||
|
|
f2d5b87677 | ||
|
|
461347c171 | ||
|
|
46cef31c86 | ||
|
|
59b1a77d4d | ||
|
|
7f338319fd | ||
|
|
2c6c91829d | ||
|
|
ad17c69ab6 | ||
|
|
0b4ec34efb | ||
|
|
f027b82e38 | ||
|
|
8cc9016cb0 | ||
|
|
1d8f37befe | ||
|
|
faebf4c8b8 | ||
|
|
b8e8d3c164 | ||
|
|
8e2416b49b | ||
|
|
f37c26e856 | ||
|
|
27a33e1b20 | ||
|
|
6f941a219a | ||
|
|
ddde8ca1bf | ||
|
|
c26b77ad20 | ||
|
|
c5b4aab8e5 | ||
|
|
a35c309b5f | ||
|
|
55bdcb98b9 | ||
|
|
82328f16cd | ||
|
|
6fee4fce6e | ||
|
|
a54e309269 | ||
|
|
c6bfd7fdc8 | ||
|
|
960a79857b | ||
|
|
7214df472d | ||
|
|
00ae50999b | ||
|
|
6cddefad26 | ||
|
|
0332eaa6ff | ||
|
|
1cce4225eb | ||
|
|
5ab0700bab | ||
|
|
b0f2fbaa95 | ||
|
|
51361c2d15 | ||
|
|
42d6e47387 | ||
|
|
1e939153fb | ||
|
|
1abeb8918e | ||
|
|
b423855483 | ||
|
|
c400d8758c | ||
|
|
bece206699 | ||
|
|
5b4ee09625 | ||
|
|
6e27f617f1 | ||
|
|
9f55c24db6 | ||
|
|
69e822b7f8 | ||
|
|
23c401a5d5 | ||
|
|
68e9883f59 | ||
|
|
95b018374a | ||
|
|
8b5c2385b0 | ||
|
|
f2c52d8239 | ||
|
|
97e951221b | ||
|
|
e1b0c140f1 | ||
|
|
5989569a44 | ||
|
|
82464d7bd3 | ||
|
|
7fb3f695d5 | ||
|
|
7e93b9d3c8 | ||
|
|
4c827ba94f | ||
|
|
cb3923a90f | ||
|
|
cc30676a3f | ||
|
|
c7fb327618 | ||
|
|
14ddbc159c | ||
|
|
0692f1699f | ||
|
|
26c4534bc3 | ||
|
|
5e06cde4cb | ||
|
|
a1a8a78f21 | ||
|
|
e5e415297c | ||
|
|
c9377efc93 | ||
|
|
2a424b6a28 | ||
|
|
d3cded3c6c | ||
|
|
d573c82f8c | ||
|
|
3aa6f91e7a | ||
|
|
1bf071af78 | ||
|
|
9f1fe6c7ae | ||
|
|
791d27326a | ||
|
|
6d4e9c97bf | ||
|
|
40140b54d6 | ||
|
|
33d69d3859 | ||
|
|
862e5ba50e | ||
|
|
25d980ebbf | ||
|
|
d546a615c0 | ||
|
|
d4c8373635 | ||
|
|
c814b2b278 | ||
|
|
74aec9d8ca | ||
|
|
7647be2569 | ||
|
|
59b8abe09e | ||
|
|
46dde54948 | ||
|
|
40aa304b7e | ||
|
|
fd38eb83c4 | ||
|
|
83aabd42ca | ||
|
|
cf22affcbb | ||
|
|
522f42f582 | ||
|
|
0a60818965 | ||
|
|
604765b563 | ||
|
|
7bbc62f3d5 | ||
|
|
771fe0d0d2 | ||
|
|
de75a8af76 | ||
|
|
df4dac4f5a | ||
|
|
24b428bdfc | ||
|
|
2ab042b862 | ||
|
|
b93ad8b7a2 |
1
.github/FUNDING.yml
vendored
Normal file
1
.github/FUNDING.yml
vendored
Normal file
@@ -0,0 +1 @@
|
||||
github: [lucidrains]
|
||||
9
.gitignore
vendored
9
.gitignore
vendored
@@ -1,3 +1,12 @@
|
||||
# default experiment tracker data
|
||||
.tracker-data/
|
||||
|
||||
# Configuration Files
|
||||
configs/*
|
||||
!configs/*.example
|
||||
!configs/*_defaults.py
|
||||
!configs/README.md
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
|
||||
111
configs/README.md
Normal file
111
configs/README.md
Normal file
@@ -0,0 +1,111 @@
|
||||
## DALLE2 Training Configurations
|
||||
|
||||
For more complex configuration, we provide the option of using a configuration file instead of command line arguments.
|
||||
|
||||
### Decoder Trainer
|
||||
|
||||
The decoder trainer has 7 main configuration options. A full example of their use can be found in the [example decoder configuration](train_decoder_config.example.json).
|
||||
|
||||
**<ins>Unet</ins>:**
|
||||
|
||||
This is a single unet config, which belongs as an array nested under the decoder config as a list of `unets`
|
||||
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `dim` | Yes | N/A | The starting channels of the unet. |
|
||||
| `image_embed_dim` | Yes | N/A | The dimension of the image embeddings. |
|
||||
| `dim_mults` | No | `(1, 2, 4, 8)` | The growth factors of the channels. |
|
||||
|
||||
Any parameter from the `Unet` constructor can also be given here.
|
||||
|
||||
**<ins>Decoder</ins>:**
|
||||
|
||||
Defines the configuration options for the decoder model. The unets defined above will automatically be inserted.
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `unets` | Yes | N/A | A list of unets, using the configuration above |
|
||||
| `image_sizes` | Yes | N/A | The resolution of the image after each upsampling step. The length of this array should be the number of unets defined. |
|
||||
| `image_size` | Yes | N/A | Not used. Can be any number. |
|
||||
| `timesteps` | No | `1000` | The number of diffusion timesteps used for generation. |
|
||||
| `loss_type` | No | `l2` | The loss function. Options are `l1`, `huber`, or `l2`. |
|
||||
| `beta_schedule` | No | `cosine` | The noising schedule. Options are `cosine`, `linear`, `quadratic`, `jsd`, or `sigmoid`. |
|
||||
| `learned_variance` | No | `True` | Whether to learn the variance. |
|
||||
|
||||
Any parameter from the `Decoder` constructor can also be given here.
|
||||
|
||||
**<ins>Data</ins>:**
|
||||
|
||||
Settings for creation of the dataloaders.
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `webdataset_base_url` | Yes | N/A | The url of a shard in the webdataset with the shard replaced with `{}`[^1]. |
|
||||
| `embeddings_url` | No | N/A | The url of the folder containing embeddings shards. Not required if embeddings are in webdataset. |
|
||||
| `num_workers` | No | `4` | The number of workers used in the dataloader. |
|
||||
| `batch_size` | No | `64` | The batch size. |
|
||||
| `start_shard` | No | `0` | Defines the start of the shard range the dataset will recall. |
|
||||
| `end_shard` | No | `9999999` | Defines the end of the shard range the dataset will recall. |
|
||||
| `shard_width` | No | `6` | Defines the width of one webdataset shard number[^2]. |
|
||||
| `index_width` | No | `4` | Defines the width of the index of a file inside a shard[^3]. |
|
||||
| `splits` | No | `{ "train": 0.75, "val": 0.15, "test": 0.1 }` | Defines the proportion of shards that will be allocated to the training, validation, and testing datasets. |
|
||||
| `shuffle_train` | No | `True` | Whether to shuffle the shards of the training dataset. |
|
||||
| `resample_train` | No | `False` | If true, shards will be randomly sampled with replacement from the datasets making the epoch length infinite if a limit is not set. Cannot be enabled if `shuffle_train` is enabled. |
|
||||
| `preprocessing` | No | `{ "ToTensor": True }` | Defines preprocessing applied to images from the datasets. |
|
||||
|
||||
[^1]: If your shard files have the paths `protocol://path/to/shard/00104.tar`, then the base url would be `protocol://path/to/shard/{}.tar`. If you are using a protocol like `s3`, you need to pipe the tars. For example `pipe:s3cmd get s3://bucket/path/{}.tar -`.
|
||||
|
||||
[^2]: This refers to the string length of the shard number for your webdataset shards. For instance, if your webdataset shard has the filename `00104.tar`, your shard length is 5.
|
||||
|
||||
[^3]: Inside the webdataset `tar`, you have files named something like `001045945.jpg`. 5 of these characters refer to the shard, and 4 refer to the index of the file in the webdataset (shard is `001041` and index is `5945`). The `index_width` in this case is 4.
|
||||
|
||||
**<ins>Train</ins>:**
|
||||
|
||||
Settings for controlling the training hyperparameters.
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `epochs` | No | `20` | The number of epochs in the training run. |
|
||||
| `lr` | No | `1e-4` | The learning rate. |
|
||||
| `wd` | No | `0.01` | The weight decay. |
|
||||
| `max_grad_norm`| No | `0.5` | The grad norm clipping. |
|
||||
| `save_every_n_samples` | No | `100000` | Samples will be generated and a checkpoint will be saved every `save_every_n_samples` samples. |
|
||||
| `device` | No | `cuda:0` | The device to train on. |
|
||||
| `epoch_samples` | No | `None` | Limits the number of samples iterated through in each epoch. This must be set if resampling. None means no limit. |
|
||||
| `validation_samples` | No | `None` | The number of samples to use for validation. None mean the entire validation set. |
|
||||
| `use_ema` | No | `True` | Whether to use exponential moving average models for sampling. |
|
||||
| `ema_beta` | No | `0.99` | The ema coefficient. |
|
||||
| `save_all` | No | `False` | If True, preserves a checkpoint for every epoch. |
|
||||
| `save_latest` | No | `True` | If True, overwrites the `latest.pth` every time the model is saved. |
|
||||
| `save_best` | No | `True` | If True, overwrites the `best.pth` every time the model has a lower validation loss than all previous models. |
|
||||
| `unet_training_mask` | No | `None` | A boolean array of the same length as the number of unets. If false, the unet is frozen. A value of `None` trains all unets. |
|
||||
|
||||
**<ins>Evaluate</ins>:**
|
||||
|
||||
Defines which evaluation metrics will be used to test the model.
|
||||
Each metric can be enabled by setting its configuration. The configuration keys for each metric are defined by the torchmetrics constructors which will be linked.
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `n_evalation_samples` | No | `1000` | The number of samples to generate to test the model. |
|
||||
| `FID` | No | `None` | Setting to an object enables the [Frechet Inception Distance](https://torchmetrics.readthedocs.io/en/stable/image/frechet_inception_distance.html) metric.
|
||||
| `IS` | No | `None` | Setting to an object enables the [Inception Score](https://torchmetrics.readthedocs.io/en/stable/image/inception_score.html) metric.
|
||||
| `KID` | No | `None` | Setting to an object enables the [Kernel Inception Distance](https://torchmetrics.readthedocs.io/en/stable/image/kernel_inception_distance.html) metric. |
|
||||
| `LPIPS` | No | `None` | Setting to an object enables the [Learned Perceptual Image Patch Similarity](https://torchmetrics.readthedocs.io/en/stable/image/learned_perceptual_image_patch_similarity.html) metric. |
|
||||
|
||||
**<ins>Tracker</ins>:**
|
||||
|
||||
Selects which tracker to use and configures it.
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `tracker_type` | No | `console` | Which tracker to use. Currently accepts `console` or `wandb`. |
|
||||
| `data_path` | No | `./models` | Where the tracker will store local data. |
|
||||
| `verbose` | No | `False` | Enables console logging for non-console trackers. |
|
||||
|
||||
Other configuration options are required for the specific trackers. To see which are required, reference the initializer parameters of each [tracker](../dalle2_pytorch/trackers.py).
|
||||
|
||||
**<ins>Load</ins>:**
|
||||
|
||||
Selects where to load a pretrained model from.
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `source` | No | `None` | Supports `file` or `wandb`. |
|
||||
| `resume` | No | `False` | If the tracker support resuming the run, resume it. |
|
||||
|
||||
Other configuration options are required for loading from a specific source. To see which are required, reference the load methods at the top of the [tracker file](../dalle2_pytorch/trackers.py).
|
||||
99
configs/train_decoder_config.example.json
Normal file
99
configs/train_decoder_config.example.json
Normal file
@@ -0,0 +1,99 @@
|
||||
{
|
||||
"decoder": {
|
||||
"unets": [
|
||||
{
|
||||
"dim": 128,
|
||||
"image_embed_dim": 768,
|
||||
"cond_dim": 64,
|
||||
"channels": 3,
|
||||
"dim_mults": [1, 2, 4, 8],
|
||||
"attn_dim_head": 32,
|
||||
"attn_heads": 16
|
||||
}
|
||||
],
|
||||
"image_sizes": [64],
|
||||
"channels": 3,
|
||||
"timesteps": 1000,
|
||||
"loss_type": "l2",
|
||||
"beta_schedule": "cosine",
|
||||
"learned_variance": true
|
||||
},
|
||||
"data": {
|
||||
"webdataset_base_url": "pipe:s3cmd get s3://bucket/path/{}.tar -",
|
||||
"embeddings_url": "s3://bucket/embeddings/path/",
|
||||
"num_workers": 4,
|
||||
"batch_size": 64,
|
||||
"start_shard": 0,
|
||||
"end_shard": 9999999,
|
||||
"shard_width": 6,
|
||||
"index_width": 4,
|
||||
"splits": {
|
||||
"train": 0.75,
|
||||
"val": 0.15,
|
||||
"test": 0.1
|
||||
},
|
||||
"shuffle_train": true,
|
||||
"resample_train": false,
|
||||
"preprocessing": {
|
||||
"RandomResizedCrop": {
|
||||
"size": [128, 128],
|
||||
"scale": [0.75, 1.0],
|
||||
"ratio": [1.0, 1.0]
|
||||
},
|
||||
"ToTensor": true
|
||||
}
|
||||
},
|
||||
"train": {
|
||||
"epochs": 20,
|
||||
"lr": 1e-4,
|
||||
"wd": 0.01,
|
||||
"max_grad_norm": 0.5,
|
||||
"save_every_n_samples": 100000,
|
||||
"n_sample_images": 6,
|
||||
"device": "cuda:0",
|
||||
"epoch_samples": null,
|
||||
"validation_samples": null,
|
||||
"use_ema": true,
|
||||
"ema_beta": 0.99,
|
||||
"amp": false,
|
||||
"save_all": false,
|
||||
"save_latest": true,
|
||||
"save_best": true,
|
||||
"unet_training_mask": [true]
|
||||
},
|
||||
"evaluate": {
|
||||
"n_evaluation_samples": 1000,
|
||||
"FID": {
|
||||
"feature": 64
|
||||
},
|
||||
"IS": {
|
||||
"feature": 64,
|
||||
"splits": 10
|
||||
},
|
||||
"KID": {
|
||||
"feature": 64,
|
||||
"subset_size": 10
|
||||
},
|
||||
"LPIPS": {
|
||||
"net_type": "vgg",
|
||||
"reduction": "mean"
|
||||
}
|
||||
},
|
||||
"tracker": {
|
||||
"tracker_type": "console",
|
||||
"data_path": "./models",
|
||||
|
||||
"wandb_entity": "",
|
||||
"wandb_project": "",
|
||||
|
||||
"verbose": false
|
||||
},
|
||||
"load": {
|
||||
"source": null,
|
||||
|
||||
"run_path": "",
|
||||
"file_path": "",
|
||||
|
||||
"resume": false
|
||||
}
|
||||
}
|
||||
70
configs/train_prior_config.example.json
Normal file
70
configs/train_prior_config.example.json
Normal file
@@ -0,0 +1,70 @@
|
||||
{
|
||||
"prior": {
|
||||
"clip": {
|
||||
"make": "x-clip",
|
||||
"model": "ViT-L/14",
|
||||
"base_model_kwargs": {
|
||||
"dim_text": 768,
|
||||
"dim_image": 768,
|
||||
"dim_latent": 768
|
||||
}
|
||||
},
|
||||
"net": {
|
||||
"dim": 768,
|
||||
"depth": 12,
|
||||
"num_timesteps": 1000,
|
||||
"num_time_embeds": 1,
|
||||
"num_image_embeds": 1,
|
||||
"num_text_embeds": 1,
|
||||
"dim_head": 64,
|
||||
"heads": 12,
|
||||
"ff_mult": 4,
|
||||
"norm_out": true,
|
||||
"attn_dropout": 0.0,
|
||||
"ff_dropout": 0.0,
|
||||
"final_proj": true,
|
||||
"normformer": true,
|
||||
"rotary_emb": true
|
||||
},
|
||||
"image_embed_dim": 768,
|
||||
"image_size": 224,
|
||||
"image_channels": 3,
|
||||
"timesteps": 1000,
|
||||
"cond_drop_prob": 0.1,
|
||||
"loss_type": "l2",
|
||||
"predict_x_start": true,
|
||||
"beta_schedule": "cosine",
|
||||
"condition_on_text_encodings": true
|
||||
},
|
||||
"data": {
|
||||
"image_url": "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/",
|
||||
"text_url": "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/",
|
||||
"meta_url": "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/laion2B-en-metadata/",
|
||||
"batch_size": 256,
|
||||
"splits": {
|
||||
"train": 0.9,
|
||||
"val": 1e-7,
|
||||
"test": 0.0999999
|
||||
}
|
||||
},
|
||||
"train": {
|
||||
"epochs": 1,
|
||||
"lr": 1.1e-4,
|
||||
"wd": 6.02e-2,
|
||||
"max_grad_norm": 0.5,
|
||||
"use_ema": true,
|
||||
"amp": false,
|
||||
"save_every": 10000
|
||||
},
|
||||
"load": {
|
||||
"source": null,
|
||||
"resume": false
|
||||
},
|
||||
"tracker": {
|
||||
"tracker_type": "wandb",
|
||||
"data_path": "./prior_checkpoints",
|
||||
"wandb_entity": "laion",
|
||||
"wandb_project": "diffusion-prior",
|
||||
"verbose": true
|
||||
}
|
||||
}
|
||||
@@ -1 +1,6 @@
|
||||
from dalle2_pytorch.dalle2_pytorch import DALLE2
|
||||
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
|
||||
from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter
|
||||
from dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer
|
||||
|
||||
from dalle2_pytorch.vqgan_vae import VQGanVAE
|
||||
from x_clip import CLIP
|
||||
|
||||
52
dalle2_pytorch/cli.py
Normal file
52
dalle2_pytorch/cli.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import click
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
from functools import reduce
|
||||
from pathlib import Path
|
||||
|
||||
from dalle2_pytorch import DALLE2, Decoder, DiffusionPrior
|
||||
|
||||
def safeget(dictionary, keys, default = None):
|
||||
return reduce(lambda d, key: d.get(key, default) if isinstance(d, dict) else default, keys.split('.'), dictionary)
|
||||
|
||||
def simple_slugify(text, max_length = 255):
|
||||
return text.replace("-", "_").replace(",", "").replace(" ", "_").replace("|", "--").strip('-_')[:max_length]
|
||||
|
||||
def get_pkg_version():
|
||||
from pkg_resources import get_distribution
|
||||
return get_distribution('dalle2_pytorch').version
|
||||
|
||||
def main():
|
||||
pass
|
||||
|
||||
@click.command()
|
||||
@click.option('--model', default = './dalle2.pt', help = 'path to trained DALL-E2 model')
|
||||
@click.option('--cond_scale', default = 2, help = 'conditioning scale (classifier free guidance) in decoder')
|
||||
@click.argument('text')
|
||||
def dream(
|
||||
model,
|
||||
cond_scale,
|
||||
text
|
||||
):
|
||||
model_path = Path(model)
|
||||
full_model_path = str(model_path.resolve())
|
||||
assert model_path.exists(), f'model not found at {full_model_path}'
|
||||
loaded = torch.load(str(model_path))
|
||||
|
||||
version = safeget(loaded, 'version')
|
||||
print(f'loading DALL-E2 from {full_model_path}, saved at version {version} - current package version is {get_pkg_version()}')
|
||||
|
||||
prior_init_params = safeget(loaded, 'init_params.prior')
|
||||
decoder_init_params = safeget(loaded, 'init_params.decoder')
|
||||
model_params = safeget(loaded, 'model_params')
|
||||
|
||||
prior = DiffusionPrior(**prior_init_params)
|
||||
decoder = Decoder(**decoder_init_params)
|
||||
|
||||
dalle2 = DALLE2(prior, decoder)
|
||||
dalle2.load_state_dict(model_params)
|
||||
|
||||
image = dalle2(text, cond_scale = cond_scale)
|
||||
|
||||
pil_image = T.ToPILImage()(image)
|
||||
return pil_image.save(f'./{simple_slugify(text)}.png')
|
||||
File diff suppressed because it is too large
Load Diff
41
dalle2_pytorch/dataloaders/README.md
Normal file
41
dalle2_pytorch/dataloaders/README.md
Normal file
@@ -0,0 +1,41 @@
|
||||
## Dataloaders
|
||||
In order to make loading data simple and efficient, we include some general dataloaders that can be used to train portions of the network.
|
||||
|
||||
### Decoder: Image Embedding Dataset
|
||||
When training the decoder (and up samplers if training together) in isolation, you will need to load images and corresponding image embeddings. This dataset can read two similar types of datasets. First, it can read a [webdataset](https://github.com/webdataset/webdataset) that contains `.jpg` and `.npy` files in the `.tar`s that contain the images and associated image embeddings respectively. Alternatively, you can also specify a source for the embeddings outside of the webdataset. In this case, the path to the embeddings should contain `.npy` files with the same shard numbers as the webdataset and there should be a correspondence between the filename of the `.jpg` and the index of the embedding in the `.npy`. So, for example, `0001.tar` from the webdataset with image `00010509.jpg` (the first 4 digits are the shard number and the last 4 are the index) in it should be paralleled by a `img_emb_0001.npy` which contains a NumPy array with the embedding at index 509.
|
||||
|
||||
Generating a dataset of this type:
|
||||
1. Use [img2dataset](https://github.com/rom1504/img2dataset) to generate a webdataset.
|
||||
2. Use [clip-retrieval](https://github.com/rom1504/clip-retrieval) to convert the images to embeddings.
|
||||
3. Use [embedding-dataset-reordering](https://github.com/Veldrovive/embedding-dataset-reordering) to reorder the embeddings into the expected format.
|
||||
|
||||
Usage:
|
||||
```python
|
||||
from dalle2_pytorch.dataloaders import ImageEmbeddingDataset, create_image_embedding_dataloader
|
||||
|
||||
# Create a dataloader directly.
|
||||
dataloader = create_image_embedding_dataloader(
|
||||
tar_url="/path/or/url/to/webdataset/{0000..9999}.tar", # Uses braket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar
|
||||
embeddings_url="path/or/url/to/embeddings/folder", # Included if .npy files are not in webdataset. Left out or set to None otherwise
|
||||
num_workers=4,
|
||||
batch_size=32,
|
||||
shard_width=4, # If a file in the webdataset shard 3 is named 0003039.jpg, we know the shard width is 4 and the last three digits are the index
|
||||
shuffle_num=200, # Does a shuffle of the data with a buffer size of 200
|
||||
shuffle_shards=True, # Shuffle the order the shards are read in
|
||||
resample_shards=False, # Sample shards with replacement. If true, an epoch will be infinite unless stopped manually
|
||||
)
|
||||
for img, emb in dataloader:
|
||||
print(img.shape) # torch.Size([32, 3, 256, 256])
|
||||
print(emb.shape) # torch.Size([32, 512])
|
||||
# Train decoder only as shown above
|
||||
|
||||
# Or create a dataset without a loader so you can configure it manually
|
||||
dataset = ImageEmbeddingDataset(
|
||||
urls="/path/or/url/to/webdataset/{0000..9999}.tar",
|
||||
embedding_folder_url="path/or/url/to/embeddings/folder",
|
||||
shard_width=4,
|
||||
shuffle_shards=True,
|
||||
resample=False
|
||||
)
|
||||
```
|
||||
|
||||
2
dalle2_pytorch/dataloaders/__init__.py
Normal file
2
dalle2_pytorch/dataloaders/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from dalle2_pytorch.dataloaders.decoder_loader import ImageEmbeddingDataset, create_image_embedding_dataloader
|
||||
from dalle2_pytorch.dataloaders.embedding_wrapper import make_splits
|
||||
234
dalle2_pytorch/dataloaders/decoder_loader.py
Normal file
234
dalle2_pytorch/dataloaders/decoder_loader.py
Normal file
@@ -0,0 +1,234 @@
|
||||
import os
|
||||
import webdataset as wds
|
||||
import torch
|
||||
import numpy as np
|
||||
import fsspec
|
||||
import shutil
|
||||
|
||||
def get_shard(filename):
|
||||
"""
|
||||
Filenames with shards in them have a consistent structure that we can take advantage of
|
||||
Standard structure: path/to/file/prefix_string_00001.ext
|
||||
"""
|
||||
try:
|
||||
return filename.split("_")[-1].split(".")[0]
|
||||
except ValueError:
|
||||
raise RuntimeError(f"Could not find shard for filename {filename}")
|
||||
|
||||
def get_example_file(fs, path, file_format):
|
||||
"""
|
||||
Given a file system and a file extension, return the example file
|
||||
"""
|
||||
return fs.glob(os.path.join(path, f"*.{file_format}"))[0]
|
||||
|
||||
def embedding_inserter(samples, embeddings_url, index_width, handler=wds.handlers.reraise_exception):
|
||||
"""Given a datum of {"__key__": str, "__url__": str, ...} adds the cooresponding embedding and yields"""
|
||||
previous_tar_url = None
|
||||
current_embeddings = None
|
||||
# Get a reference to an abstract file system where the embeddings are stored
|
||||
embeddings_fs, embeddings_path = fsspec.core.url_to_fs(embeddings_url)
|
||||
example_embedding_file = get_example_file(embeddings_fs, embeddings_path, "npy")
|
||||
example_embedding_shard = get_shard(example_embedding_file)
|
||||
emb_shard_width = len(example_embedding_shard)
|
||||
# Easier to get the basename without the shard once than search through for the correct file every time
|
||||
embedding_file_basename = '_'.join(example_embedding_file.split("_")[:-1]) + "_"
|
||||
|
||||
def load_corresponding_embeds(tar_url):
|
||||
"""Finds and reads the npy files that contains embeddings for the given webdataset tar"""
|
||||
shard = int(tar_url.split("/")[-1].split(".")[0])
|
||||
embedding_url = embedding_file_basename + str(shard).zfill(emb_shard_width) + '.npy'
|
||||
with embeddings_fs.open(embedding_url) as f:
|
||||
data = np.load(f)
|
||||
return torch.from_numpy(data)
|
||||
|
||||
for sample in samples:
|
||||
try:
|
||||
tar_url = sample["__url__"]
|
||||
key = sample["__key__"]
|
||||
if tar_url != previous_tar_url:
|
||||
# If the tar changed, we need to download new embeddings
|
||||
# This means if we shuffle before inserting it will load many more files than we expect and be very inefficient.
|
||||
previous_tar_url = tar_url
|
||||
current_embeddings = load_corresponding_embeds(tar_url)
|
||||
|
||||
embedding_index = int(key[-index_width:])
|
||||
embedding = current_embeddings[embedding_index]
|
||||
# We need to check if this sample is nonzero. If it is, this embedding is not valid and we should continue to the next loop
|
||||
if torch.count_nonzero(embedding) == 0:
|
||||
raise RuntimeError(f"Webdataset had a sample, but no embedding was found. ImgShard: {key[:-index_width]} - Index: {key[-index_width:]}")
|
||||
sample["npy"] = embedding
|
||||
yield sample
|
||||
except Exception as exn: # From wds implementation
|
||||
if handler(exn):
|
||||
continue
|
||||
else:
|
||||
break
|
||||
insert_embedding = wds.filters.pipelinefilter(embedding_inserter)
|
||||
|
||||
def unassociated_shard_skipper(tarfiles, embeddings_url, handler=wds.handlers.reraise_exception):
|
||||
"""Finds if the is a corresponding embedding for the tarfile at { url: [URL] }"""
|
||||
embeddings_fs, embeddings_path = fsspec.core.url_to_fs(embeddings_url)
|
||||
embedding_files = embeddings_fs.ls(embeddings_path)
|
||||
get_embedding_shard = lambda embedding_file: int(embedding_file.split("_")[-1].split(".")[0])
|
||||
embedding_shards = set([get_embedding_shard(filename) for filename in embedding_files]) # Sets have O(1) check for member
|
||||
|
||||
get_tar_shard = lambda tar_file: int(tar_file.split("/")[-1].split(".")[0])
|
||||
for tarfile in tarfiles:
|
||||
try:
|
||||
webdataset_shard = get_tar_shard(tarfile["url"])
|
||||
# If this shard has an associated embeddings file, we pass it through. Otherwise we iterate until we do have one
|
||||
if webdataset_shard in embedding_shards:
|
||||
yield tarfile
|
||||
except Exception as exn: # From wds implementation
|
||||
if handler(exn):
|
||||
continue
|
||||
else:
|
||||
break
|
||||
|
||||
skip_unassociated_shards = wds.filters.pipelinefilter(unassociated_shard_skipper)
|
||||
|
||||
def verify_keys(samples, handler=wds.handlers.reraise_exception):
|
||||
"""
|
||||
Requires that both the image and embedding are present in the sample
|
||||
This is important to do as a user may forget they do not have embeddings in their webdataset and neglect to add them using the embedding_folder_url parameter.
|
||||
"""
|
||||
for sample in samples:
|
||||
try:
|
||||
assert "jpg" in sample, f"Sample {sample['__key__']} missing image"
|
||||
assert "npy" in sample, f"Sample {sample['__key__']} missing embedding. Did you set embedding_folder_url?"
|
||||
yield sample
|
||||
except Exception as exn: # From wds implementation
|
||||
if handler(exn):
|
||||
continue
|
||||
else:
|
||||
break
|
||||
|
||||
class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
|
||||
"""
|
||||
A fluid interface wrapper for DataPipline that returns image embedding pairs
|
||||
Reads embeddings as npy files from the webdataset if they exist. If embedding_folder_url is set, they will be inserted in from the alternate source.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
urls,
|
||||
embedding_folder_url=None,
|
||||
index_width=None,
|
||||
img_preproc=None,
|
||||
extra_keys=[],
|
||||
handler=wds.handlers.reraise_exception,
|
||||
resample=False,
|
||||
shuffle_shards=True
|
||||
):
|
||||
"""
|
||||
Modeled directly off of the WebDataset constructor
|
||||
|
||||
:param urls: A url pointing to the tar files of the webdataset formatted as /path/to/webdataset/{0000..9999}.tar
|
||||
:param embedding_folder_url: Required if webdataset does not contain embeddings. A url pointing to the npy files of the embeddings. Should have the same number of shards as the webdataset.
|
||||
Webdataset image keys should align with the index of the embedding. This means missing image indices must have a corresponding embedding of all zeros.
|
||||
:param index_width: The number of digits in the index. This is used to align the embedding index with the image index.
|
||||
For example, if a file in the webdataset shard 3 is named 0003039.jpg, we know the shard is 4 digits and the last 3 digits are the index_width.
|
||||
:param img_preproc: This function is run on the img before it is batched and returned. Useful for data augmentation or converting to torch tensor.
|
||||
:param handler: A webdataset handler.
|
||||
:param resample: If true, resample webdataset shards with replacement. You need to set your own epoch size if this is true since it will resample infinitely.
|
||||
:param shuffle_shards: If true, shuffle the shards before resampling. This cannot be true if resample is true.
|
||||
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
keys = ["jpg", "npy"] + extra_keys
|
||||
self.key_map = {key: i for i, key in enumerate(keys)}
|
||||
self.resampling = resample
|
||||
self.img_preproc = img_preproc
|
||||
# If s3, check if s3fs is installed and s3cmd is installed and check if the data is piped instead of straight up
|
||||
if (isinstance(urls, str) and "s3:" in urls) or (isinstance(urls, list) and any(["s3:" in url for url in urls])):
|
||||
# Then this has an s3 link for the webdataset and we need extra packages
|
||||
if shutil.which("s3cmd") is None:
|
||||
raise RuntimeError("s3cmd is required for s3 webdataset")
|
||||
if "s3:" in embedding_folder_url:
|
||||
# Then the embeddings are being loaded from s3 and fsspec requires s3fs
|
||||
try:
|
||||
import s3fs
|
||||
except ImportError:
|
||||
raise RuntimeError("s3fs is required to load embeddings from s3")
|
||||
# Add the shardList and randomize or resample if requested
|
||||
if resample:
|
||||
assert not shuffle_shards, "Cannot both resample and shuffle"
|
||||
self.append(wds.ResampledShards(urls))
|
||||
else:
|
||||
self.append(wds.SimpleShardList(urls))
|
||||
if shuffle_shards:
|
||||
self.append(wds.filters.shuffle(1000))
|
||||
|
||||
if embedding_folder_url is not None:
|
||||
# There may be webdataset shards that do not have a embedding shard associated with it. If we do not skip these, they would cause issues.
|
||||
self.append(skip_unassociated_shards(embeddings_url=embedding_folder_url, handler=handler))
|
||||
|
||||
self.append(wds.split_by_node)
|
||||
self.append(wds.split_by_worker)
|
||||
|
||||
self.append(wds.tarfile_to_samples(handler=handler))
|
||||
self.append(wds.decode("pilrgb", handler=handler))
|
||||
if embedding_folder_url is not None:
|
||||
# Then we are loading embeddings for a remote source
|
||||
assert index_width is not None, "Reading embeddings separately requires index width length to be given"
|
||||
self.append(insert_embedding(embeddings_url=embedding_folder_url, index_width=index_width, handler=handler))
|
||||
self.append(verify_keys)
|
||||
# Apply preprocessing
|
||||
self.append(wds.map(self.preproc))
|
||||
self.append(wds.to_tuple(*keys))
|
||||
|
||||
def preproc(self, sample):
|
||||
"""Applies the preprocessing for images"""
|
||||
if self.img_preproc is not None:
|
||||
sample["jpg"] = self.img_preproc(sample["jpg"])
|
||||
return sample
|
||||
|
||||
def create_image_embedding_dataloader(
|
||||
tar_url,
|
||||
num_workers,
|
||||
batch_size,
|
||||
embeddings_url=None,
|
||||
index_width=None,
|
||||
shuffle_num = None,
|
||||
shuffle_shards = True,
|
||||
resample_shards = False,
|
||||
img_preproc=None,
|
||||
extra_keys=[],
|
||||
handler=wds.handlers.reraise_exception#warn_and_continue
|
||||
):
|
||||
"""
|
||||
Convenience function to create an image embedding dataseta and dataloader in one line
|
||||
|
||||
:param tar_url: A url pointing to the tar files of the webdataset formatted as /path/to/webdataset/{0000..9999}.tar
|
||||
:param num_workers: The number of workers to use for the dataloader
|
||||
:param batch_size: The batch size to use for the dataloader
|
||||
:param embeddings_url: Required if webdataset does not contain embeddings. A url pointing to the npy files of the embeddings. Should have the same number of shards as the webdataset.
|
||||
Webdataset image keys should align with the index of the embedding. This means missing image indices must have a corresponding embedding of all zeros.
|
||||
:param index_width: The number of digits in the index. This is used to align the embedding index with the image index.
|
||||
For example, if a file in the webdataset shard 3 is named 0003039.jpg, we know the shard is 4 digits and the last 3 digits are the index_width.
|
||||
:param shuffle_num: If not None, shuffle the dataset with this size buffer after sampling.
|
||||
:param shuffle_shards: If true, shuffle the shards before sampling. This cannot be true if resample is true.
|
||||
:param resample_shards: If true, resample webdataset shards with replacement. You need to set your own epoch size if this is true since it will resample infinitely.
|
||||
:param handler: A webdataset handler.
|
||||
"""
|
||||
ds = ImageEmbeddingDataset(
|
||||
tar_url,
|
||||
embeddings_url,
|
||||
index_width=index_width,
|
||||
shuffle_shards=shuffle_shards,
|
||||
resample=resample_shards,
|
||||
extra_keys=extra_keys,
|
||||
img_preproc=img_preproc,
|
||||
handler=handler
|
||||
)
|
||||
if shuffle_num is not None and shuffle_num > 0:
|
||||
ds.shuffle(1000)
|
||||
return wds.WebLoader(
|
||||
ds,
|
||||
num_workers=num_workers,
|
||||
batch_size=batch_size,
|
||||
prefetch_factor=2, # This might be good to have high so the next npy file is prefetched
|
||||
pin_memory=True,
|
||||
shuffle=False
|
||||
)
|
||||
180
dalle2_pytorch/dataloaders/embedding_wrapper.py
Normal file
180
dalle2_pytorch/dataloaders/embedding_wrapper.py
Normal file
@@ -0,0 +1,180 @@
|
||||
from torch.utils.data import IterableDataset
|
||||
from torch import from_numpy
|
||||
from clip import tokenize
|
||||
from embedding_reader import EmbeddingReader
|
||||
|
||||
|
||||
class PriorEmbeddingLoader(IterableDataset):
|
||||
def __init__(
|
||||
self,
|
||||
text_conditioned: bool,
|
||||
batch_size: int,
|
||||
start: int,
|
||||
stop: int,
|
||||
image_reader,
|
||||
text_reader: EmbeddingReader = None,
|
||||
device: str = "cpu",
|
||||
) -> None:
|
||||
super(PriorEmbeddingLoader).__init__()
|
||||
|
||||
self.text_conditioned = text_conditioned
|
||||
|
||||
if not self.text_conditioned:
|
||||
self.text_reader = text_reader
|
||||
|
||||
self.image_reader = image_reader
|
||||
self.batch_size = batch_size
|
||||
self.start = start
|
||||
self.stop = stop
|
||||
self.device = device
|
||||
|
||||
def __iter__(self):
|
||||
self.n = 0
|
||||
loader_args = dict(
|
||||
batch_size=self.batch_size,
|
||||
start=self.start,
|
||||
end=self.stop,
|
||||
show_progress=False,
|
||||
)
|
||||
if self.text_conditioned:
|
||||
self.loader = self.image_reader(**loader_args)
|
||||
else:
|
||||
self.loader = zip(
|
||||
self.image_reader(**loader_args), self.text_reader(**loader_args)
|
||||
)
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
try:
|
||||
return self.get_sample()
|
||||
except StopIteration:
|
||||
raise StopIteration
|
||||
|
||||
def get_sample(self):
|
||||
"""
|
||||
pre-proocess data from either reader into a common format
|
||||
"""
|
||||
self.n += 1
|
||||
|
||||
if self.text_conditioned:
|
||||
image_embedding, caption = next(self.loader)
|
||||
|
||||
image_embedding = from_numpy(image_embedding).to(self.device)
|
||||
tokenized_caption = tokenize(
|
||||
caption["caption"].to_list(), truncate=True
|
||||
).to(self.device)
|
||||
|
||||
return image_embedding, tokenized_caption
|
||||
|
||||
else:
|
||||
(image_embedding, _), (text_embedding, _) = next(self.loader)
|
||||
|
||||
image_embedding = from_numpy(image_embedding).to(self.device)
|
||||
text_embedding = from_numpy(text_embedding).to(self.device)
|
||||
|
||||
return image_embedding, text_embedding
|
||||
|
||||
|
||||
def make_splits(
|
||||
text_conditioned: bool,
|
||||
batch_size: int,
|
||||
num_data_points: int,
|
||||
train_split: float,
|
||||
eval_split: float,
|
||||
device: str,
|
||||
img_url: str,
|
||||
meta_url: str = None,
|
||||
txt_url: str = None,
|
||||
):
|
||||
|
||||
assert img_url is not None, "Must supply some image embeddings"
|
||||
|
||||
if text_conditioned:
|
||||
assert meta_url is not None, "Must supply metadata url if text-conditioning"
|
||||
image_reader = EmbeddingReader(
|
||||
embeddings_folder=img_url,
|
||||
file_format="parquet_npy",
|
||||
meta_columns=["caption"],
|
||||
metadata_folder=meta_url,
|
||||
)
|
||||
|
||||
# compute split points
|
||||
if num_data_points > image_reader.count:
|
||||
print("Specified point count is larger than the number of points available...defaulting to max length of reader.")
|
||||
num_data_points = image_reader.count
|
||||
|
||||
train_set_size = int(train_split * num_data_points)
|
||||
eval_set_size = int(eval_split * num_data_points)
|
||||
eval_stop = int(train_set_size + eval_set_size)
|
||||
|
||||
train_loader = PriorEmbeddingLoader(
|
||||
text_conditioned=text_conditioned,
|
||||
image_reader=image_reader,
|
||||
batch_size=batch_size,
|
||||
start=0,
|
||||
stop=train_set_size,
|
||||
device=device,
|
||||
)
|
||||
eval_loader = PriorEmbeddingLoader(
|
||||
text_conditioned=text_conditioned,
|
||||
image_reader=image_reader,
|
||||
batch_size=batch_size,
|
||||
start=train_set_size,
|
||||
stop=eval_stop,
|
||||
device=device,
|
||||
)
|
||||
test_loader = PriorEmbeddingLoader(
|
||||
text_conditioned=text_conditioned,
|
||||
image_reader=image_reader,
|
||||
batch_size=batch_size,
|
||||
start=eval_stop,
|
||||
stop=int(num_data_points),
|
||||
device=device,
|
||||
)
|
||||
|
||||
else:
|
||||
assert (
|
||||
txt_url is not None
|
||||
), "Must supply text embedding url if not text-conditioning"
|
||||
|
||||
image_reader = EmbeddingReader(img_url, file_format="npy")
|
||||
text_reader = EmbeddingReader(txt_url, file_format="npy")
|
||||
|
||||
# compute split points
|
||||
if num_data_points > image_reader.count:
|
||||
print("Specified point count is larger than the number of points available...defaulting to max length of reader.")
|
||||
num_data_points = image_reader.count
|
||||
|
||||
train_set_size = int(train_split * num_data_points)
|
||||
eval_set_size = int(eval_split * num_data_points)
|
||||
eval_stop = int(train_set_size + eval_set_size)
|
||||
|
||||
train_loader = PriorEmbeddingLoader(
|
||||
text_conditioned=text_conditioned,
|
||||
image_reader=image_reader,
|
||||
text_reader=text_reader,
|
||||
batch_size=batch_size,
|
||||
start=0,
|
||||
stop=train_set_size,
|
||||
device=device,
|
||||
)
|
||||
eval_loader = PriorEmbeddingLoader(
|
||||
text_conditioned=text_conditioned,
|
||||
image_reader=image_reader,
|
||||
text_reader=text_reader,
|
||||
batch_size=batch_size,
|
||||
start=train_set_size,
|
||||
stop=eval_stop,
|
||||
device=device,
|
||||
)
|
||||
test_loader = PriorEmbeddingLoader(
|
||||
text_conditioned=text_conditioned,
|
||||
image_reader=image_reader,
|
||||
text_reader=text_reader,
|
||||
batch_size=batch_size,
|
||||
start=eval_stop,
|
||||
stop=int(num_data_points),
|
||||
device=device,
|
||||
)
|
||||
|
||||
return train_loader, eval_loader, test_loader
|
||||
59
dalle2_pytorch/dataloaders/simple_image_only_dataloader.py
Normal file
59
dalle2_pytorch/dataloaders/simple_image_only_dataloader.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch.utils import data
|
||||
from torchvision import transforms, utils
|
||||
|
||||
from PIL import Image
|
||||
|
||||
# helpers functions
|
||||
|
||||
def cycle(dl):
|
||||
while True:
|
||||
for data in dl:
|
||||
yield data
|
||||
|
||||
# dataset and dataloader
|
||||
|
||||
class Dataset(data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
folder,
|
||||
image_size,
|
||||
exts = ['jpg', 'jpeg', 'png']
|
||||
):
|
||||
super().__init__()
|
||||
self.folder = folder
|
||||
self.image_size = image_size
|
||||
self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]
|
||||
|
||||
self.transform = transforms.Compose([
|
||||
transforms.Resize(image_size),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.CenterCrop(image_size),
|
||||
transforms.ToTensor()
|
||||
])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.paths)
|
||||
|
||||
def __getitem__(self, index):
|
||||
path = self.paths[index]
|
||||
img = Image.open(path)
|
||||
return self.transform(img)
|
||||
|
||||
def get_images_dataloader(
|
||||
folder,
|
||||
*,
|
||||
batch_size,
|
||||
image_size,
|
||||
shuffle = True,
|
||||
cycle_dl = True,
|
||||
pin_memory = True
|
||||
):
|
||||
ds = Dataset(folder, image_size)
|
||||
dl = data.DataLoader(ds, batch_size = batch_size, shuffle = shuffle, pin_memory = pin_memory)
|
||||
|
||||
if cycle_dl:
|
||||
dl = cycle(dl)
|
||||
return dl
|
||||
32
dalle2_pytorch/optimizer.py
Normal file
32
dalle2_pytorch/optimizer.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from torch.optim import AdamW, Adam
|
||||
|
||||
def separate_weight_decayable_params(params):
|
||||
no_wd_params = set([param for param in params if param.ndim < 2])
|
||||
wd_params = set(params) - no_wd_params
|
||||
return wd_params, no_wd_params
|
||||
|
||||
def get_optimizer(
|
||||
params,
|
||||
lr = 1e-4,
|
||||
wd = 1e-2,
|
||||
betas = (0.9, 0.999),
|
||||
eps = 1e-8,
|
||||
filter_by_requires_grad = False,
|
||||
group_wd_params = True,
|
||||
**kwargs
|
||||
):
|
||||
if filter_by_requires_grad:
|
||||
params = list(filter(lambda t: t.requires_grad, params))
|
||||
|
||||
if wd == 0:
|
||||
return Adam(params, lr = lr, betas = betas, eps = eps)
|
||||
|
||||
if group_wd_params:
|
||||
wd_params, no_wd_params = separate_weight_decayable_params(params)
|
||||
|
||||
params = [
|
||||
{'params': list(wd_params)},
|
||||
{'params': list(no_wd_params), 'weight_decay': 0},
|
||||
]
|
||||
|
||||
return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps)
|
||||
115
dalle2_pytorch/trackers.py
Normal file
115
dalle2_pytorch/trackers.py
Normal file
@@ -0,0 +1,115 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
import importlib
|
||||
from itertools import zip_longest
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
# constants
|
||||
|
||||
DEFAULT_DATA_PATH = './.tracker-data'
|
||||
|
||||
# helper functions
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def import_or_print_error(pkg_name, err_str = None):
|
||||
try:
|
||||
return importlib.import_module(pkg_name)
|
||||
except ModuleNotFoundError as e:
|
||||
if exists(err_str):
|
||||
print(err_str)
|
||||
exit()
|
||||
|
||||
# load state dict functions
|
||||
|
||||
def load_wandb_state_dict(run_path, file_path, **kwargs):
|
||||
wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb recall function')
|
||||
file_reference = wandb.restore(file_path, run_path=run_path)
|
||||
return torch.load(file_reference.name)
|
||||
|
||||
def load_local_state_dict(file_path, **kwargs):
|
||||
return torch.load(file_path)
|
||||
|
||||
# base class
|
||||
|
||||
class BaseTracker(nn.Module):
|
||||
def __init__(self, data_path = DEFAULT_DATA_PATH):
|
||||
super().__init__()
|
||||
self.data_path = Path(data_path)
|
||||
self.data_path.mkdir(parents = True, exist_ok = True)
|
||||
|
||||
def init(self, config, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def log(self, log, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def log_images(self, images, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def save_state_dict(self, state_dict, relative_path, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def recall_state_dict(self, recall_source, *args, **kwargs):
|
||||
"""
|
||||
Loads a state dict from any source.
|
||||
Since a user may wish to load a model from a different source than their own tracker (i.e. tracking using wandb but recalling from disk),
|
||||
this should not be linked to any individual tracker.
|
||||
"""
|
||||
# TODO: Pull this into a dict or something similar so that we can add more sources without having a massive switch statement
|
||||
if recall_source == 'wandb':
|
||||
return load_wandb_state_dict(*args, **kwargs)
|
||||
elif recall_source == 'local':
|
||||
return load_local_state_dict(*args, **kwargs)
|
||||
else:
|
||||
raise ValueError('`recall_source` must be one of `wandb` or `local`')
|
||||
|
||||
|
||||
# basic stdout class
|
||||
|
||||
class ConsoleTracker(BaseTracker):
|
||||
def init(self, **config):
|
||||
print(config)
|
||||
|
||||
def log(self, log, **kwargs):
|
||||
print(log)
|
||||
|
||||
def log_images(self, images, **kwargs): # noop for logging images
|
||||
pass
|
||||
|
||||
def save_state_dict(self, state_dict, relative_path, **kwargs):
|
||||
torch.save(state_dict, str(self.data_path / relative_path))
|
||||
|
||||
# basic wandb class
|
||||
|
||||
class WandbTracker(BaseTracker):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb experiment tracker')
|
||||
os.environ["WANDB_SILENT"] = "true"
|
||||
|
||||
def init(self, **config):
|
||||
self.wandb.init(**config)
|
||||
|
||||
def log(self, log, verbose=False, **kwargs):
|
||||
if verbose:
|
||||
print(log)
|
||||
self.wandb.log(log, **kwargs)
|
||||
|
||||
def log_images(self, images, captions=[], image_section="images", **kwargs):
|
||||
"""
|
||||
Takes a tensor of images and a list of captions and logs them to wandb.
|
||||
"""
|
||||
wandb_images = [self.wandb.Image(image, caption=caption) for image, caption in zip_longest(images, captions)]
|
||||
self.log({ image_section: wandb_images }, **kwargs)
|
||||
|
||||
def save_state_dict(self, state_dict, relative_path, **kwargs):
|
||||
"""
|
||||
Saves a state_dict to disk and uploads it
|
||||
"""
|
||||
full_path = str(self.data_path / relative_path)
|
||||
torch.save(state_dict, full_path)
|
||||
self.wandb.save(full_path, base_path = str(self.data_path)) # Upload and keep relative to data_path
|
||||
259
dalle2_pytorch/train_configs.py
Normal file
259
dalle2_pytorch/train_configs.py
Normal file
@@ -0,0 +1,259 @@
|
||||
import json
|
||||
from torchvision import transforms as T
|
||||
from pydantic import BaseModel, validator, root_validator
|
||||
from typing import List, Iterable, Optional, Union, Tuple, Dict, Any
|
||||
|
||||
from x_clip import CLIP as XCLIP
|
||||
from coca_pytorch import CoCa
|
||||
|
||||
from dalle2_pytorch.dalle2_pytorch import (
|
||||
CoCaAdapter,
|
||||
OpenAIClipAdapter,
|
||||
Unet,
|
||||
Decoder,
|
||||
DiffusionPrior,
|
||||
DiffusionPriorNetwork,
|
||||
XClipAdapter,
|
||||
)
|
||||
|
||||
# helper functions
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
def ListOrTuple(inner_type):
|
||||
return Union[List[inner_type], Tuple[inner_type]]
|
||||
|
||||
# general pydantic classes
|
||||
|
||||
class TrainSplitConfig(BaseModel):
|
||||
train: float = 0.75
|
||||
val: float = 0.15
|
||||
test: float = 0.1
|
||||
|
||||
@root_validator
|
||||
def validate_all(cls, fields):
|
||||
actual_sum = sum([*fields.values()])
|
||||
if actual_sum != 1.:
|
||||
raise ValueError(f'{fields.keys()} must sum to 1.0. Found: {actual_sum}')
|
||||
return fields
|
||||
|
||||
class TrackerConfig(BaseModel):
|
||||
tracker_type: str = 'console' # Decoder currently supports console and wandb
|
||||
data_path: str = './models' # The path where files will be saved locally
|
||||
init_config: Dict[str, Any] = None
|
||||
wandb_entity: str = '' # Only needs to be set if tracker_type is wandb
|
||||
wandb_project: str = ''
|
||||
verbose: bool = False # Whether to print console logging for non-console trackers
|
||||
|
||||
# diffusion prior pydantic classes
|
||||
|
||||
class AdapterConfig(BaseModel):
|
||||
make: str = "openai"
|
||||
model: str = "ViT-L/14"
|
||||
base_model_kwargs: Dict[str, Any] = None
|
||||
|
||||
def create(self):
|
||||
if self.make == "openai":
|
||||
return OpenAIClipAdapter(self.model)
|
||||
elif self.make == "x-clip":
|
||||
return XClipAdapter(XCLIP(**self.base_model_kwargs))
|
||||
elif self.make == "coca":
|
||||
return CoCaAdapter(CoCa(**self.base_model_kwargs))
|
||||
else:
|
||||
raise AttributeError("No adapter with that name is available.")
|
||||
|
||||
class DiffusionPriorNetworkConfig(BaseModel):
|
||||
dim: int
|
||||
depth: int
|
||||
num_timesteps: int = None
|
||||
num_time_embeds: int = 1
|
||||
num_image_embeds: int = 1
|
||||
num_text_embeds: int = 1
|
||||
dim_head: int = 64
|
||||
heads: int = 8
|
||||
ff_mult: int = 4
|
||||
norm_out: bool = True
|
||||
attn_dropout: float = 0.
|
||||
ff_dropout: float = 0.
|
||||
final_proj: bool = True
|
||||
normformer: bool = False
|
||||
rotary_emb: bool = True
|
||||
|
||||
def create(self):
|
||||
kwargs = self.dict()
|
||||
return DiffusionPriorNetwork(**kwargs)
|
||||
|
||||
class DiffusionPriorConfig(BaseModel):
|
||||
clip: AdapterConfig
|
||||
net: DiffusionPriorNetworkConfig
|
||||
image_embed_dim: int
|
||||
image_size: int
|
||||
image_channels: int = 3
|
||||
timesteps: int = 1000
|
||||
cond_drop_prob: float = 0.
|
||||
loss_type: str = 'l2'
|
||||
predict_x_start: bool = True
|
||||
beta_schedule: str = 'cosine'
|
||||
condition_on_text_encodings: bool = True
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
def create(self):
|
||||
kwargs = self.dict()
|
||||
clip = AdapterConfig(**kwargs.pop('clip')).create()
|
||||
diffusion_prior_network = DiffusionPriorNetworkConfig(**kwargs.pop('net')).create()
|
||||
return DiffusionPrior(net = diffusion_prior_network, clip=clip, **kwargs)
|
||||
|
||||
class DiffusionPriorTrainConfig(BaseModel):
|
||||
epochs: int = 1
|
||||
lr: float = 1.1e-4
|
||||
wd: float = 6.02e-2
|
||||
max_grad_norm: float = 0.5
|
||||
use_ema: bool = True
|
||||
ema_beta: float = 0.99
|
||||
amp: bool = False
|
||||
save_every: int = 10000 # what steps to save on
|
||||
|
||||
class DiffusionPriorDataConfig(BaseModel):
|
||||
image_url: str # path to embeddings folder
|
||||
meta_url: str # path to metadata (captions) for images
|
||||
splits: TrainSplitConfig
|
||||
batch_size: int = 64
|
||||
|
||||
class DiffusionPriorLoadConfig(BaseModel):
|
||||
source: str = None
|
||||
resume: bool = False
|
||||
|
||||
class TrainDiffusionPriorConfig(BaseModel):
|
||||
prior: DiffusionPriorConfig
|
||||
data: DiffusionPriorDataConfig
|
||||
train: DiffusionPriorTrainConfig
|
||||
load: DiffusionPriorLoadConfig
|
||||
tracker: TrackerConfig
|
||||
|
||||
@classmethod
|
||||
def from_json_path(cls, json_path):
|
||||
with open(json_path) as f:
|
||||
config = json.load(f)
|
||||
return cls(**config)
|
||||
|
||||
# decoder pydantic classes
|
||||
|
||||
class UnetConfig(BaseModel):
|
||||
dim: int
|
||||
dim_mults: ListOrTuple(int)
|
||||
image_embed_dim: int = None
|
||||
cond_dim: int = None
|
||||
channels: int = 3
|
||||
attn_dim_head: int = 32
|
||||
attn_heads: int = 16
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
class DecoderConfig(BaseModel):
|
||||
unets: ListOrTuple(UnetConfig)
|
||||
image_size: int = None
|
||||
image_sizes: ListOrTuple(int) = None
|
||||
channels: int = 3
|
||||
timesteps: int = 1000
|
||||
loss_type: str = 'l2'
|
||||
beta_schedule: str = 'cosine'
|
||||
learned_variance: bool = True
|
||||
image_cond_drop_prob: float = 0.1
|
||||
text_cond_drop_prob: float = 0.5
|
||||
|
||||
def create(self):
|
||||
decoder_kwargs = self.dict()
|
||||
unet_configs = decoder_kwargs.pop('unets')
|
||||
unets = [Unet(**config) for config in unet_configs]
|
||||
return Decoder(unets, **decoder_kwargs)
|
||||
|
||||
@validator('image_sizes')
|
||||
def check_image_sizes(cls, image_sizes, values):
|
||||
if exists(values.get('image_size')) ^ exists(image_sizes):
|
||||
return image_sizes
|
||||
raise ValueError('either image_size or image_sizes is required, but not both')
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
class DecoderDataConfig(BaseModel):
|
||||
webdataset_base_url: str # path to a webdataset with jpg images
|
||||
embeddings_url: str # path to .npy files with embeddings
|
||||
num_workers: int = 4
|
||||
batch_size: int = 64
|
||||
start_shard: int = 0
|
||||
end_shard: int = 9999999
|
||||
shard_width: int = 6
|
||||
index_width: int = 4
|
||||
splits: TrainSplitConfig
|
||||
shuffle_train: bool = True
|
||||
resample_train: bool = False
|
||||
preprocessing: Dict[str, Any] = {'ToTensor': True}
|
||||
|
||||
@property
|
||||
def img_preproc(self):
|
||||
def _get_transformation(transformation_name, **kwargs):
|
||||
if transformation_name == "RandomResizedCrop":
|
||||
return T.RandomResizedCrop(**kwargs)
|
||||
elif transformation_name == "RandomHorizontalFlip":
|
||||
return T.RandomHorizontalFlip()
|
||||
elif transformation_name == "ToTensor":
|
||||
return T.ToTensor()
|
||||
|
||||
transforms = []
|
||||
for transform_name, transform_kwargs_or_bool in self.preprocessing.items():
|
||||
transform_kwargs = {} if not isinstance(transform_kwargs_or_bool, dict) else transform_kwargs_or_bool
|
||||
transforms.append(_get_transformation(transform_name, **transform_kwargs))
|
||||
return T.Compose(transforms)
|
||||
|
||||
class DecoderTrainConfig(BaseModel):
|
||||
epochs: int = 20
|
||||
lr: float = 1e-4
|
||||
wd: float = 0.01
|
||||
max_grad_norm: float = 0.5
|
||||
save_every_n_samples: int = 100000
|
||||
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
|
||||
device: str = 'cuda:0'
|
||||
epoch_samples: int = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
|
||||
validation_samples: int = None # Same as above but for validation.
|
||||
use_ema: bool = True
|
||||
ema_beta: float = 0.99
|
||||
amp: bool = False
|
||||
save_all: bool = False # Whether to preserve all checkpoints
|
||||
save_latest: bool = True # Whether to always save the latest checkpoint
|
||||
save_best: bool = True # Whether to save the best checkpoint
|
||||
unet_training_mask: ListOrTuple(bool) = None # If None, use all unets
|
||||
|
||||
class DecoderEvaluateConfig(BaseModel):
|
||||
n_evaluation_samples: int = 1000
|
||||
FID: Dict[str, Any] = None
|
||||
IS: Dict[str, Any] = None
|
||||
KID: Dict[str, Any] = None
|
||||
LPIPS: Dict[str, Any] = None
|
||||
|
||||
class DecoderLoadConfig(BaseModel):
|
||||
source: str = None # Supports file and wandb
|
||||
run_path: str = '' # Used only if source is wandb
|
||||
file_path: str = '' # The local filepath if source is file. If source is wandb, the relative path to the model file in wandb.
|
||||
resume: bool = False # If using wandb, whether to resume the run
|
||||
|
||||
class TrainDecoderConfig(BaseModel):
|
||||
decoder: DecoderConfig
|
||||
data: DecoderDataConfig
|
||||
train: DecoderTrainConfig
|
||||
evaluate: DecoderEvaluateConfig
|
||||
tracker: TrackerConfig
|
||||
load: DecoderLoadConfig
|
||||
|
||||
@classmethod
|
||||
def from_json_path(cls, json_path):
|
||||
with open(json_path) as f:
|
||||
config = json.load(f)
|
||||
return cls(**config)
|
||||
591
dalle2_pytorch/trainer.py
Normal file
591
dalle2_pytorch/trainer.py
Normal file
@@ -0,0 +1,591 @@
|
||||
import time
|
||||
import copy
|
||||
from pathlib import Path
|
||||
from math import ceil
|
||||
from functools import partial, wraps
|
||||
from collections.abc import Iterable
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.cuda.amp import autocast, GradScaler
|
||||
|
||||
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
|
||||
from dalle2_pytorch.optimizer import get_optimizer
|
||||
|
||||
import numpy as np
|
||||
|
||||
# helper functions
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
def cast_tuple(val, length = 1):
|
||||
return val if isinstance(val, tuple) else ((val,) * length)
|
||||
|
||||
def pick_and_pop(keys, d):
|
||||
values = list(map(lambda key: d.pop(key), keys))
|
||||
return dict(zip(keys, values))
|
||||
|
||||
def group_dict_by_key(cond, d):
|
||||
return_val = [dict(),dict()]
|
||||
for key in d.keys():
|
||||
match = bool(cond(key))
|
||||
ind = int(not match)
|
||||
return_val[ind][key] = d[key]
|
||||
return (*return_val,)
|
||||
|
||||
def string_begins_with(prefix, str):
|
||||
return str.startswith(prefix)
|
||||
|
||||
def group_by_key_prefix(prefix, d):
|
||||
return group_dict_by_key(partial(string_begins_with, prefix), d)
|
||||
|
||||
def groupby_prefix_and_trim(prefix, d):
|
||||
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
|
||||
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
|
||||
return kwargs_without_prefix, kwargs
|
||||
|
||||
def num_to_groups(num, divisor):
|
||||
groups = num // divisor
|
||||
remainder = num % divisor
|
||||
arr = [divisor] * groups
|
||||
if remainder > 0:
|
||||
arr.append(remainder)
|
||||
return arr
|
||||
|
||||
def get_pkg_version():
|
||||
from pkg_resources import get_distribution
|
||||
return get_distribution('dalle2_pytorch').version
|
||||
|
||||
# decorators
|
||||
|
||||
def cast_torch_tensor(fn):
|
||||
@wraps(fn)
|
||||
def inner(model, *args, **kwargs):
|
||||
device = kwargs.pop('_device', next(model.parameters()).device)
|
||||
cast_device = kwargs.pop('_cast_device', True)
|
||||
|
||||
kwargs_keys = kwargs.keys()
|
||||
all_args = (*args, *kwargs.values())
|
||||
split_kwargs_index = len(all_args) - len(kwargs_keys)
|
||||
all_args = tuple(map(lambda t: torch.from_numpy(t) if exists(t) and isinstance(t, np.ndarray) else t, all_args))
|
||||
|
||||
if cast_device:
|
||||
all_args = tuple(map(lambda t: t.to(device) if exists(t) and isinstance(t, torch.Tensor) else t, all_args))
|
||||
|
||||
args, kwargs_values = all_args[:split_kwargs_index], all_args[split_kwargs_index:]
|
||||
kwargs = dict(tuple(zip(kwargs_keys, kwargs_values)))
|
||||
|
||||
out = fn(model, *args, **kwargs)
|
||||
return out
|
||||
return inner
|
||||
|
||||
# gradient accumulation functions
|
||||
|
||||
def split_iterable(it, split_size):
|
||||
accum = []
|
||||
for ind in range(ceil(len(it) / split_size)):
|
||||
start_index = ind * split_size
|
||||
accum.append(it[start_index: (start_index + split_size)])
|
||||
return accum
|
||||
|
||||
def split(t, split_size = None):
|
||||
if not exists(split_size):
|
||||
return t
|
||||
|
||||
if isinstance(t, torch.Tensor):
|
||||
return t.split(split_size, dim = 0)
|
||||
|
||||
if isinstance(t, Iterable):
|
||||
return split_iterable(t, split_size)
|
||||
|
||||
return TypeError
|
||||
|
||||
def find_first(cond, arr):
|
||||
for el in arr:
|
||||
if cond(el):
|
||||
return el
|
||||
return None
|
||||
|
||||
def split_args_and_kwargs(*args, split_size = None, **kwargs):
|
||||
all_args = (*args, *kwargs.values())
|
||||
len_all_args = len(all_args)
|
||||
first_tensor = find_first(lambda t: isinstance(t, torch.Tensor), all_args)
|
||||
assert exists(first_tensor)
|
||||
|
||||
batch_size = len(first_tensor)
|
||||
split_size = default(split_size, batch_size)
|
||||
num_chunks = ceil(batch_size / split_size)
|
||||
|
||||
dict_len = len(kwargs)
|
||||
dict_keys = kwargs.keys()
|
||||
split_kwargs_index = len_all_args - dict_len
|
||||
|
||||
split_all_args = [split(arg, split_size = split_size) if exists(arg) and isinstance(arg, (torch.Tensor, Iterable)) else ((arg,) * num_chunks) for arg in all_args]
|
||||
chunk_sizes = tuple(map(len, split_all_args[0]))
|
||||
|
||||
for (chunk_size, *chunked_all_args) in tuple(zip(chunk_sizes, *split_all_args)):
|
||||
chunked_args, chunked_kwargs_values = chunked_all_args[:split_kwargs_index], chunked_all_args[split_kwargs_index:]
|
||||
chunked_kwargs = dict(tuple(zip(dict_keys, chunked_kwargs_values)))
|
||||
chunk_size_frac = chunk_size / batch_size
|
||||
yield chunk_size_frac, (chunked_args, chunked_kwargs)
|
||||
|
||||
# saving and loading functions
|
||||
|
||||
# for diffusion prior
|
||||
|
||||
def load_diffusion_model(dprior_path, device):
|
||||
dprior_path = Path(dprior_path)
|
||||
assert dprior_path.exists(), 'Dprior model file does not exist'
|
||||
loaded_obj = torch.load(str(dprior_path), map_location='cpu')
|
||||
|
||||
# Get hyperparameters of loaded model
|
||||
dpn_config = loaded_obj['hparams']['diffusion_prior_network']
|
||||
dp_config = loaded_obj['hparams']['diffusion_prior']
|
||||
image_embed_dim = loaded_obj['image_embed_dim']['image_embed_dim']
|
||||
|
||||
# Create DiffusionPriorNetwork and DiffusionPrior with loaded hyperparameters
|
||||
|
||||
# DiffusionPriorNetwork
|
||||
prior_network = DiffusionPriorNetwork( dim = image_embed_dim, **dpn_config).to(device)
|
||||
|
||||
# DiffusionPrior with text embeddings and image embeddings pre-computed
|
||||
diffusion_prior = DiffusionPrior(net = prior_network, **dp_config, image_embed_dim = image_embed_dim).to(device)
|
||||
|
||||
# Load state dict from saved model
|
||||
diffusion_prior.load_state_dict(loaded_obj['model'])
|
||||
|
||||
return diffusion_prior, loaded_obj
|
||||
|
||||
def save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim):
|
||||
# Saving State Dict
|
||||
print_ribbon('Saving checkpoint')
|
||||
|
||||
state_dict = dict(model=model.state_dict(),
|
||||
optimizer=optimizer.state_dict(),
|
||||
scaler=scaler.state_dict(),
|
||||
hparams = config,
|
||||
image_embed_dim = {"image_embed_dim":image_embed_dim})
|
||||
torch.save(state_dict, save_path+'/'+str(time.time())+'_saved_model.pth')
|
||||
|
||||
# exponential moving average wrapper
|
||||
|
||||
class EMA(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
beta = 0.9999,
|
||||
update_after_step = 1000,
|
||||
update_every = 10,
|
||||
):
|
||||
super().__init__()
|
||||
self.beta = beta
|
||||
self.online_model = model
|
||||
self.ema_model = copy.deepcopy(model)
|
||||
|
||||
self.update_every = update_every
|
||||
self.update_after_step = update_after_step // update_every # only start EMA after this step number, starting at 0
|
||||
|
||||
self.register_buffer('initted', torch.Tensor([False]))
|
||||
self.register_buffer('step', torch.tensor([0]))
|
||||
|
||||
def restore_ema_model_device(self):
|
||||
device = self.initted.device
|
||||
self.ema_model.to(device)
|
||||
|
||||
def copy_params_from_model_to_ema(self):
|
||||
self.ema_model.state_dict(self.online_model.state_dict())
|
||||
|
||||
def update(self):
|
||||
self.step += 1
|
||||
|
||||
if (self.step % self.update_every) != 0:
|
||||
return
|
||||
|
||||
if self.step <= self.update_after_step:
|
||||
self.copy_params_from_model_to_ema()
|
||||
return
|
||||
|
||||
if not self.initted:
|
||||
self.copy_params_from_model_to_ema()
|
||||
self.initted.data.copy_(torch.Tensor([True]))
|
||||
|
||||
self.update_moving_average(self.ema_model, self.online_model)
|
||||
|
||||
def update_moving_average(self, ma_model, current_model):
|
||||
def calculate_ema(beta, old, new):
|
||||
if not exists(old):
|
||||
return new
|
||||
return old * beta + (1 - beta) * new
|
||||
|
||||
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
|
||||
old_weight, up_weight = ma_params.data, current_params.data
|
||||
ma_params.data = calculate_ema(self.beta, old_weight, up_weight)
|
||||
|
||||
for current_buffer, ma_buffer in zip(current_model.buffers(), ma_model.buffers()):
|
||||
new_buffer_value = calculate_ema(self.beta, ma_buffer, current_buffer)
|
||||
ma_buffer.copy_(new_buffer_value)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.ema_model(*args, **kwargs)
|
||||
|
||||
# diffusion prior trainer
|
||||
|
||||
def prior_sample_in_chunks(fn):
|
||||
@wraps(fn)
|
||||
def inner(self, *args, max_batch_size = None, **kwargs):
|
||||
if not exists(max_batch_size):
|
||||
return fn(self, *args, **kwargs)
|
||||
|
||||
outputs = [fn(self, *chunked_args, **chunked_kwargs) for _, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs)]
|
||||
return torch.cat(outputs, dim = 0)
|
||||
return inner
|
||||
|
||||
class DiffusionPriorTrainer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
diffusion_prior,
|
||||
use_ema = True,
|
||||
lr = 3e-4,
|
||||
wd = 1e-2,
|
||||
eps = 1e-6,
|
||||
max_grad_norm = None,
|
||||
amp = False,
|
||||
group_wd_params = True,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(diffusion_prior, DiffusionPrior)
|
||||
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
|
||||
|
||||
self.diffusion_prior = diffusion_prior
|
||||
|
||||
# exponential moving average
|
||||
|
||||
self.use_ema = use_ema
|
||||
if self.use_ema:
|
||||
self.ema_diffusion_prior = EMA(diffusion_prior, **ema_kwargs)
|
||||
|
||||
# optimizer and mixed precision stuff
|
||||
|
||||
self.amp = amp
|
||||
|
||||
self.scaler = GradScaler(enabled = amp)
|
||||
|
||||
self.optimizer = get_optimizer(
|
||||
diffusion_prior.parameters(),
|
||||
lr = lr,
|
||||
wd = wd,
|
||||
eps = eps,
|
||||
group_wd_params = group_wd_params,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
# gradient clipping if needed
|
||||
|
||||
self.max_grad_norm = max_grad_norm
|
||||
|
||||
self.register_buffer('step', torch.tensor([0]))
|
||||
|
||||
def save(self, path, overwrite = True, **kwargs):
|
||||
path = Path(path)
|
||||
assert not (path.exists() and not overwrite)
|
||||
path.parent.mkdir(parents = True, exist_ok = True)
|
||||
|
||||
save_obj = dict(
|
||||
scaler = self.scaler.state_dict(),
|
||||
optimizer = self.optimizer.state_dict(),
|
||||
model = self.diffusion_prior.state_dict(),
|
||||
version = get_pkg_version(),
|
||||
step = self.step.item(),
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if self.use_ema:
|
||||
save_obj = {**save_obj, 'ema': self.ema_diffusion_prior.state_dict()}
|
||||
|
||||
torch.save(save_obj, str(path))
|
||||
|
||||
def load(self, path, only_model = False, strict = True):
|
||||
path = Path(path)
|
||||
assert path.exists()
|
||||
|
||||
loaded_obj = torch.load(str(path))
|
||||
|
||||
if get_pkg_version() != loaded_obj['version']:
|
||||
print(f'loading saved diffusion prior at version {loaded_obj["version"]} but current package version is at {get_pkg_version()}')
|
||||
|
||||
self.diffusion_prior.load_state_dict(loaded_obj['model'], strict = strict)
|
||||
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
|
||||
|
||||
if only_model:
|
||||
return loaded_obj
|
||||
|
||||
self.scaler.load_state_dict(loaded_obj['scaler'])
|
||||
self.optimizer.load_state_dict(loaded_obj['optimizer'])
|
||||
|
||||
if self.use_ema:
|
||||
assert 'ema' in loaded_obj
|
||||
self.ema_diffusion_prior.load_state_dict(loaded_obj['ema'], strict = strict)
|
||||
|
||||
return loaded_obj
|
||||
|
||||
def update(self):
|
||||
if exists(self.max_grad_norm):
|
||||
self.scaler.unscale_(self.optimizer)
|
||||
nn.utils.clip_grad_norm_(self.diffusion_prior.parameters(), self.max_grad_norm)
|
||||
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
if self.use_ema:
|
||||
self.ema_diffusion_prior.update()
|
||||
|
||||
self.step += 1
|
||||
|
||||
@torch.no_grad()
|
||||
@cast_torch_tensor
|
||||
@prior_sample_in_chunks
|
||||
def p_sample_loop(self, *args, **kwargs):
|
||||
return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs)
|
||||
|
||||
@torch.no_grad()
|
||||
@cast_torch_tensor
|
||||
@prior_sample_in_chunks
|
||||
def sample(self, *args, **kwargs):
|
||||
return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_batch_size(self, *args, **kwargs):
|
||||
return self.ema_diffusion_prior.ema_model.sample_batch_size(*args, **kwargs)
|
||||
|
||||
@cast_torch_tensor
|
||||
def forward(
|
||||
self,
|
||||
*args,
|
||||
max_batch_size = None,
|
||||
**kwargs
|
||||
):
|
||||
total_loss = 0.
|
||||
|
||||
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
|
||||
with autocast(enabled = self.amp):
|
||||
loss = self.diffusion_prior(*chunked_args, **chunked_kwargs)
|
||||
loss = loss * chunk_size_frac
|
||||
|
||||
total_loss += loss.item()
|
||||
|
||||
if self.training:
|
||||
self.scaler.scale(loss).backward()
|
||||
|
||||
return total_loss
|
||||
|
||||
# decoder trainer
|
||||
|
||||
def decoder_sample_in_chunks(fn):
|
||||
@wraps(fn)
|
||||
def inner(self, *args, max_batch_size = None, **kwargs):
|
||||
if not exists(max_batch_size):
|
||||
return fn(self, *args, **kwargs)
|
||||
|
||||
if self.decoder.unconditional:
|
||||
batch_size = kwargs.get('batch_size')
|
||||
batch_sizes = num_to_groups(batch_size, max_batch_size)
|
||||
outputs = [fn(self, *args, **{**kwargs, 'batch_size': sub_batch_size}) for sub_batch_size in batch_sizes]
|
||||
else:
|
||||
outputs = [fn(self, *chunked_args, **chunked_kwargs) for _, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs)]
|
||||
|
||||
return torch.cat(outputs, dim = 0)
|
||||
return inner
|
||||
|
||||
class DecoderTrainer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
decoder,
|
||||
use_ema = True,
|
||||
lr = 1e-4,
|
||||
wd = 1e-2,
|
||||
eps = 1e-8,
|
||||
max_grad_norm = 0.5,
|
||||
amp = False,
|
||||
group_wd_params = True,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(decoder, Decoder)
|
||||
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
|
||||
|
||||
self.decoder = decoder
|
||||
self.num_unets = len(self.decoder.unets)
|
||||
|
||||
self.use_ema = use_ema
|
||||
self.ema_unets = nn.ModuleList([])
|
||||
|
||||
self.amp = amp
|
||||
|
||||
# be able to finely customize learning rate, weight decay
|
||||
# per unet
|
||||
|
||||
lr, wd, eps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps))
|
||||
|
||||
for ind, (unet, unet_lr, unet_wd, unet_eps) in enumerate(zip(self.decoder.unets, lr, wd, eps)):
|
||||
optimizer = get_optimizer(
|
||||
unet.parameters(),
|
||||
lr = unet_lr,
|
||||
wd = unet_wd,
|
||||
eps = unet_eps,
|
||||
group_wd_params = group_wd_params,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
setattr(self, f'optim{ind}', optimizer) # cannot use pytorch ModuleList for some reason with optimizers
|
||||
|
||||
if self.use_ema:
|
||||
self.ema_unets.append(EMA(unet, **ema_kwargs))
|
||||
|
||||
scaler = GradScaler(enabled = amp)
|
||||
setattr(self, f'scaler{ind}', scaler)
|
||||
|
||||
# gradient clipping if needed
|
||||
|
||||
self.max_grad_norm = max_grad_norm
|
||||
|
||||
self.register_buffer('step', torch.tensor([0.]))
|
||||
|
||||
def save(self, path, overwrite = True, **kwargs):
|
||||
path = Path(path)
|
||||
assert not (path.exists() and not overwrite)
|
||||
path.parent.mkdir(parents = True, exist_ok = True)
|
||||
|
||||
save_obj = dict(
|
||||
model = self.decoder.state_dict(),
|
||||
version = get_pkg_version(),
|
||||
step = self.step.item(),
|
||||
**kwargs
|
||||
)
|
||||
|
||||
for ind in range(0, self.num_unets):
|
||||
scaler_key = f'scaler{ind}'
|
||||
optimizer_key = f'scaler{ind}'
|
||||
scaler = getattr(self, scaler_key)
|
||||
optimizer = getattr(self, optimizer_key)
|
||||
save_obj = {**save_obj, scaler_key: scaler.state_dict(), optimizer_key: optimizer.state_dict()}
|
||||
|
||||
if self.use_ema:
|
||||
save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}
|
||||
|
||||
torch.save(save_obj, str(path))
|
||||
|
||||
def load(self, path, only_model = False, strict = True):
|
||||
path = Path(path)
|
||||
assert path.exists()
|
||||
|
||||
loaded_obj = torch.load(str(path))
|
||||
|
||||
if get_pkg_version() != loaded_obj['version']:
|
||||
print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {get_pkg_version()}')
|
||||
|
||||
self.decoder.load_state_dict(loaded_obj['model'], strict = strict)
|
||||
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
|
||||
|
||||
if only_model:
|
||||
return loaded_obj
|
||||
|
||||
for ind in range(0, self.num_unets):
|
||||
scaler_key = f'scaler{ind}'
|
||||
optimizer_key = f'scaler{ind}'
|
||||
scaler = getattr(self, scaler_key)
|
||||
optimizer = getattr(self, optimizer_key)
|
||||
|
||||
scaler.load_state_dict(loaded_obj[scaler_key])
|
||||
optimizer.load_state_dict(loaded_obj[optimizer_key])
|
||||
|
||||
if self.use_ema:
|
||||
assert 'ema' in loaded_obj
|
||||
self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict)
|
||||
|
||||
return loaded_obj
|
||||
|
||||
@property
|
||||
def unets(self):
|
||||
return nn.ModuleList([ema.ema_model for ema in self.ema_unets])
|
||||
|
||||
def scale(self, loss, *, unet_number):
|
||||
assert 1 <= unet_number <= self.num_unets
|
||||
index = unet_number - 1
|
||||
scaler = getattr(self, f'scaler{index}')
|
||||
return scaler.scale(loss)
|
||||
|
||||
def update(self, unet_number = None):
|
||||
if self.num_unets == 1:
|
||||
unet_number = default(unet_number, 1)
|
||||
|
||||
assert exists(unet_number) and 1 <= unet_number <= self.num_unets
|
||||
index = unet_number - 1
|
||||
unet = self.decoder.unets[index]
|
||||
|
||||
optimizer = getattr(self, f'optim{index}')
|
||||
scaler = getattr(self, f'scaler{index}')
|
||||
|
||||
if exists(self.max_grad_norm):
|
||||
scaler.unscale_(optimizer)
|
||||
nn.utils.clip_grad_norm_(unet.parameters(), self.max_grad_norm)
|
||||
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if self.use_ema:
|
||||
ema_unet = self.ema_unets[index]
|
||||
ema_unet.update()
|
||||
|
||||
self.step += 1
|
||||
|
||||
@torch.no_grad()
|
||||
@cast_torch_tensor
|
||||
@decoder_sample_in_chunks
|
||||
def sample(self, *args, **kwargs):
|
||||
if kwargs.pop('use_non_ema', False) or not self.use_ema:
|
||||
return self.decoder.sample(*args, **kwargs)
|
||||
|
||||
trainable_unets = self.decoder.unets
|
||||
self.decoder.unets = self.unets # swap in exponential moving averaged unets for sampling
|
||||
|
||||
output = self.decoder.sample(*args, **kwargs)
|
||||
|
||||
self.decoder.unets = trainable_unets # restore original training unets
|
||||
|
||||
# cast the ema_model unets back to original device
|
||||
for ema in self.ema_unets:
|
||||
ema.restore_ema_model_device()
|
||||
|
||||
return output
|
||||
|
||||
@cast_torch_tensor
|
||||
def forward(
|
||||
self,
|
||||
*args,
|
||||
unet_number = None,
|
||||
max_batch_size = None,
|
||||
**kwargs
|
||||
):
|
||||
if self.num_unets == 1:
|
||||
unet_number = default(unet_number, 1)
|
||||
|
||||
total_loss = 0.
|
||||
|
||||
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
|
||||
with autocast(enabled = self.amp):
|
||||
loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs)
|
||||
loss = loss * chunk_size_frac
|
||||
|
||||
total_loss += loss.item()
|
||||
|
||||
if self.training:
|
||||
self.scale(loss, unet_number = unet_number).backward()
|
||||
|
||||
return total_loss
|
||||
19
dalle2_pytorch/utils.py
Normal file
19
dalle2_pytorch/utils.py
Normal file
@@ -0,0 +1,19 @@
|
||||
import time
|
||||
|
||||
# time helpers
|
||||
|
||||
class Timer:
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.last_time = time.time()
|
||||
|
||||
def elapsed(self):
|
||||
return time.time() - self.last_time
|
||||
|
||||
# print helpers
|
||||
|
||||
def print_ribbon(s, symbol = '=', repeat = 40):
|
||||
flank = symbol * repeat
|
||||
return f'{flank} {s} {flank}'
|
||||
765
dalle2_pytorch/vqgan_vae.py
Normal file
765
dalle2_pytorch/vqgan_vae.py
Normal file
@@ -0,0 +1,765 @@
|
||||
import copy
|
||||
import math
|
||||
from math import sqrt
|
||||
from functools import partial, wraps
|
||||
|
||||
from vector_quantize_pytorch import VectorQuantize as VQ
|
||||
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import grad as torch_grad
|
||||
import torchvision
|
||||
|
||||
from einops import rearrange, reduce, repeat
|
||||
from einops_exts import rearrange_many
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# constants
|
||||
|
||||
MList = nn.ModuleList
|
||||
|
||||
# helper functions
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
# decorators
|
||||
|
||||
def eval_decorator(fn):
|
||||
def inner(model, *args, **kwargs):
|
||||
was_training = model.training
|
||||
model.eval()
|
||||
out = fn(model, *args, **kwargs)
|
||||
model.train(was_training)
|
||||
return out
|
||||
return inner
|
||||
|
||||
def remove_vgg(fn):
|
||||
@wraps(fn)
|
||||
def inner(self, *args, **kwargs):
|
||||
has_vgg = hasattr(self, 'vgg')
|
||||
if has_vgg:
|
||||
vgg = self.vgg
|
||||
delattr(self, 'vgg')
|
||||
|
||||
out = fn(self, *args, **kwargs)
|
||||
|
||||
if has_vgg:
|
||||
self.vgg = vgg
|
||||
|
||||
return out
|
||||
return inner
|
||||
|
||||
# keyword argument helpers
|
||||
|
||||
def pick_and_pop(keys, d):
|
||||
values = list(map(lambda key: d.pop(key), keys))
|
||||
return dict(zip(keys, values))
|
||||
|
||||
def group_dict_by_key(cond, d):
|
||||
return_val = [dict(),dict()]
|
||||
for key in d.keys():
|
||||
match = bool(cond(key))
|
||||
ind = int(not match)
|
||||
return_val[ind][key] = d[key]
|
||||
return (*return_val,)
|
||||
|
||||
def string_begins_with(prefix, str):
|
||||
return str.startswith(prefix)
|
||||
|
||||
def group_by_key_prefix(prefix, d):
|
||||
return group_dict_by_key(partial(string_begins_with, prefix), d)
|
||||
|
||||
def groupby_prefix_and_trim(prefix, d):
|
||||
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
|
||||
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
|
||||
return kwargs_without_prefix, kwargs
|
||||
|
||||
# tensor helper functions
|
||||
|
||||
def log(t, eps = 1e-10):
|
||||
return torch.log(t + eps)
|
||||
|
||||
def gradient_penalty(images, output, weight = 10):
|
||||
batch_size = images.shape[0]
|
||||
gradients = torch_grad(outputs = output, inputs = images,
|
||||
grad_outputs = torch.ones(output.size(), device = images.device),
|
||||
create_graph = True, retain_graph = True, only_inputs = True)[0]
|
||||
|
||||
gradients = rearrange(gradients, 'b ... -> b (...)')
|
||||
return weight * ((gradients.norm(2, dim = 1) - 1) ** 2).mean()
|
||||
|
||||
def l2norm(t):
|
||||
return F.normalize(t, dim = -1)
|
||||
|
||||
def leaky_relu(p = 0.1):
|
||||
return nn.LeakyReLU(0.1)
|
||||
|
||||
def stable_softmax(t, dim = -1, alpha = 32 ** 2):
|
||||
t = t / alpha
|
||||
t = t - torch.amax(t, dim = dim, keepdim = True).detach()
|
||||
return (t * alpha).softmax(dim = dim)
|
||||
|
||||
def safe_div(numer, denom, eps = 1e-8):
|
||||
return numer / (denom + eps)
|
||||
|
||||
# gan losses
|
||||
|
||||
def hinge_discr_loss(fake, real):
|
||||
return (F.relu(1 + fake) + F.relu(1 - real)).mean()
|
||||
|
||||
def hinge_gen_loss(fake):
|
||||
return -fake.mean()
|
||||
|
||||
def bce_discr_loss(fake, real):
|
||||
return (-log(1 - torch.sigmoid(fake)) - log(torch.sigmoid(real))).mean()
|
||||
|
||||
def bce_gen_loss(fake):
|
||||
return -log(torch.sigmoid(fake)).mean()
|
||||
|
||||
def grad_layer_wrt_loss(loss, layer):
|
||||
return torch_grad(
|
||||
outputs = loss,
|
||||
inputs = layer,
|
||||
grad_outputs = torch.ones_like(loss),
|
||||
retain_graph = True
|
||||
)[0].detach()
|
||||
|
||||
# vqgan vae
|
||||
|
||||
class LayerNormChan(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
eps = 1e-5
|
||||
):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1))
|
||||
|
||||
def forward(self, x):
|
||||
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
|
||||
mean = torch.mean(x, dim = 1, keepdim = True)
|
||||
return (x - mean) / (var + self.eps).sqrt() * self.gamma
|
||||
|
||||
# discriminator
|
||||
|
||||
class Discriminator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dims,
|
||||
channels = 3,
|
||||
groups = 16,
|
||||
init_kernel_size = 5
|
||||
):
|
||||
super().__init__()
|
||||
dim_pairs = zip(dims[:-1], dims[1:])
|
||||
|
||||
self.layers = MList([nn.Sequential(nn.Conv2d(channels, dims[0], init_kernel_size, padding = init_kernel_size // 2), leaky_relu())])
|
||||
|
||||
for dim_in, dim_out in dim_pairs:
|
||||
self.layers.append(nn.Sequential(
|
||||
nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1),
|
||||
nn.GroupNorm(groups, dim_out),
|
||||
leaky_relu()
|
||||
))
|
||||
|
||||
dim = dims[-1]
|
||||
self.to_logits = nn.Sequential( # return 5 x 5, for PatchGAN-esque training
|
||||
nn.Conv2d(dim, dim, 1),
|
||||
leaky_relu(),
|
||||
nn.Conv2d(dim, 1, 4)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
for net in self.layers:
|
||||
x = net(x)
|
||||
|
||||
return self.to_logits(x)
|
||||
|
||||
# positional encoding
|
||||
|
||||
class ContinuousPositionBias(nn.Module):
|
||||
""" from https://arxiv.org/abs/2111.09883 """
|
||||
|
||||
def __init__(self, *, dim, heads, layers = 2):
|
||||
super().__init__()
|
||||
self.net = MList([])
|
||||
self.net.append(nn.Sequential(nn.Linear(2, dim), leaky_relu()))
|
||||
|
||||
for _ in range(layers - 1):
|
||||
self.net.append(nn.Sequential(nn.Linear(dim, dim), leaky_relu()))
|
||||
|
||||
self.net.append(nn.Linear(dim, heads))
|
||||
self.register_buffer('rel_pos', None, persistent = False)
|
||||
|
||||
def forward(self, x):
|
||||
n, device = x.shape[-1], x.device
|
||||
fmap_size = int(sqrt(n))
|
||||
|
||||
if not exists(self.rel_pos):
|
||||
pos = torch.arange(fmap_size, device = device)
|
||||
grid = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij'))
|
||||
grid = rearrange(grid, 'c i j -> (i j) c')
|
||||
rel_pos = rearrange(grid, 'i c -> i 1 c') - rearrange(grid, 'j c -> 1 j c')
|
||||
rel_pos = torch.sign(rel_pos) * torch.log(rel_pos.abs() + 1)
|
||||
self.register_buffer('rel_pos', rel_pos, persistent = False)
|
||||
|
||||
rel_pos = self.rel_pos.float()
|
||||
|
||||
for layer in self.net:
|
||||
rel_pos = layer(rel_pos)
|
||||
|
||||
bias = rearrange(rel_pos, 'i j h -> h i j')
|
||||
return x + bias
|
||||
|
||||
# resnet encoder / decoder
|
||||
|
||||
class ResnetEncDec(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
*,
|
||||
channels = 3,
|
||||
layers = 4,
|
||||
layer_mults = None,
|
||||
num_resnet_blocks = 1,
|
||||
resnet_groups = 16,
|
||||
first_conv_kernel_size = 5,
|
||||
use_attn = True,
|
||||
attn_dim_head = 64,
|
||||
attn_heads = 8,
|
||||
attn_dropout = 0.,
|
||||
):
|
||||
super().__init__()
|
||||
assert dim % resnet_groups == 0, f'dimension {dim} must be divisible by {resnet_groups} (groups for the groupnorm)'
|
||||
|
||||
self.layers = layers
|
||||
|
||||
self.encoders = MList([])
|
||||
self.decoders = MList([])
|
||||
|
||||
layer_mults = default(layer_mults, list(map(lambda t: 2 ** t, range(layers))))
|
||||
assert len(layer_mults) == layers, 'layer multipliers must be equal to designated number of layers'
|
||||
|
||||
layer_dims = [dim * mult for mult in layer_mults]
|
||||
dims = (dim, *layer_dims)
|
||||
|
||||
self.encoded_dim = dims[-1]
|
||||
|
||||
dim_pairs = zip(dims[:-1], dims[1:])
|
||||
|
||||
append = lambda arr, t: arr.append(t)
|
||||
prepend = lambda arr, t: arr.insert(0, t)
|
||||
|
||||
if not isinstance(num_resnet_blocks, tuple):
|
||||
num_resnet_blocks = (*((0,) * (layers - 1)), num_resnet_blocks)
|
||||
|
||||
if not isinstance(use_attn, tuple):
|
||||
use_attn = (*((False,) * (layers - 1)), use_attn)
|
||||
|
||||
assert len(num_resnet_blocks) == layers, 'number of resnet blocks config must be equal to number of layers'
|
||||
assert len(use_attn) == layers
|
||||
|
||||
for layer_index, (dim_in, dim_out), layer_num_resnet_blocks, layer_use_attn in zip(range(layers), dim_pairs, num_resnet_blocks, use_attn):
|
||||
append(self.encoders, nn.Sequential(nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), leaky_relu()))
|
||||
prepend(self.decoders, nn.Sequential(nn.ConvTranspose2d(dim_out, dim_in, 4, 2, 1), leaky_relu()))
|
||||
|
||||
if layer_use_attn:
|
||||
prepend(self.decoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
|
||||
|
||||
for _ in range(layer_num_resnet_blocks):
|
||||
append(self.encoders, ResBlock(dim_out, groups = resnet_groups))
|
||||
prepend(self.decoders, GLUResBlock(dim_out, groups = resnet_groups))
|
||||
|
||||
if layer_use_attn:
|
||||
append(self.encoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
|
||||
|
||||
prepend(self.encoders, nn.Conv2d(channels, dim, first_conv_kernel_size, padding = first_conv_kernel_size // 2))
|
||||
append(self.decoders, nn.Conv2d(dim, channels, 1))
|
||||
|
||||
def get_encoded_fmap_size(self, image_size):
|
||||
return image_size // (2 ** self.layers)
|
||||
|
||||
@property
|
||||
def last_dec_layer(self):
|
||||
return self.decoders[-1].weight
|
||||
|
||||
def encode(self, x):
|
||||
for enc in self.encoders:
|
||||
x = enc(x)
|
||||
return x
|
||||
|
||||
def decode(self, x):
|
||||
for dec in self.decoders:
|
||||
x = dec(x)
|
||||
return x
|
||||
|
||||
class GLUResBlock(nn.Module):
|
||||
def __init__(self, chan, groups = 16):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(chan, chan * 2, 3, padding = 1),
|
||||
nn.GLU(dim = 1),
|
||||
nn.GroupNorm(groups, chan),
|
||||
nn.Conv2d(chan, chan * 2, 3, padding = 1),
|
||||
nn.GLU(dim = 1),
|
||||
nn.GroupNorm(groups, chan),
|
||||
nn.Conv2d(chan, chan, 1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x) + x
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, chan, groups = 16):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(chan, chan, 3, padding = 1),
|
||||
nn.GroupNorm(groups, chan),
|
||||
leaky_relu(),
|
||||
nn.Conv2d(chan, chan, 3, padding = 1),
|
||||
nn.GroupNorm(groups, chan),
|
||||
leaky_relu(),
|
||||
nn.Conv2d(chan, chan, 1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x) + x
|
||||
|
||||
# vqgan attention layer
|
||||
|
||||
class VQGanAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dim,
|
||||
dim_head = 64,
|
||||
heads = 8,
|
||||
dropout = 0.
|
||||
):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
inner_dim = heads * dim_head
|
||||
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.pre_norm = LayerNormChan(dim)
|
||||
|
||||
self.cpb = ContinuousPositionBias(dim = dim // 4, heads = heads)
|
||||
self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)
|
||||
self.to_out = nn.Conv2d(inner_dim, dim, 1, bias = False)
|
||||
|
||||
def forward(self, x):
|
||||
h = self.heads
|
||||
height, width, residual = *x.shape[-2:], x.clone()
|
||||
|
||||
x = self.pre_norm(x)
|
||||
|
||||
q, k, v = self.to_qkv(x).chunk(3, dim = 1)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = h), (q, k, v))
|
||||
|
||||
sim = einsum('b h c i, b h c j -> b h i j', q, k) * self.scale
|
||||
|
||||
sim = self.cpb(sim)
|
||||
|
||||
attn = stable_softmax(sim, dim = -1)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = einsum('b h i j, b h c j -> b h c i', attn, v)
|
||||
out = rearrange(out, 'b h c (x y) -> b (h c) x y', x = height, y = width)
|
||||
out = self.to_out(out)
|
||||
|
||||
return out + residual
|
||||
|
||||
# ViT encoder / decoder
|
||||
|
||||
class RearrangeImage(nn.Module):
|
||||
def forward(self, x):
|
||||
n = x.shape[1]
|
||||
w = h = int(sqrt(n))
|
||||
return rearrange(x, 'b (h w) ... -> b h w ...', h = h, w = w)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
*,
|
||||
heads = 8,
|
||||
dim_head = 32
|
||||
):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
inner_dim = dim_head * heads
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
self.to_out = nn.Linear(inner_dim, dim)
|
||||
|
||||
def forward(self, x):
|
||||
h = self.heads
|
||||
|
||||
x = self.norm(x)
|
||||
|
||||
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = h)
|
||||
|
||||
q = q * self.scale
|
||||
sim = einsum('b h i d, b h j d -> b h i j', q, k)
|
||||
|
||||
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
|
||||
attn = sim.softmax(dim = -1)
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
def FeedForward(dim, mult = 4):
|
||||
return nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, dim * mult, bias = False),
|
||||
nn.GELU(),
|
||||
nn.Linear(dim * mult, dim, bias = False)
|
||||
)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
*,
|
||||
layers,
|
||||
dim_head = 32,
|
||||
heads = 8,
|
||||
ff_mult = 4
|
||||
):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(layers):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Attention(dim = dim, dim_head = dim_head, heads = heads),
|
||||
FeedForward(dim = dim, mult = ff_mult)
|
||||
]))
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
|
||||
return self.norm(x)
|
||||
|
||||
class ViTEncDec(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
channels = 3,
|
||||
layers = 4,
|
||||
patch_size = 8,
|
||||
dim_head = 32,
|
||||
heads = 8,
|
||||
ff_mult = 4
|
||||
):
|
||||
super().__init__()
|
||||
self.encoded_dim = dim
|
||||
self.patch_size = patch_size
|
||||
|
||||
input_dim = channels * (patch_size ** 2)
|
||||
|
||||
self.encoder = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
|
||||
nn.Linear(input_dim, dim),
|
||||
Transformer(
|
||||
dim = dim,
|
||||
dim_head = dim_head,
|
||||
heads = heads,
|
||||
ff_mult = ff_mult,
|
||||
layers = layers
|
||||
),
|
||||
RearrangeImage(),
|
||||
Rearrange('b h w c -> b c h w')
|
||||
)
|
||||
|
||||
self.decoder = nn.Sequential(
|
||||
Rearrange('b c h w -> b (h w) c'),
|
||||
Transformer(
|
||||
dim = dim,
|
||||
dim_head = dim_head,
|
||||
heads = heads,
|
||||
ff_mult = ff_mult,
|
||||
layers = layers
|
||||
),
|
||||
nn.Sequential(
|
||||
nn.Linear(dim, dim * 4, bias = False),
|
||||
nn.Tanh(),
|
||||
nn.Linear(dim * 4, input_dim, bias = False),
|
||||
),
|
||||
RearrangeImage(),
|
||||
Rearrange('b h w (p1 p2 c) -> b c (h p1) (w p2)', p1 = patch_size, p2 = patch_size)
|
||||
)
|
||||
|
||||
def get_encoded_fmap_size(self, image_size):
|
||||
return image_size // self.patch_size
|
||||
|
||||
@property
|
||||
def last_dec_layer(self):
|
||||
return self.decoder[-3][-1].weight
|
||||
|
||||
def encode(self, x):
|
||||
return self.encoder(x)
|
||||
|
||||
def decode(self, x):
|
||||
return self.decoder(x)
|
||||
|
||||
# main vqgan-vae classes
|
||||
|
||||
class NullVQGanVAE(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
channels
|
||||
):
|
||||
super().__init__()
|
||||
self.encoded_dim = channels
|
||||
self.layers = 0
|
||||
|
||||
def get_encoded_fmap_size(self, size):
|
||||
return size
|
||||
|
||||
def copy_for_eval(self):
|
||||
return self
|
||||
|
||||
def encode(self, x):
|
||||
return x
|
||||
|
||||
def decode(self, x):
|
||||
return x
|
||||
|
||||
class VQGanVAE(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dim,
|
||||
image_size,
|
||||
channels = 3,
|
||||
layers = 4,
|
||||
l2_recon_loss = False,
|
||||
use_hinge_loss = True,
|
||||
vgg = None,
|
||||
vq_codebook_dim = 256,
|
||||
vq_codebook_size = 512,
|
||||
vq_decay = 0.8,
|
||||
vq_commitment_weight = 1.,
|
||||
vq_kmeans_init = True,
|
||||
vq_use_cosine_sim = True,
|
||||
use_vgg_and_gan = True,
|
||||
vae_type = 'resnet',
|
||||
discr_layers = 4,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
vq_kwargs, kwargs = groupby_prefix_and_trim('vq_', kwargs)
|
||||
encdec_kwargs, kwargs = groupby_prefix_and_trim('encdec_', kwargs)
|
||||
|
||||
self.image_size = image_size
|
||||
self.channels = channels
|
||||
self.codebook_size = vq_codebook_size
|
||||
|
||||
if vae_type == 'resnet':
|
||||
enc_dec_klass = ResnetEncDec
|
||||
elif vae_type == 'vit':
|
||||
enc_dec_klass = ViTEncDec
|
||||
else:
|
||||
raise ValueError(f'{vae_type} not valid')
|
||||
|
||||
self.enc_dec = enc_dec_klass(
|
||||
dim = dim,
|
||||
channels = channels,
|
||||
layers = layers,
|
||||
**encdec_kwargs
|
||||
)
|
||||
|
||||
self.vq = VQ(
|
||||
dim = self.enc_dec.encoded_dim,
|
||||
codebook_dim = vq_codebook_dim,
|
||||
codebook_size = vq_codebook_size,
|
||||
decay = vq_decay,
|
||||
commitment_weight = vq_commitment_weight,
|
||||
accept_image_fmap = True,
|
||||
kmeans_init = vq_kmeans_init,
|
||||
use_cosine_sim = vq_use_cosine_sim,
|
||||
**vq_kwargs
|
||||
)
|
||||
|
||||
# reconstruction loss
|
||||
|
||||
self.recon_loss_fn = F.mse_loss if l2_recon_loss else F.l1_loss
|
||||
|
||||
# turn off GAN and perceptual loss if grayscale
|
||||
|
||||
self.vgg = None
|
||||
self.discr = None
|
||||
self.use_vgg_and_gan = use_vgg_and_gan
|
||||
|
||||
if not use_vgg_and_gan:
|
||||
return
|
||||
|
||||
# preceptual loss
|
||||
|
||||
if exists(vgg):
|
||||
self.vgg = vgg
|
||||
else:
|
||||
self.vgg = torchvision.models.vgg16(pretrained = True)
|
||||
self.vgg.classifier = nn.Sequential(*self.vgg.classifier[:-2])
|
||||
|
||||
# gan related losses
|
||||
|
||||
layer_mults = list(map(lambda t: 2 ** t, range(discr_layers)))
|
||||
layer_dims = [dim * mult for mult in layer_mults]
|
||||
dims = (dim, *layer_dims)
|
||||
|
||||
self.discr = Discriminator(dims = dims, channels = channels)
|
||||
|
||||
self.discr_loss = hinge_discr_loss if use_hinge_loss else bce_discr_loss
|
||||
self.gen_loss = hinge_gen_loss if use_hinge_loss else bce_gen_loss
|
||||
|
||||
@property
|
||||
def encoded_dim(self):
|
||||
return self.enc_dec.encoded_dim
|
||||
|
||||
def get_encoded_fmap_size(self, image_size):
|
||||
return self.enc_dec.get_encoded_fmap_size(image_size)
|
||||
|
||||
def copy_for_eval(self):
|
||||
device = next(self.parameters()).device
|
||||
vae_copy = copy.deepcopy(self.cpu())
|
||||
|
||||
if vae_copy.use_vgg_and_gan:
|
||||
del vae_copy.discr
|
||||
del vae_copy.vgg
|
||||
|
||||
vae_copy.eval()
|
||||
return vae_copy.to(device)
|
||||
|
||||
@remove_vgg
|
||||
def state_dict(self, *args, **kwargs):
|
||||
return super().state_dict(*args, **kwargs)
|
||||
|
||||
@remove_vgg
|
||||
def load_state_dict(self, *args, **kwargs):
|
||||
return super().load_state_dict(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def codebook(self):
|
||||
return self.vq.codebook
|
||||
|
||||
def encode(self, fmap):
|
||||
fmap = self.enc_dec.encode(fmap)
|
||||
return fmap
|
||||
|
||||
def decode(self, fmap, return_indices_and_loss = False):
|
||||
fmap, indices, commit_loss = self.vq(fmap)
|
||||
|
||||
fmap = self.enc_dec.decode(fmap)
|
||||
|
||||
if not return_indices_and_loss:
|
||||
return fmap
|
||||
|
||||
return fmap, indices, commit_loss
|
||||
|
||||
def forward(
|
||||
self,
|
||||
img,
|
||||
return_loss = False,
|
||||
return_discr_loss = False,
|
||||
return_recons = False,
|
||||
add_gradient_penalty = True
|
||||
):
|
||||
batch, channels, height, width, device = *img.shape, img.device
|
||||
assert height == self.image_size and width == self.image_size, 'height and width of input image must be equal to {self.image_size}'
|
||||
assert channels == self.channels, 'number of channels on image or sketch is not equal to the channels set on this VQGanVAE'
|
||||
|
||||
fmap = self.encode(img)
|
||||
|
||||
fmap, indices, commit_loss = self.decode(fmap, return_indices_and_loss = True)
|
||||
|
||||
if not return_loss and not return_discr_loss:
|
||||
return fmap
|
||||
|
||||
assert return_loss ^ return_discr_loss, 'you should either return autoencoder loss or discriminator loss, but not both'
|
||||
|
||||
# whether to return discriminator loss
|
||||
|
||||
if return_discr_loss:
|
||||
assert exists(self.discr), 'discriminator must exist to train it'
|
||||
|
||||
fmap.detach_()
|
||||
img.requires_grad_()
|
||||
|
||||
fmap_discr_logits, img_discr_logits = map(self.discr, (fmap, img))
|
||||
|
||||
discr_loss = self.discr_loss(fmap_discr_logits, img_discr_logits)
|
||||
|
||||
if add_gradient_penalty:
|
||||
gp = gradient_penalty(img, img_discr_logits)
|
||||
loss = discr_loss + gp
|
||||
|
||||
if return_recons:
|
||||
return loss, fmap
|
||||
|
||||
return loss
|
||||
|
||||
# reconstruction loss
|
||||
|
||||
recon_loss = self.recon_loss_fn(fmap, img)
|
||||
|
||||
# early return if training on grayscale
|
||||
|
||||
if not self.use_vgg_and_gan:
|
||||
if return_recons:
|
||||
return recon_loss, fmap
|
||||
|
||||
return recon_loss
|
||||
|
||||
# perceptual loss
|
||||
|
||||
img_vgg_input = img
|
||||
fmap_vgg_input = fmap
|
||||
|
||||
if img.shape[1] == 1:
|
||||
# handle grayscale for vgg
|
||||
img_vgg_input, fmap_vgg_input = map(lambda t: repeat(t, 'b 1 ... -> b c ...', c = 3), (img_vgg_input, fmap_vgg_input))
|
||||
|
||||
img_vgg_feats = self.vgg(img_vgg_input)
|
||||
recon_vgg_feats = self.vgg(fmap_vgg_input)
|
||||
perceptual_loss = F.mse_loss(img_vgg_feats, recon_vgg_feats)
|
||||
|
||||
# generator loss
|
||||
|
||||
gen_loss = self.gen_loss(self.discr(fmap))
|
||||
|
||||
# calculate adaptive weight
|
||||
|
||||
last_dec_layer = self.enc_dec.last_dec_layer
|
||||
|
||||
norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_dec_layer).norm(p = 2)
|
||||
norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_dec_layer).norm(p = 2)
|
||||
|
||||
adaptive_weight = safe_div(norm_grad_wrt_perceptual_loss, norm_grad_wrt_gen_loss)
|
||||
adaptive_weight.clamp_(max = 1e4)
|
||||
|
||||
# combine losses
|
||||
|
||||
loss = recon_loss + perceptual_loss + commit_loss + adaptive_weight * gen_loss
|
||||
|
||||
if return_recons:
|
||||
return loss, fmap
|
||||
|
||||
return loss
|
||||
277
dalle2_pytorch/vqgan_vae_trainer.py
Normal file
277
dalle2_pytorch/vqgan_vae_trainer.py
Normal file
@@ -0,0 +1,277 @@
|
||||
from math import sqrt
|
||||
import copy
|
||||
from random import choice
|
||||
from pathlib import Path
|
||||
from shutil import rmtree
|
||||
from PIL import Image
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.cuda.amp import autocast, GradScaler
|
||||
from torch.utils.data import Dataset, DataLoader, random_split
|
||||
|
||||
import torchvision.transforms as T
|
||||
from torchvision.datasets import ImageFolder
|
||||
from torchvision.utils import make_grid, save_image
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
from dalle2_pytorch.train import EMA
|
||||
from dalle2_pytorch.vqgan_vae import VQGanVAE
|
||||
from dalle2_pytorch.optimizer import get_optimizer
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def noop(*args, **kwargs):
|
||||
pass
|
||||
|
||||
def cycle(dl):
|
||||
while True:
|
||||
for data in dl:
|
||||
yield data
|
||||
|
||||
def cast_tuple(t):
|
||||
return t if isinstance(t, (tuple, list)) else (t,)
|
||||
|
||||
def yes_or_no(question):
|
||||
answer = input(f'{question} (y/n) ')
|
||||
return answer.lower() in ('yes', 'y')
|
||||
|
||||
def accum_log(log, new_logs):
|
||||
for key, new_value in new_logs.items():
|
||||
old_value = log.get(key, 0.)
|
||||
log[key] = old_value + new_value
|
||||
return log
|
||||
|
||||
# classes
|
||||
|
||||
class ImageDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
folder,
|
||||
image_size,
|
||||
exts = ['jpg', 'jpeg', 'png']
|
||||
):
|
||||
super().__init__()
|
||||
self.folder = folder
|
||||
self.image_size = image_size
|
||||
self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]
|
||||
|
||||
print(f'{len(self.paths)} training samples found at {folder}')
|
||||
|
||||
self.transform = T.Compose([
|
||||
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
|
||||
T.Resize(image_size),
|
||||
T.RandomHorizontalFlip(),
|
||||
T.CenterCrop(image_size),
|
||||
T.ToTensor()
|
||||
])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.paths)
|
||||
|
||||
def __getitem__(self, index):
|
||||
path = self.paths[index]
|
||||
img = Image.open(path)
|
||||
return self.transform(img)
|
||||
|
||||
# main trainer class
|
||||
|
||||
class VQGanVAETrainer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
vae,
|
||||
*,
|
||||
num_train_steps,
|
||||
lr,
|
||||
batch_size,
|
||||
folder,
|
||||
grad_accum_every,
|
||||
wd = 0.,
|
||||
save_results_every = 100,
|
||||
save_model_every = 1000,
|
||||
results_folder = './results',
|
||||
valid_frac = 0.05,
|
||||
random_split_seed = 42,
|
||||
ema_beta = 0.995,
|
||||
ema_update_after_step = 2000,
|
||||
ema_update_every = 10,
|
||||
apply_grad_penalty_every = 4,
|
||||
amp = False
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(vae, VQGanVAE), 'vae must be instance of VQGanVAE'
|
||||
image_size = vae.image_size
|
||||
|
||||
self.vae = vae
|
||||
self.ema_vae = EMA(vae, update_after_step = ema_update_after_step, update_every = ema_update_every)
|
||||
|
||||
self.register_buffer('steps', torch.Tensor([0]))
|
||||
|
||||
self.num_train_steps = num_train_steps
|
||||
self.batch_size = batch_size
|
||||
self.grad_accum_every = grad_accum_every
|
||||
|
||||
all_parameters = set(vae.parameters())
|
||||
discr_parameters = set(vae.discr.parameters())
|
||||
vae_parameters = all_parameters - discr_parameters
|
||||
|
||||
self.optim = get_optimizer(vae_parameters, lr = lr, wd = wd)
|
||||
self.discr_optim = get_optimizer(discr_parameters, lr = lr, wd = wd)
|
||||
|
||||
self.amp = amp
|
||||
self.scaler = GradScaler(enabled = amp)
|
||||
self.discr_scaler = GradScaler(enabled = amp)
|
||||
|
||||
# create dataset
|
||||
|
||||
self.ds = ImageDataset(folder, image_size = image_size)
|
||||
|
||||
# split for validation
|
||||
|
||||
if valid_frac > 0:
|
||||
train_size = int((1 - valid_frac) * len(self.ds))
|
||||
valid_size = len(self.ds) - train_size
|
||||
self.ds, self.valid_ds = random_split(self.ds, [train_size, valid_size], generator = torch.Generator().manual_seed(random_split_seed))
|
||||
print(f'training with dataset of {len(self.ds)} samples and validating with randomly splitted {len(self.valid_ds)} samples')
|
||||
else:
|
||||
self.valid_ds = self.ds
|
||||
print(f'training with shared training and valid dataset of {len(self.ds)} samples')
|
||||
|
||||
# dataloader
|
||||
|
||||
self.dl = cycle(DataLoader(
|
||||
self.ds,
|
||||
batch_size = batch_size,
|
||||
shuffle = True
|
||||
))
|
||||
|
||||
self.valid_dl = cycle(DataLoader(
|
||||
self.valid_ds,
|
||||
batch_size = batch_size,
|
||||
shuffle = True
|
||||
))
|
||||
|
||||
self.save_model_every = save_model_every
|
||||
self.save_results_every = save_results_every
|
||||
|
||||
self.apply_grad_penalty_every = apply_grad_penalty_every
|
||||
|
||||
self.results_folder = Path(results_folder)
|
||||
|
||||
if len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?'):
|
||||
rmtree(str(self.results_folder))
|
||||
|
||||
self.results_folder.mkdir(parents = True, exist_ok = True)
|
||||
|
||||
def train_step(self):
|
||||
device = next(self.vae.parameters()).device
|
||||
steps = int(self.steps.item())
|
||||
apply_grad_penalty = not (steps % self.apply_grad_penalty_every)
|
||||
|
||||
self.vae.train()
|
||||
|
||||
# logs
|
||||
|
||||
logs = {}
|
||||
|
||||
# update vae (generator)
|
||||
|
||||
for _ in range(self.grad_accum_every):
|
||||
img = next(self.dl)
|
||||
img = img.to(device)
|
||||
|
||||
with autocast(enabled = self.amp):
|
||||
loss = self.vae(
|
||||
img,
|
||||
return_loss = True,
|
||||
apply_grad_penalty = apply_grad_penalty
|
||||
)
|
||||
|
||||
|
||||
self.scaler.scale(loss / self.grad_accum_every).backward()
|
||||
|
||||
accum_log(logs, {'loss': loss.item() / self.grad_accum_every})
|
||||
|
||||
self.scaler.step(self.optim)
|
||||
self.scaler.update()
|
||||
self.optim.zero_grad()
|
||||
|
||||
# update discriminator
|
||||
|
||||
if exists(self.vae.discr):
|
||||
discr_loss = 0
|
||||
for _ in range(self.grad_accum_every):
|
||||
img = next(self.dl)
|
||||
img = img.to(device)
|
||||
|
||||
with autocast(enabled = self.amp):
|
||||
loss = self.vae(img, return_discr_loss = True)
|
||||
|
||||
self.discr_scaler.scale(loss / self.grad_accum_every).backward()
|
||||
|
||||
accum_log(logs, {'discr_loss': loss.item() / self.grad_accum_every})
|
||||
|
||||
self.discr_scaler.step(self.discr_optim)
|
||||
self.discr_scaler.update()
|
||||
self.discr_optim.zero_grad()
|
||||
|
||||
# log
|
||||
|
||||
print(f"{steps}: vae loss: {logs['loss']} - discr loss: {logs['discr_loss']}")
|
||||
|
||||
# update exponential moving averaged generator
|
||||
|
||||
self.ema_vae.update()
|
||||
|
||||
# sample results every so often
|
||||
|
||||
if not (steps % self.save_results_every):
|
||||
for model, filename in ((self.ema_vae.ema_model, f'{steps}.ema'), (self.vae, str(steps))):
|
||||
model.eval()
|
||||
|
||||
imgs = next(self.dl)
|
||||
imgs = imgs.to(device)
|
||||
|
||||
recons = model(imgs)
|
||||
nrows = int(sqrt(self.batch_size))
|
||||
|
||||
imgs_and_recons = torch.stack((imgs, recons), dim = 0)
|
||||
imgs_and_recons = rearrange(imgs_and_recons, 'r b ... -> (b r) ...')
|
||||
|
||||
imgs_and_recons = imgs_and_recons.detach().cpu().float().clamp(0., 1.)
|
||||
grid = make_grid(imgs_and_recons, nrow = 2, normalize = True, value_range = (0, 1))
|
||||
|
||||
logs['reconstructions'] = grid
|
||||
|
||||
save_image(grid, str(self.results_folder / f'{filename}.png'))
|
||||
|
||||
print(f'{steps}: saving to {str(self.results_folder)}')
|
||||
|
||||
# save model every so often
|
||||
|
||||
if not (steps % self.save_model_every):
|
||||
state_dict = self.vae.state_dict()
|
||||
model_path = str(self.results_folder / f'vae.{steps}.pt')
|
||||
torch.save(state_dict, model_path)
|
||||
|
||||
ema_state_dict = self.ema_vae.state_dict()
|
||||
model_path = str(self.results_folder / f'vae.{steps}.ema.pt')
|
||||
torch.save(ema_state_dict, model_path)
|
||||
|
||||
print(f'{steps}: saving model to {str(self.results_folder)}')
|
||||
|
||||
self.steps += 1
|
||||
return logs
|
||||
|
||||
def train(self, log_fn = noop):
|
||||
device = next(self.vae.parameters()).device
|
||||
|
||||
while self.steps < self.num_train_steps:
|
||||
logs = self.train_step()
|
||||
log_fn(logs)
|
||||
|
||||
print('training complete')
|
||||
BIN
samples/oxford.png
Normal file
BIN
samples/oxford.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 985 KiB |
24
setup.py
24
setup.py
@@ -7,13 +7,15 @@ setup(
|
||||
entry_points={
|
||||
'console_scripts': [
|
||||
'dalle2_pytorch = dalle2_pytorch.cli:main',
|
||||
'dream = dalle2_pytorch.cli:dream'
|
||||
],
|
||||
},
|
||||
version = '0.0.2',
|
||||
version = '0.5.4',
|
||||
license='MIT',
|
||||
description = 'DALL-E 2',
|
||||
author = 'Phil Wang',
|
||||
author_email = 'lucidrains@gmail.com',
|
||||
long_description_content_type = 'text/markdown',
|
||||
url = 'https://github.com/lucidrains/dalle2-pytorch',
|
||||
keywords = [
|
||||
'artificial intelligence',
|
||||
@@ -22,12 +24,26 @@ setup(
|
||||
],
|
||||
install_requires=[
|
||||
'click',
|
||||
'clip-anytorch',
|
||||
'coca-pytorch>=0.0.5',
|
||||
'einops>=0.4',
|
||||
'einops-exts',
|
||||
'einops-exts>=0.0.3',
|
||||
'embedding-reader',
|
||||
'kornia>=0.5.4',
|
||||
'numpy',
|
||||
'pillow',
|
||||
'pydantic',
|
||||
'resize-right>=0.0.2',
|
||||
'rotary-embedding-torch',
|
||||
'torch>=1.10',
|
||||
'torchvision',
|
||||
'x-clip>=0.4.1',
|
||||
'youtokentome'
|
||||
'tqdm',
|
||||
'vector-quantize-pytorch',
|
||||
'x-clip>=0.4.4',
|
||||
'youtokentome',
|
||||
'webdataset>=0.2.5',
|
||||
'fsspec>=2022.1.0',
|
||||
'torchmetrics[image]>=0.8.0'
|
||||
],
|
||||
classifiers=[
|
||||
'Development Status :: 4 - Beta',
|
||||
|
||||
442
train_decoder.py
Normal file
442
train_decoder.py
Normal file
@@ -0,0 +1,442 @@
|
||||
from dalle2_pytorch import Unet, Decoder
|
||||
from dalle2_pytorch.trainer import DecoderTrainer
|
||||
from dalle2_pytorch.dataloaders import create_image_embedding_dataloader
|
||||
from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker
|
||||
from dalle2_pytorch.train_configs import TrainDecoderConfig
|
||||
from dalle2_pytorch.utils import Timer, print_ribbon
|
||||
|
||||
import torchvision
|
||||
import torch
|
||||
from torchmetrics.image.fid import FrechetInceptionDistance
|
||||
from torchmetrics.image.inception import InceptionScore
|
||||
from torchmetrics.image.kid import KernelInceptionDistance
|
||||
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
|
||||
import webdataset as wds
|
||||
import click
|
||||
|
||||
# constants
|
||||
|
||||
TRAIN_CALC_LOSS_EVERY_ITERS = 10
|
||||
VALID_CALC_LOSS_EVERY_ITERS = 10
|
||||
|
||||
# helpers functions
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
# main functions
|
||||
|
||||
def create_dataloaders(
|
||||
available_shards,
|
||||
webdataset_base_url,
|
||||
embeddings_url,
|
||||
shard_width=6,
|
||||
num_workers=4,
|
||||
batch_size=32,
|
||||
n_sample_images=6,
|
||||
shuffle_train=True,
|
||||
resample_train=False,
|
||||
img_preproc = None,
|
||||
index_width=4,
|
||||
train_prop = 0.75,
|
||||
val_prop = 0.15,
|
||||
test_prop = 0.10,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Randomly splits the available shards into train, val, and test sets and returns a dataloader for each
|
||||
"""
|
||||
assert train_prop + test_prop + val_prop == 1
|
||||
num_train = round(train_prop*len(available_shards))
|
||||
num_test = round(test_prop*len(available_shards))
|
||||
num_val = len(available_shards) - num_train - num_test
|
||||
assert num_train + num_test + num_val == len(available_shards), f"{num_train} + {num_test} + {num_val} = {num_train + num_test + num_val} != {len(available_shards)}"
|
||||
train_split, test_split, val_split = torch.utils.data.random_split(available_shards, [num_train, num_test, num_val], generator=torch.Generator().manual_seed(0))
|
||||
|
||||
# The shard number in the webdataset file names has a fixed width. We zero pad the shard numbers so they correspond to a filename.
|
||||
train_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in train_split]
|
||||
test_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in test_split]
|
||||
val_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in val_split]
|
||||
|
||||
create_dataloader = lambda tar_urls, shuffle=False, resample=False, with_text=False, for_sampling=False: create_image_embedding_dataloader(
|
||||
tar_url=tar_urls,
|
||||
num_workers=num_workers,
|
||||
batch_size=batch_size if not for_sampling else n_sample_images,
|
||||
embeddings_url=embeddings_url,
|
||||
index_width=index_width,
|
||||
shuffle_num = None,
|
||||
extra_keys= ["txt"] if with_text else [],
|
||||
shuffle_shards = shuffle,
|
||||
resample_shards = resample,
|
||||
img_preproc=img_preproc,
|
||||
handler=wds.handlers.warn_and_continue
|
||||
)
|
||||
|
||||
train_dataloader = create_dataloader(train_urls, shuffle=shuffle_train, resample=resample_train)
|
||||
train_sampling_dataloader = create_dataloader(train_urls, shuffle=False, for_sampling=True)
|
||||
val_dataloader = create_dataloader(val_urls, shuffle=False, with_text=True)
|
||||
test_dataloader = create_dataloader(test_urls, shuffle=False, with_text=True)
|
||||
test_sampling_dataloader = create_dataloader(test_urls, shuffle=False, for_sampling=True)
|
||||
return {
|
||||
"train": train_dataloader,
|
||||
"train_sampling": train_sampling_dataloader,
|
||||
"val": val_dataloader,
|
||||
"test": test_dataloader,
|
||||
"test_sampling": test_sampling_dataloader
|
||||
}
|
||||
|
||||
def get_dataset_keys(dataloader):
|
||||
"""
|
||||
It is sometimes neccesary to get the keys the dataloader is returning. Since the dataset is burried in the dataloader, we need to do a process to recover it.
|
||||
"""
|
||||
# If the dataloader is actually a WebLoader, we need to extract the real dataloader
|
||||
if isinstance(dataloader, wds.WebLoader):
|
||||
dataloader = dataloader.pipeline[0]
|
||||
return dataloader.dataset.key_map
|
||||
|
||||
def get_example_data(dataloader, device, n=5):
|
||||
"""
|
||||
Samples the dataloader and returns a zipped list of examples
|
||||
"""
|
||||
images = []
|
||||
embeddings = []
|
||||
captions = []
|
||||
dataset_keys = get_dataset_keys(dataloader)
|
||||
has_caption = "txt" in dataset_keys
|
||||
for data in dataloader:
|
||||
if has_caption:
|
||||
img, emb, txt = data
|
||||
else:
|
||||
img, emb = data
|
||||
txt = [""] * emb.shape[0]
|
||||
img = img.to(device=device, dtype=torch.float)
|
||||
emb = emb.to(device=device, dtype=torch.float)
|
||||
images.extend(list(img))
|
||||
embeddings.extend(list(emb))
|
||||
captions.extend(list(txt))
|
||||
if len(images) >= n:
|
||||
break
|
||||
print("Generated {} examples".format(len(images)))
|
||||
return list(zip(images[:n], embeddings[:n], captions[:n]))
|
||||
|
||||
def generate_samples(trainer, example_data, text_prepend=""):
|
||||
"""
|
||||
Takes example data and generates images from the embeddings
|
||||
Returns three lists: real images, generated images, and captions
|
||||
"""
|
||||
real_images, embeddings, txts = zip(*example_data)
|
||||
embeddings_tensor = torch.stack(embeddings)
|
||||
samples = trainer.sample(embeddings_tensor)
|
||||
generated_images = list(samples)
|
||||
captions = [text_prepend + txt for txt in txts]
|
||||
return real_images, generated_images, captions
|
||||
|
||||
def generate_grid_samples(trainer, examples, text_prepend=""):
|
||||
"""
|
||||
Generates samples and uses torchvision to put them in a side by side grid for easy viewing
|
||||
"""
|
||||
real_images, generated_images, captions = generate_samples(trainer, examples, text_prepend)
|
||||
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
|
||||
return grid_images, captions
|
||||
|
||||
def evaluate_trainer(trainer, dataloader, device, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
|
||||
"""
|
||||
Computes evaluation metrics for the decoder
|
||||
"""
|
||||
metrics = {}
|
||||
# Prepare the data
|
||||
examples = get_example_data(dataloader, device, n_evaluation_samples)
|
||||
real_images, generated_images, captions = generate_samples(trainer, examples)
|
||||
real_images = torch.stack(real_images).to(device=device, dtype=torch.float)
|
||||
generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float)
|
||||
# Convert from [0, 1] to [0, 255] and from torch.float to torch.uint8
|
||||
int_real_images = real_images.mul(255).add(0.5).clamp(0, 255).type(torch.uint8)
|
||||
int_generated_images = generated_images.mul(255).add(0.5).clamp(0, 255).type(torch.uint8)
|
||||
if exists(FID):
|
||||
fid = FrechetInceptionDistance(**FID)
|
||||
fid.to(device=device)
|
||||
fid.update(int_real_images, real=True)
|
||||
fid.update(int_generated_images, real=False)
|
||||
metrics["FID"] = fid.compute().item()
|
||||
if exists(IS):
|
||||
inception = InceptionScore(**IS)
|
||||
inception.to(device=device)
|
||||
inception.update(int_real_images)
|
||||
is_mean, is_std = inception.compute()
|
||||
metrics["IS_mean"] = is_mean.item()
|
||||
metrics["IS_std"] = is_std.item()
|
||||
if exists(KID):
|
||||
kernel_inception = KernelInceptionDistance(**KID)
|
||||
kernel_inception.to(device=device)
|
||||
kernel_inception.update(int_real_images, real=True)
|
||||
kernel_inception.update(int_generated_images, real=False)
|
||||
kid_mean, kid_std = kernel_inception.compute()
|
||||
metrics["KID_mean"] = kid_mean.item()
|
||||
metrics["KID_std"] = kid_std.item()
|
||||
if exists(LPIPS):
|
||||
# Convert from [0, 1] to [-1, 1]
|
||||
renorm_real_images = real_images.mul(2).sub(1)
|
||||
renorm_generated_images = generated_images.mul(2).sub(1)
|
||||
lpips = LearnedPerceptualImagePatchSimilarity(**LPIPS)
|
||||
lpips.to(device=device)
|
||||
lpips.update(renorm_real_images, renorm_generated_images)
|
||||
metrics["LPIPS"] = lpips.compute().item()
|
||||
return metrics
|
||||
|
||||
def save_trainer(tracker, trainer, epoch, step, validation_losses, relative_paths):
|
||||
"""
|
||||
Logs the model with an appropriate method depending on the tracker
|
||||
"""
|
||||
if isinstance(relative_paths, str):
|
||||
relative_paths = [relative_paths]
|
||||
trainer_state_dict = {}
|
||||
trainer_state_dict["trainer"] = trainer.state_dict()
|
||||
trainer_state_dict['epoch'] = epoch
|
||||
trainer_state_dict['step'] = step
|
||||
trainer_state_dict['validation_losses'] = validation_losses
|
||||
for relative_path in relative_paths:
|
||||
tracker.save_state_dict(trainer_state_dict, relative_path)
|
||||
|
||||
def recall_trainer(tracker, trainer, recall_source=None, **load_config):
|
||||
"""
|
||||
Loads the model with an appropriate method depending on the tracker
|
||||
"""
|
||||
print(print_ribbon(f"Loading model from {recall_source}"))
|
||||
state_dict = tracker.recall_state_dict(recall_source, **load_config)
|
||||
trainer.load_state_dict(state_dict["trainer"])
|
||||
print("Model loaded")
|
||||
return state_dict["epoch"], state_dict["step"], state_dict["validation_losses"]
|
||||
|
||||
def train(
|
||||
dataloaders,
|
||||
decoder,
|
||||
tracker,
|
||||
inference_device,
|
||||
load_config=None,
|
||||
evaluate_config=None,
|
||||
epoch_samples = None, # If the training dataset is resampling, we have to manually stop an epoch
|
||||
validation_samples = None,
|
||||
epochs = 20,
|
||||
n_sample_images = 5,
|
||||
save_every_n_samples = 100000,
|
||||
save_all=False,
|
||||
save_latest=True,
|
||||
save_best=True,
|
||||
unet_training_mask=None,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Trains a decoder on a dataset.
|
||||
"""
|
||||
trainer = DecoderTrainer( # TODO: Change the get_optimizer function so that it can take arbitrary named args so we can just put **kwargs as an argument here
|
||||
decoder,
|
||||
**kwargs
|
||||
)
|
||||
# Set up starting model and parameters based on a recalled state dict
|
||||
start_step = 0
|
||||
start_epoch = 0
|
||||
validation_losses = []
|
||||
|
||||
if exists(load_config) and exists(load_config.source):
|
||||
start_epoch, start_step, validation_losses = recall_trainer(tracker, trainer, recall_source=load_config.source, **load_config)
|
||||
trainer.to(device=inference_device)
|
||||
|
||||
if not exists(unet_training_mask):
|
||||
# Then the unet mask should be true for all unets in the decoder
|
||||
unet_training_mask = [True] * trainer.num_unets
|
||||
assert len(unet_training_mask) == trainer.num_unets, f"The unet training mask should be the same length as the number of unets in the decoder. Got {len(unet_training_mask)} and {trainer.num_unets}"
|
||||
|
||||
print(print_ribbon("Generating Example Data", repeat=40))
|
||||
print("This can take a while to load the shard lists...")
|
||||
train_example_data = get_example_data(dataloaders["train_sampling"], inference_device, n_sample_images)
|
||||
test_example_data = get_example_data(dataloaders["test_sampling"], inference_device, n_sample_images)
|
||||
|
||||
send_to_device = lambda arr: [x.to(device=inference_device, dtype=torch.float) for x in arr]
|
||||
step = start_step
|
||||
|
||||
for epoch in range(start_epoch, epochs):
|
||||
print(print_ribbon(f"Starting epoch {epoch}", repeat=40))
|
||||
|
||||
timer = Timer()
|
||||
|
||||
sample = 0
|
||||
last_sample = 0
|
||||
last_snapshot = 0
|
||||
|
||||
losses = []
|
||||
|
||||
for i, (img, emb) in enumerate(dataloaders["train"]):
|
||||
step += 1
|
||||
sample += img.shape[0]
|
||||
img, emb = send_to_device((img, emb))
|
||||
|
||||
trainer.train()
|
||||
for unet in range(1, trainer.num_unets+1):
|
||||
# Check if this is a unet we are training
|
||||
if not unet_training_mask[unet-1]: # Unet index is the unet number - 1
|
||||
continue
|
||||
|
||||
loss = trainer.forward(img, image_embed=emb, unet_number=unet)
|
||||
trainer.update(unet_number=unet)
|
||||
losses.append(loss)
|
||||
|
||||
samples_per_sec = (sample - last_sample) / timer.elapsed()
|
||||
|
||||
timer.reset()
|
||||
last_sample = sample
|
||||
|
||||
if i % TRAIN_CALC_LOSS_EVERY_ITERS == 0:
|
||||
average_loss = sum(losses) / len(losses)
|
||||
log_data = {
|
||||
"Training loss": average_loss,
|
||||
"Epoch": epoch,
|
||||
"Sample": sample,
|
||||
"Step": i,
|
||||
"Samples per second": samples_per_sec
|
||||
}
|
||||
tracker.log(log_data, step=step, verbose=True)
|
||||
losses = []
|
||||
|
||||
if last_snapshot + save_every_n_samples < sample: # This will miss by some amount every time, but it's not a big deal... I hope
|
||||
last_snapshot = sample
|
||||
# We need to know where the model should be saved
|
||||
save_paths = []
|
||||
if save_latest:
|
||||
save_paths.append("latest.pth")
|
||||
if save_all:
|
||||
save_paths.append(f"checkpoints/epoch_{epoch}_step_{step}.pth")
|
||||
|
||||
save_trainer(tracker, trainer, epoch, step, validation_losses, save_paths)
|
||||
|
||||
if exists(n_sample_images) and n_sample_images > 0:
|
||||
trainer.eval()
|
||||
train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ")
|
||||
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step)
|
||||
|
||||
if exists(epoch_samples) and sample >= epoch_samples:
|
||||
break
|
||||
|
||||
trainer.eval()
|
||||
print(print_ribbon(f"Starting Validation {epoch}", repeat=40))
|
||||
with torch.no_grad():
|
||||
sample = 0
|
||||
average_loss = 0
|
||||
timer = Timer()
|
||||
for i, (img, emb, txt) in enumerate(dataloaders["val"]):
|
||||
sample += img.shape[0]
|
||||
img, emb = send_to_device((img, emb))
|
||||
|
||||
for unet in range(1, len(decoder.unets)+1):
|
||||
loss = trainer.forward(img.float(), image_embed=emb.float(), unet_number=unet)
|
||||
average_loss += loss
|
||||
|
||||
if i % VALID_CALC_LOSS_EVERY_ITERS == 0:
|
||||
print(f"Epoch {epoch}/{epochs} - {sample / timer.elapsed():.2f} samples/sec")
|
||||
print(f"Loss: {average_loss / (i+1)}")
|
||||
print("")
|
||||
|
||||
if exists(validation_samples) and sample >= validation_samples:
|
||||
break
|
||||
|
||||
average_loss /= i+1
|
||||
log_data = {
|
||||
"Validation loss": average_loss
|
||||
}
|
||||
tracker.log(log_data, step=step, verbose=True)
|
||||
|
||||
# Compute evaluation metrics
|
||||
if exists(evaluate_config):
|
||||
print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
|
||||
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config)
|
||||
tracker.log(evaluation, step=step, verbose=True)
|
||||
|
||||
# Generate sample images
|
||||
print(print_ribbon(f"Sampling Set {epoch}", repeat=40))
|
||||
test_images, test_captions = generate_grid_samples(trainer, test_example_data, "Test: ")
|
||||
train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ")
|
||||
tracker.log_images(test_images, captions=test_captions, image_section="Test Samples", step=step)
|
||||
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step)
|
||||
|
||||
print(print_ribbon(f"Starting Saving {epoch}", repeat=40))
|
||||
# Get the same paths
|
||||
save_paths = []
|
||||
if save_latest:
|
||||
save_paths.append("latest.pth")
|
||||
if save_best and (len(validation_losses) == 0 or average_loss < min(validation_losses)):
|
||||
save_paths.append("best.pth")
|
||||
validation_losses.append(average_loss)
|
||||
save_trainer(tracker, trainer, epoch, step, validation_losses, save_paths)
|
||||
|
||||
def create_tracker(config, tracker_type=None, data_path=None, **kwargs):
|
||||
"""
|
||||
Creates a tracker of the specified type and initializes special features based on the full config
|
||||
"""
|
||||
tracker_config = config.tracker
|
||||
init_config = {}
|
||||
|
||||
if exists(tracker_config.init_config):
|
||||
init_config["config"] = tracker_config.init_config
|
||||
|
||||
if tracker_type == "console":
|
||||
tracker = ConsoleTracker(**init_config)
|
||||
elif tracker_type == "wandb":
|
||||
# We need to initialize the resume state here
|
||||
load_config = config.load
|
||||
if load_config.source == "wandb" and load_config.resume:
|
||||
# Then we are resuming the run load_config["run_path"]
|
||||
run_id = load_config.run_path.split("/")[-1]
|
||||
init_config["id"] = run_id
|
||||
init_config["resume"] = "must"
|
||||
|
||||
init_config["entity"] = tracker_config.wandb_entity
|
||||
init_config["project"] = tracker_config.wandb_project
|
||||
tracker = WandbTracker(data_path)
|
||||
tracker.init(**init_config)
|
||||
else:
|
||||
raise ValueError(f"Tracker type {tracker_type} not supported by decoder trainer")
|
||||
return tracker
|
||||
|
||||
def initialize_training(config):
|
||||
# Create the save path
|
||||
if "cuda" in config.train.device:
|
||||
assert torch.cuda.is_available(), "CUDA is not available"
|
||||
device = torch.device(config.train.device)
|
||||
torch.cuda.set_device(device)
|
||||
all_shards = list(range(config.data.start_shard, config.data.end_shard + 1))
|
||||
|
||||
dataloaders = create_dataloaders (
|
||||
available_shards=all_shards,
|
||||
img_preproc = config.data.img_preproc,
|
||||
train_prop = config.data.splits.train,
|
||||
val_prop = config.data.splits.val,
|
||||
test_prop = config.data.splits.test,
|
||||
n_sample_images=config.train.n_sample_images,
|
||||
**config.data.dict()
|
||||
)
|
||||
|
||||
decoder = config.decoder.create().to(device = device)
|
||||
num_parameters = sum(p.numel() for p in decoder.parameters())
|
||||
print(print_ribbon("Loaded Config", repeat=40))
|
||||
print(f"Number of parameters: {num_parameters}")
|
||||
|
||||
tracker = create_tracker(config, **config.tracker.dict())
|
||||
|
||||
train(dataloaders, decoder,
|
||||
tracker=tracker,
|
||||
inference_device=device,
|
||||
load_config=config.load,
|
||||
evaluate_config=config.evaluate,
|
||||
**config.train.dict(),
|
||||
)
|
||||
|
||||
# Create a simple click command line interface to load the config and start the training
|
||||
@click.command()
|
||||
@click.option("--config_file", default="./train_decoder_config.json", help="Path to config file")
|
||||
def main(config_file):
|
||||
print("Recalling config from {}".format(config_file))
|
||||
config = TrainDecoderConfig.from_json_path(config_file)
|
||||
initialize_training(config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
365
train_diffusion_prior.py
Normal file
365
train_diffusion_prior.py
Normal file
@@ -0,0 +1,365 @@
|
||||
from pathlib import Path
|
||||
import click
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import clip
|
||||
from torch import nn
|
||||
|
||||
from dalle2_pytorch.dataloaders import make_splits
|
||||
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter
|
||||
from dalle2_pytorch.trainer import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model
|
||||
|
||||
from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker
|
||||
from dalle2_pytorch.utils import Timer, print_ribbon
|
||||
|
||||
from embedding_reader import EmbeddingReader
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
# constants
|
||||
|
||||
REPORT_METRICS_EVERY = 250 # for cosine similarity and other metric reporting during training
|
||||
|
||||
tracker = WandbTracker()
|
||||
|
||||
# helpers functions
|
||||
|
||||
def exists(val):
|
||||
val is not None
|
||||
|
||||
# functions
|
||||
|
||||
def eval_model(model, dataloader, text_conditioned, loss_type, phase="Validation"):
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
total_loss = 0.
|
||||
total_samples = 0.
|
||||
|
||||
for image_embeddings, text_data in tqdm(dataloader):
|
||||
|
||||
batches = image_embeddings.shape[0]
|
||||
|
||||
input_args = dict(image_embed=image_embeddings)
|
||||
if text_conditioned:
|
||||
input_args = dict(**input_args, text = text_data)
|
||||
else:
|
||||
input_args = dict(**input_args, text_embed=text_data)
|
||||
|
||||
loss = model(**input_args)
|
||||
|
||||
total_loss += loss * batches
|
||||
total_samples += batches
|
||||
|
||||
avg_loss = (total_loss / total_samples)
|
||||
|
||||
tracker.log({f'{phase} {loss_type}': avg_loss})
|
||||
|
||||
def report_cosine_sims(diffusion_prior, dataloader, text_conditioned):
|
||||
diffusion_prior.eval()
|
||||
|
||||
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
|
||||
|
||||
for test_image_embeddings, text_data in tqdm(dataloader):
|
||||
|
||||
# we are text conditioned, we produce an embedding from the tokenized text
|
||||
if text_conditioned:
|
||||
text_embedding, text_encodings, text_mask = diffusion_prior.clip.embed_text(
|
||||
text_data)
|
||||
text_cond = dict(text_embed=text_embedding,
|
||||
text_encodings=text_encodings, mask=text_mask)
|
||||
else:
|
||||
text_embedding = text_data
|
||||
text_cond = dict(text_embed=text_embedding)
|
||||
|
||||
# make a copy of the text embeddings for shuffling
|
||||
text_embed_shuffled = text_embedding.clone()
|
||||
|
||||
# roll the text to simulate "unrelated" captions
|
||||
rolled_idx = torch.roll(torch.arange(text_embedding.shape[0]), 1)
|
||||
text_embed_shuffled = text_embed_shuffled[rolled_idx]
|
||||
text_embed_shuffled = text_embed_shuffled / \
|
||||
text_embed_shuffled.norm(dim=1, keepdim=True)
|
||||
|
||||
if text_conditioned:
|
||||
text_encodings_shuffled = text_encodings[rolled_idx]
|
||||
text_mask_shuffled = text_mask[rolled_idx]
|
||||
else:
|
||||
text_encodings_shuffled = None
|
||||
text_mask_shuffled = None
|
||||
|
||||
text_cond_shuffled = dict(text_embed=text_embed_shuffled,
|
||||
text_encodings=text_encodings_shuffled, mask=text_mask_shuffled)
|
||||
|
||||
# prepare the text embedding
|
||||
text_embed = text_embedding / text_embedding.norm(dim=1, keepdim=True)
|
||||
|
||||
# prepare image embeddings
|
||||
test_image_embeddings = test_image_embeddings / \
|
||||
test_image_embeddings.norm(dim=1, keepdim=True)
|
||||
|
||||
# predict on the unshuffled text embeddings
|
||||
predicted_image_embeddings = diffusion_prior.p_sample_loop(
|
||||
test_image_embeddings.shape, text_cond)
|
||||
predicted_image_embeddings = predicted_image_embeddings / \
|
||||
predicted_image_embeddings.norm(dim=1, keepdim=True)
|
||||
|
||||
# predict on the shuffled embeddings
|
||||
predicted_unrelated_embeddings = diffusion_prior.p_sample_loop(
|
||||
test_image_embeddings.shape, text_cond_shuffled)
|
||||
predicted_unrelated_embeddings = predicted_unrelated_embeddings / \
|
||||
predicted_unrelated_embeddings.norm(dim=1, keepdim=True)
|
||||
|
||||
# calculate similarities
|
||||
original_similarity = cos(
|
||||
text_embed, test_image_embeddings).cpu().numpy()
|
||||
predicted_similarity = cos(
|
||||
text_embed, predicted_image_embeddings).cpu().numpy()
|
||||
unrelated_similarity = cos(
|
||||
text_embed, predicted_unrelated_embeddings).cpu().numpy()
|
||||
predicted_img_similarity = cos(
|
||||
test_image_embeddings, predicted_image_embeddings).cpu().numpy()
|
||||
tracker.log({"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity),
|
||||
"CosineSimilarity(text_embed,predicted_image_embed)":np.mean(predicted_similarity),
|
||||
"CosineSimilarity(orig_image_embed,predicted_image_embed)":np.mean(predicted_img_similarity),
|
||||
"CosineSimilarity(text_embed,predicted_unrelated_embed)": np.mean(unrelated_similarity),
|
||||
"Cosine similarity difference":np.mean(predicted_similarity - original_similarity)})
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option("--wandb-entity", default="laion")
|
||||
@click.option("--wandb-project", default="diffusion-prior")
|
||||
@click.option("--wandb-dataset", default="LAION-5B")
|
||||
@click.option("--wandb-arch", default="DiffusionPrior")
|
||||
@click.option("--image-embed-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/")
|
||||
@click.option("--text-embed-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/")
|
||||
@click.option("--meta-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/laion2B-en-metadata/")
|
||||
@click.option("--learning-rate", default=1.1e-4)
|
||||
@click.option("--weight-decay", default=6.02e-2)
|
||||
@click.option("--dropout", default=5e-2)
|
||||
@click.option("--max-grad-norm", default=0.5)
|
||||
@click.option("--num-data-points", default=250e6)
|
||||
@click.option("--batch-size", default=320)
|
||||
@click.option("--num-epochs", default=5)
|
||||
@click.option("--image-embed-dim", default=768)
|
||||
@click.option("--train-percent", default=0.9)
|
||||
@click.option("--val-percent", default=1e-7)
|
||||
@click.option("--test-percent", default=0.0999999)
|
||||
@click.option("--dpn-depth", default=12)
|
||||
@click.option("--dpn-dim-head", default=64)
|
||||
@click.option("--dpn-heads", default=12)
|
||||
@click.option("--dp-condition-on-text-encodings", default=True)
|
||||
@click.option("--dp-timesteps", default=1000)
|
||||
@click.option("--dp-normformer", default=True)
|
||||
@click.option("--dp-cond-drop-prob", default=0.1)
|
||||
@click.option("--dp-loss-type", default="l2")
|
||||
@click.option("--clip", default="ViT-L/14")
|
||||
@click.option("--amp", default=False)
|
||||
@click.option("--save-interval", default=120)
|
||||
@click.option("--save-path", default="./diffusion_prior_checkpoints")
|
||||
@click.option("--pretrained-model-path", default=None)
|
||||
@click.option("--gpu-device", default=0)
|
||||
def train(
|
||||
wandb_entity,
|
||||
wandb_project,
|
||||
wandb_dataset,
|
||||
wandb_arch,
|
||||
image_embed_url,
|
||||
text_embed_url,
|
||||
meta_url,
|
||||
learning_rate,
|
||||
weight_decay,
|
||||
dropout,
|
||||
max_grad_norm,
|
||||
num_data_points,
|
||||
batch_size,
|
||||
num_epochs,
|
||||
image_embed_dim,
|
||||
train_percent,
|
||||
val_percent,
|
||||
test_percent,
|
||||
dpn_depth,
|
||||
dpn_dim_head,
|
||||
dpn_heads,
|
||||
dp_condition_on_text_encodings,
|
||||
dp_timesteps,
|
||||
dp_normformer,
|
||||
dp_cond_drop_prob,
|
||||
dp_loss_type,
|
||||
clip,
|
||||
amp,
|
||||
save_interval,
|
||||
save_path,
|
||||
pretrained_model_path,
|
||||
gpu_device
|
||||
):
|
||||
config = {
|
||||
"learning_rate": learning_rate,
|
||||
"architecture": wandb_arch,
|
||||
"dataset": wandb_dataset,
|
||||
"weight_decay": weight_decay,
|
||||
"max_gradient_clipping_norm": max_grad_norm,
|
||||
"batch_size": batch_size,
|
||||
"epochs": num_epochs,
|
||||
"diffusion_prior_network": {
|
||||
"depth": dpn_depth,
|
||||
"dim_head": dpn_dim_head,
|
||||
"heads": dpn_heads,
|
||||
"normformer": dp_normformer
|
||||
},
|
||||
"diffusion_prior": {
|
||||
"condition_on_text_encodings": dp_condition_on_text_encodings,
|
||||
"timesteps": dp_timesteps,
|
||||
"cond_drop_prob": dp_cond_drop_prob,
|
||||
"loss_type": dp_loss_type,
|
||||
"clip": clip
|
||||
}
|
||||
}
|
||||
|
||||
# Check if DPRIOR_PATH exists(saved model path)
|
||||
|
||||
DPRIOR_PATH = pretrained_model_path
|
||||
RESUME = exists(DPRIOR_PATH)
|
||||
|
||||
if not RESUME:
|
||||
tracker.init(
|
||||
entity = wandb_entity,
|
||||
project = wandb_project,
|
||||
config = config
|
||||
)
|
||||
|
||||
# Obtain the utilized device.
|
||||
|
||||
has_cuda = torch.cuda.is_available()
|
||||
if has_cuda:
|
||||
device = torch.device(f"cuda:{gpu_device}")
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
# Training loop
|
||||
# diffusion prior network
|
||||
|
||||
prior_network = DiffusionPriorNetwork(
|
||||
dim = image_embed_dim,
|
||||
depth = dpn_depth,
|
||||
dim_head = dpn_dim_head,
|
||||
heads = dpn_heads,
|
||||
attn_dropout = dropout,
|
||||
ff_dropout = dropout,
|
||||
normformer = dp_normformer
|
||||
)
|
||||
|
||||
# Load clip model if text-conditioning
|
||||
if dp_condition_on_text_encodings:
|
||||
clip_adapter = OpenAIClipAdapter(clip)
|
||||
else:
|
||||
clip_adapter = None
|
||||
|
||||
# diffusion prior with text embeddings and image embeddings pre-computed
|
||||
|
||||
diffusion_prior = DiffusionPrior(
|
||||
net = prior_network,
|
||||
clip = clip_adapter,
|
||||
image_embed_dim = image_embed_dim,
|
||||
timesteps = dp_timesteps,
|
||||
cond_drop_prob = dp_cond_drop_prob,
|
||||
loss_type = dp_loss_type,
|
||||
condition_on_text_encodings = dp_condition_on_text_encodings
|
||||
)
|
||||
|
||||
# Load pre-trained model from DPRIOR_PATH
|
||||
|
||||
if RESUME:
|
||||
diffusion_prior, loaded_obj = load_diffusion_model(DPRIOR_PATH, device)
|
||||
tracker.init(entity = wandb_entity, project = wandb_project, config = config)
|
||||
|
||||
# diffusion prior trainer
|
||||
|
||||
trainer = DiffusionPriorTrainer(
|
||||
diffusion_prior = diffusion_prior,
|
||||
lr = learning_rate,
|
||||
wd = weight_decay,
|
||||
max_grad_norm = max_grad_norm,
|
||||
amp = amp,
|
||||
).to(device)
|
||||
|
||||
# load optimizer and scaler
|
||||
|
||||
if RESUME:
|
||||
trainer.optimizer.load_state_dict(loaded_obj['optimizer'])
|
||||
trainer.scaler.load_state_dict(loaded_obj['scaler'])
|
||||
|
||||
# Create save_path if it doesn't exist
|
||||
|
||||
Path(save_path).mkdir(exist_ok = True, parents = True)
|
||||
|
||||
# Utilize wrapper to abstract away loader logic
|
||||
print_ribbon("Downloading Embeddings")
|
||||
loader_args = dict(text_conditioned=dp_condition_on_text_encodings, batch_size=batch_size, num_data_points=num_data_points,
|
||||
train_split=train_percent, eval_split=val_percent, device=device, img_url=image_embed_url)
|
||||
|
||||
if dp_condition_on_text_encodings:
|
||||
loader_args = dict(**loader_args, meta_url=meta_url)
|
||||
else:
|
||||
loader_args = dict(**loader_args, txt_url=text_embed_url)
|
||||
|
||||
train_loader, eval_loader, test_loader = make_splits(**loader_args)
|
||||
|
||||
### Training code ###
|
||||
|
||||
step = 1
|
||||
timer = Timer()
|
||||
epochs = num_epochs
|
||||
|
||||
for _ in range(epochs):
|
||||
|
||||
for image, text in tqdm(train_loader):
|
||||
|
||||
diffusion_prior.train()
|
||||
|
||||
input_args = dict(image_embed=image)
|
||||
if dp_condition_on_text_encodings:
|
||||
input_args = dict(**input_args, text = text)
|
||||
else:
|
||||
input_args = dict(**input_args, text_embed=text)
|
||||
|
||||
loss = trainer(**input_args)
|
||||
|
||||
# Samples per second
|
||||
|
||||
samples_per_sec = batch_size * step / timer.elapsed()
|
||||
|
||||
# Save checkpoint every save_interval minutes
|
||||
if(int(timer.elapsed()) >= 60 * save_interval):
|
||||
timer.reset()
|
||||
|
||||
save_diffusion_model(
|
||||
save_path,
|
||||
diffusion_prior,
|
||||
trainer.optimizer,
|
||||
trainer.scaler,
|
||||
config,
|
||||
image_embed_dim)
|
||||
|
||||
# Log to wandb
|
||||
tracker.log({"Training loss": loss,
|
||||
"Steps": step,
|
||||
"Samples per second": samples_per_sec})
|
||||
# Log cosineSim(text_embed,predicted_image_embed) - cosineSim(text_embed,image_embed)
|
||||
# Use NUM_TEST_EMBEDDINGS samples from the test set each time
|
||||
# Get embeddings from the most recently saved model
|
||||
if(step % REPORT_METRICS_EVERY) == 0:
|
||||
report_cosine_sims(diffusion_prior, eval_loader, dp_condition_on_text_encodings)
|
||||
### Evaluate model(validation run) ###
|
||||
eval_model(diffusion_prior, eval_loader, dp_condition_on_text_encodings, dp_loss_type, phase="Validation")
|
||||
|
||||
step += 1
|
||||
trainer.update()
|
||||
|
||||
### Test run ###
|
||||
eval_model(diffusion_prior, test_loader, dp_condition_on_text_encodings, dp_loss_type, phase="Test")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
Reference in New Issue
Block a user