mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 19:44:26 +01:00
Compare commits
161 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9340d33d5f | ||
|
|
e0524a6aff | ||
|
|
c85e0d5c35 | ||
|
|
db0642c4cd | ||
|
|
bb86ab2404 | ||
|
|
ae056dd67c | ||
|
|
033d6b0ce8 | ||
|
|
c7ea8748db | ||
|
|
13382885d9 | ||
|
|
c3d4a7ffe4 | ||
|
|
164d9be444 | ||
|
|
5562ec6be2 | ||
|
|
89ff04cfe2 | ||
|
|
f4016f6302 | ||
|
|
1212f7058d | ||
|
|
dab106d4e5 | ||
|
|
bb151ca6b1 | ||
|
|
4a59dea4cf | ||
|
|
ecf9e8027d | ||
|
|
36c5079bd7 | ||
|
|
4a4c7ac9e6 | ||
|
|
fad7481479 | ||
|
|
123658d082 | ||
|
|
11d4e11f10 | ||
|
|
99778e12de | ||
|
|
0f0011caf0 | ||
|
|
7b7a62044a | ||
|
|
156fe5ed9f | ||
|
|
5ec34bebe1 | ||
|
|
8eaacf1ac1 | ||
|
|
e66c7b0249 | ||
|
|
f7cd4a0992 | ||
|
|
68e7d2f241 | ||
|
|
74f222596a | ||
|
|
aa6772dcff | ||
|
|
71d0c4edae | ||
|
|
f7eee09d8b | ||
|
|
89de5af63e | ||
|
|
4ec6d0ba81 | ||
|
|
aee92dba4a | ||
|
|
b0cd5f24b6 | ||
|
|
b494ed81d4 | ||
|
|
ff3474f05c | ||
|
|
d5293f19f1 | ||
|
|
e697183849 | ||
|
|
591d37e266 | ||
|
|
d1f02e8f49 | ||
|
|
9faab59b23 | ||
|
|
5d27029e98 | ||
|
|
3115fa17b3 | ||
|
|
124d8577c8 | ||
|
|
2db0c9794c | ||
|
|
2277b47ffd | ||
|
|
28b58e568c | ||
|
|
924455d97d | ||
|
|
6021945fc8 | ||
|
|
6f76652d11 | ||
|
|
3dda2570ed | ||
|
|
2f3c02dba8 | ||
|
|
908088cfea | ||
|
|
8dc8a3de0d | ||
|
|
35f89556ba | ||
|
|
2b55f753b9 | ||
|
|
fc8fce38fb | ||
|
|
a1bfb03ba4 | ||
|
|
b1e7b5f6bb | ||
|
|
10b905b445 | ||
|
|
9b322ea634 | ||
|
|
ba64ea45cc | ||
|
|
64f7be1926 | ||
|
|
db805e73e1 | ||
|
|
cb07b37970 | ||
|
|
a774bfefe2 | ||
|
|
2ae57f0cf5 | ||
|
|
e46eaec817 | ||
|
|
8647cb5e76 | ||
|
|
53c189e46a | ||
|
|
dde51fd362 | ||
|
|
2eac7996fa | ||
|
|
4010aec033 | ||
|
|
c87b84a259 | ||
|
|
8b05468653 | ||
|
|
830afd3c15 | ||
|
|
8f93729d19 | ||
|
|
cd5f2c1de4 | ||
|
|
85ed77d512 | ||
|
|
fd53fa17db | ||
|
|
3676ef4d49 | ||
|
|
28e944f328 | ||
|
|
14e63a3f67 | ||
|
|
09e9eaa5a6 | ||
|
|
e6d752cf4a | ||
|
|
ad20a14a4d | ||
|
|
0be1e0d64c | ||
|
|
98df1ba51e | ||
|
|
878b555ef7 | ||
|
|
63029f7388 | ||
|
|
c76a964fd6 | ||
|
|
79fabc4341 | ||
|
|
f7ef4bde38 | ||
|
|
93ba019069 | ||
|
|
8518684ae9 | ||
|
|
1d5dc08810 | ||
|
|
d8d8b6caf1 | ||
|
|
15acc03bd4 | ||
|
|
896f19786d | ||
|
|
aec5575d09 | ||
|
|
9773f10d6c | ||
|
|
a6bf8ddef6 | ||
|
|
86e692d24f | ||
|
|
97b751209f | ||
|
|
74103fd8d6 | ||
|
|
1992d25cad | ||
|
|
5b619c2fd5 | ||
|
|
9359ad2e91 | ||
|
|
9ff228188b | ||
|
|
2d9963d30e | ||
|
|
58d9b422f3 | ||
|
|
44b319cb57 | ||
|
|
c30f380689 | ||
|
|
e4e884bb8b | ||
|
|
803ad9c17d | ||
|
|
a88dd6a9c0 | ||
|
|
72c16b496e | ||
|
|
81d83dd7f2 | ||
|
|
fa66f7e1e9 | ||
|
|
aa8d135245 | ||
|
|
70282de23b | ||
|
|
83f761847e | ||
|
|
11469dc0c6 | ||
|
|
2d25c89f35 | ||
|
|
3fe96c208a | ||
|
|
0fc6c9cdf3 | ||
|
|
7ee0ecc388 | ||
|
|
1924c7cc3d | ||
|
|
f7df3caaf3 | ||
|
|
fc954ee788 | ||
|
|
c1db2753f5 | ||
|
|
ad87bfe28f | ||
|
|
76c767b1ce | ||
|
|
d991b8c39c | ||
|
|
902693e271 | ||
|
|
35cd63982d | ||
|
|
53ce6dfdf6 | ||
|
|
ad8d7a368b | ||
|
|
b8cf1e5c20 | ||
|
|
94aaa08d97 | ||
|
|
8b9bbec7d1 | ||
|
|
1bb9fc9829 | ||
|
|
5e421bd5bb | ||
|
|
67fcab1122 | ||
|
|
5bfbccda22 | ||
|
|
989275ff59 | ||
|
|
56408f4a40 | ||
|
|
d1a697ac23 | ||
|
|
ebe01749ed | ||
|
|
63195cc2cb | ||
|
|
a2ef69af66 | ||
|
|
5fff22834e | ||
|
|
a9421f49ec | ||
|
|
77fa34eae9 |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -1,3 +1,6 @@
|
||||
# default experiment tracker data
|
||||
.tracker-data/
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
|
||||
399
README.md
399
README.md
@@ -10,10 +10,20 @@ The main novelty seems to be an extra layer of indirection with the prior networ
|
||||
|
||||
This model is SOTA for text-to-image for now.
|
||||
|
||||
Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord" src="https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white"></a> if you are interested in helping out with the replication
|
||||
Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord" src="https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white"></a> if you are interested in helping out with the replication with the <a href="https://laion.ai/">LAION</a> community | <a href="https://www.youtube.com/watch?v=AIOE1l1W0Tw">Yannic Interview</a>
|
||||
|
||||
There was enough interest for a <a href="https://github.com/lucidrains/dalle2-jax">Jax version</a>. I will also eventually extend this to <a href="https://github.com/lucidrains/dalle2-video">text to video</a>, once the repository is in a good place.
|
||||
|
||||
## Status
|
||||
|
||||
- A research group has used the code in this repository to train a functional diffusion prior for their CLIP generations. Will share their work once they release their preprint. This, and <a href="https://github.com/crowsonkb">Katherine's</a> own experiments, validate OpenAI's finding that the extra prior increases variety of generations.
|
||||
|
||||
- Decoder is now verified working for unconditional generation on my experimental setup for Oxford flowers. 2 researchers have also confirmed Decoder is working for them.
|
||||
|
||||
<img src="./samples/oxford.png" width="600px" />
|
||||
|
||||
*ongoing at 21k steps*
|
||||
|
||||
## Install
|
||||
|
||||
```bash
|
||||
@@ -508,7 +518,7 @@ To use a pretrained OpenAI CLIP, simply import `OpenAIClipAdapter` and pass it i
|
||||
import torch
|
||||
from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder, OpenAIClipAdapter
|
||||
|
||||
# openai pretrained clip - defaults to ViT/B-32
|
||||
# openai pretrained clip - defaults to ViT-B/32
|
||||
|
||||
clip = OpenAIClipAdapter()
|
||||
|
||||
@@ -706,9 +716,309 @@ mock_image_embed = torch.randn(1, 512).cuda()
|
||||
images = decoder.sample(mock_image_embed) # (1, 3, 1024, 1024)
|
||||
```
|
||||
|
||||
## Training wrapper (wip)
|
||||
## Training wrapper
|
||||
|
||||
Offer training wrappers
|
||||
### Decoder Training
|
||||
|
||||
Training the `Decoder` may be confusing, as one needs to keep track of an optimizer for each of the `Unet`(s) separately. Each `Unet` will also need its own corresponding exponential moving average. The `DecoderTrainer` hopes to make this simple, as shown below
|
||||
|
||||
```python
|
||||
import torch
|
||||
from dalle2_pytorch import DALLE2, Unet, Decoder, CLIP, DecoderTrainer
|
||||
|
||||
clip = CLIP(
|
||||
dim_text = 512,
|
||||
dim_image = 512,
|
||||
dim_latent = 512,
|
||||
num_text_tokens = 49408,
|
||||
text_enc_depth = 6,
|
||||
text_seq_len = 256,
|
||||
text_heads = 8,
|
||||
visual_enc_depth = 6,
|
||||
visual_image_size = 256,
|
||||
visual_patch_size = 32,
|
||||
visual_heads = 8
|
||||
).cuda()
|
||||
|
||||
# mock data
|
||||
|
||||
text = torch.randint(0, 49408, (32, 256)).cuda()
|
||||
images = torch.randn(32, 3, 256, 256).cuda()
|
||||
|
||||
# decoder (with unet)
|
||||
|
||||
unet1 = Unet(
|
||||
dim = 128,
|
||||
image_embed_dim = 512,
|
||||
text_embed_dim = 512,
|
||||
cond_dim = 128,
|
||||
channels = 3,
|
||||
dim_mults=(1, 2, 4, 8)
|
||||
).cuda()
|
||||
|
||||
unet2 = Unet(
|
||||
dim = 16,
|
||||
image_embed_dim = 512,
|
||||
text_embed_dim = 512,
|
||||
cond_dim = 128,
|
||||
channels = 3,
|
||||
dim_mults = (1, 2, 4, 8, 16),
|
||||
cond_on_text_encodings = True
|
||||
).cuda()
|
||||
|
||||
decoder = Decoder(
|
||||
unet = (unet1, unet2),
|
||||
image_sizes = (128, 256),
|
||||
clip = clip,
|
||||
timesteps = 1000,
|
||||
condition_on_text_encodings = True
|
||||
).cuda()
|
||||
|
||||
decoder_trainer = DecoderTrainer(
|
||||
decoder,
|
||||
lr = 3e-4,
|
||||
wd = 1e-2,
|
||||
ema_beta = 0.99,
|
||||
ema_update_after_step = 1000,
|
||||
ema_update_every = 10,
|
||||
)
|
||||
|
||||
for unet_number in (1, 2):
|
||||
loss = decoder_trainer(
|
||||
images,
|
||||
text = text,
|
||||
unet_number = unet_number, # which unet to train on
|
||||
max_batch_size = 4 # gradient accumulation - this sets the maximum batch size in which to do forward and backwards pass - for this example 32 / 4 == 8 times
|
||||
)
|
||||
|
||||
decoder_trainer.update(unet_number) # update the specific unet as well as its exponential moving average
|
||||
|
||||
# after much training
|
||||
# you can sample from the exponentially moving averaged unets as so
|
||||
|
||||
mock_image_embed = torch.randn(4, 512).cuda()
|
||||
images = decoder_trainer.sample(mock_image_embed, text = text) # (4, 3, 256, 256)
|
||||
```
|
||||
|
||||
### Diffusion Prior Training
|
||||
|
||||
Similarly, one can use the `DiffusionPriorTrainer` to automatically instantiate and keep track of an exponential moving averaged prior.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, DiffusionPriorTrainer, Unet, Decoder, CLIP
|
||||
|
||||
clip = CLIP(
|
||||
dim_text = 512,
|
||||
dim_image = 512,
|
||||
dim_latent = 512,
|
||||
num_text_tokens = 49408,
|
||||
text_enc_depth = 6,
|
||||
text_seq_len = 256,
|
||||
text_heads = 8,
|
||||
visual_enc_depth = 6,
|
||||
visual_image_size = 256,
|
||||
visual_patch_size = 32,
|
||||
visual_heads = 8
|
||||
).cuda()
|
||||
|
||||
# mock data
|
||||
|
||||
text = torch.randint(0, 49408, (512, 256)).cuda()
|
||||
images = torch.randn(512, 3, 256, 256).cuda()
|
||||
|
||||
# prior networks (with transformer)
|
||||
|
||||
prior_network = DiffusionPriorNetwork(
|
||||
dim = 512,
|
||||
depth = 6,
|
||||
dim_head = 64,
|
||||
heads = 8
|
||||
).cuda()
|
||||
|
||||
diffusion_prior = DiffusionPrior(
|
||||
net = prior_network,
|
||||
clip = clip,
|
||||
timesteps = 100,
|
||||
cond_drop_prob = 0.2
|
||||
).cuda()
|
||||
|
||||
diffusion_prior_trainer = DiffusionPriorTrainer(
|
||||
diffusion_prior,
|
||||
lr = 3e-4,
|
||||
wd = 1e-2,
|
||||
ema_beta = 0.99,
|
||||
ema_update_after_step = 1000,
|
||||
ema_update_every = 10,
|
||||
)
|
||||
|
||||
loss = diffusion_prior_trainer(text, images, max_batch_size = 4)
|
||||
diffusion_prior_trainer.update() # this will update the optimizer as well as the exponential moving averaged diffusion prior
|
||||
|
||||
# after much of the above three lines in a loop
|
||||
# you can sample from the exponential moving average of the diffusion prior identically to how you do so for DiffusionPrior
|
||||
|
||||
image_embeds = diffusion_prior_trainer.sample(text, max_batch_size = 4) # (512, 512) - exponential moving averaged image embeddings
|
||||
```
|
||||
|
||||
## Bonus
|
||||
|
||||
### Unconditional Training
|
||||
|
||||
The repository also contains the means to train unconditional DDPM model, or even cascading DDPMs. You simply have to set `unconditional = True` in the `Decoder`
|
||||
|
||||
ex.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from dalle2_pytorch import Unet, Decoder, DecoderTrainer
|
||||
|
||||
# unet for the cascading ddpm
|
||||
|
||||
unet1 = Unet(
|
||||
dim = 128,
|
||||
dim_mults=(1, 2, 4, 8)
|
||||
).cuda()
|
||||
|
||||
unet2 = Unet(
|
||||
dim = 32,
|
||||
dim_mults = (1, 2, 4, 8, 16)
|
||||
).cuda()
|
||||
|
||||
# decoder, which contains the unets
|
||||
|
||||
decoder = Decoder(
|
||||
unet = (unet1, unet2),
|
||||
image_sizes = (256, 512), # first unet up to 256px, then second to 512px
|
||||
timesteps = 1000,
|
||||
unconditional = True
|
||||
).cuda()
|
||||
|
||||
# decoder trainer
|
||||
|
||||
decoder_trainer = DecoderTrainer(decoder)
|
||||
|
||||
# images (get a lot of this)
|
||||
|
||||
images = torch.randn(1, 3, 512, 512).cuda()
|
||||
|
||||
# feed images into decoder
|
||||
|
||||
for i in (1, 2):
|
||||
loss = decoder_trainer(images, unet_number = i)
|
||||
decoder_trainer.update(unet_number = i)
|
||||
|
||||
# do the above for many many many many images
|
||||
# then it will learn to generate images
|
||||
|
||||
images = decoder_trainer.sample(batch_size = 36, max_batch_size = 4) # (36, 3, 512, 512)
|
||||
```
|
||||
|
||||
## Dataloaders
|
||||
|
||||
### Decoder Dataloaders
|
||||
|
||||
In order to make loading data simple and efficient, we include some general dataloaders that can be used to train portions of the network.
|
||||
|
||||
#### Decoder: Image Embedding Dataset
|
||||
|
||||
When training the decoder (and up samplers if training together) in isolation, you will need to load images and corresponding image embeddings. This dataset can read two similar types of datasets. First, it can read a [webdataset](https://github.com/webdataset/webdataset) that contains `.jpg` and `.npy` files in the `.tar`s that contain the images and associated image embeddings respectively. Alternatively, you can also specify a source for the embeddings outside of the webdataset. In this case, the path to the embeddings should contain `.npy` files with the same shard numbers as the webdataset and there should be a correspondence between the filename of the `.jpg` and the index of the embedding in the `.npy`. So, for example, `0001.tar` from the webdataset with image `00010509.jpg` (the first 4 digits are the shard number and the last 4 are the index) in it should be paralleled by a `img_emb_0001.npy` which contains a NumPy array with the embedding at index 509.
|
||||
|
||||
Generating a dataset of this type:
|
||||
1. Use [img2dataset](https://github.com/rom1504/img2dataset) to generate a webdataset.
|
||||
2. Use [clip-retrieval](https://github.com/rom1504/clip-retrieval) to convert the images to embeddings.
|
||||
3. Use [embedding-dataset-reordering](https://github.com/Veldrovive/embedding-dataset-reordering) to reorder the embeddings into the expected format.
|
||||
|
||||
Usage:
|
||||
|
||||
```python
|
||||
from dalle2_pytorch.dataloaders import ImageEmbeddingDataset, create_image_embedding_dataloader
|
||||
|
||||
# Create a dataloader directly.
|
||||
dataloader = create_image_embedding_dataloader(
|
||||
tar_url="/path/or/url/to/webdataset/{0000..9999}.tar", # Uses braket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar
|
||||
embeddings_url="path/or/url/to/embeddings/folder", # Included if .npy files are not in webdataset. Left out or set to None otherwise
|
||||
num_workers=4,
|
||||
batch_size=32,
|
||||
shard_width=4, # If a file in the webdataset shard 3 is named 0003039.jpg, we know the shard width is 4 and the last three digits are the index
|
||||
shuffle_num=200, # Does a shuffle of the data with a buffer size of 200
|
||||
shuffle_shards=True, # Shuffle the order the shards are read in
|
||||
resample_shards=False, # Sample shards with replacement. If true, an epoch will be infinite unless stopped manually
|
||||
)
|
||||
for img, emb in dataloader:
|
||||
print(img.shape) # torch.Size([32, 3, 256, 256])
|
||||
print(emb.shape) # torch.Size([32, 512])
|
||||
# Train decoder only as shown above
|
||||
|
||||
# Or create a dataset without a loader so you can configure it manually
|
||||
dataset = ImageEmbeddingDataset(
|
||||
urls="/path/or/url/to/webdataset/{0000..9999}.tar",
|
||||
embedding_folder_url="path/or/url/to/embeddings/folder",
|
||||
shard_width=4,
|
||||
shuffle_shards=True,
|
||||
resample=False
|
||||
)
|
||||
```
|
||||
|
||||
### Scripts (wip)
|
||||
|
||||
#### `train_diffusion_prior.py`
|
||||
|
||||
This script allows training the DiffusionPrior on pre-computed text and image embeddings. The working example below elucidates this process.
|
||||
Please note that the script internally passes text_embed and image_embed to the DiffusionPrior, unlike the example below.
|
||||
|
||||
#### Usage
|
||||
|
||||
```bash
|
||||
$ python train_diffusion_prior.py
|
||||
```
|
||||
|
||||
The most significant parameters for the script are as follows:
|
||||
|
||||
- `image-embed-url`, default = `"https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/"`
|
||||
|
||||
- `text-embed-url`, default = `"https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/"`
|
||||
|
||||
- `image-embed-dim`, default = `768` - 768 corresponds to the ViT iL/14 embedding size,change it to what your chosen ViT generates
|
||||
|
||||
- `learning-rate`, default = `1.1e-4`
|
||||
|
||||
- `weight-decay`, default = `6.02e-2`
|
||||
|
||||
- `max-grad-norm`, default = `0.5`
|
||||
|
||||
- `batch-size`, default = `10 ** 4`
|
||||
|
||||
- `num-epochs`, default = `5`
|
||||
|
||||
- `clip`, default = `None` # Signals the prior to use pre-computed embeddings
|
||||
|
||||
#### Loading and Saving the DiffusionPrior model
|
||||
|
||||
Two methods are provided, load_diffusion_model and save_diffusion_model, the names being self-explanatory.
|
||||
|
||||
```python
|
||||
from dalle2_pytorch.train import load_diffusion_model, save_diffusion_model
|
||||
```
|
||||
|
||||
##### Loading
|
||||
|
||||
load_diffusion_model(dprior_path, device)
|
||||
dprior_path : path to saved model(.pth)
|
||||
device : the cuda device you're running on
|
||||
|
||||
##### Saving
|
||||
|
||||
save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim)
|
||||
save_path : path to save at
|
||||
model : object of Diffusion_Prior
|
||||
optimizer : optimizer object - see train_diffusion_prior.py for how to create one.
|
||||
e.g: optimizer = get_optimizer(diffusion_prior.net.parameters(), wd=weight_decay, lr=learning_rate)
|
||||
scaler : a GradScaler object.
|
||||
e.g: scaler = GradScaler(enabled=amp)
|
||||
config : config object created in train_diffusion_prior.py - see file for example.
|
||||
image_embed_dim - the dimension of the image_embedding
|
||||
e.g: 768
|
||||
|
||||
## CLI (wip)
|
||||
|
||||
@@ -741,14 +1051,35 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
- [x] use inheritance just this once for sharing logic between decoder and prior network ddpms
|
||||
- [x] bring in vit-vqgan https://arxiv.org/abs/2110.04627 for the latent diffusion
|
||||
- [x] abstract interface for CLIP adapter class, so other CLIPs can be brought in
|
||||
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
|
||||
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
|
||||
- [x] take care of mixed precision as well as gradient accumulation within decoder trainer
|
||||
- [x] just take care of the training for the decoder in a wrapper class, as each unet in the cascade will need its own optimizer
|
||||
- [x] bring in tools to train vqgan-vae
|
||||
- [x] add convnext backbone for vqgan-vae (in addition to vit [vit-vqgan] + resnet)
|
||||
- [x] make sure DDPMs can be run with traditional resnet blocks (but leave convnext as an option for experimentation)
|
||||
- [x] make sure for the latter unets in the cascade, one can train on crops for learning super resolution (constrain the unet to be only convolutions in that case, or allow conv-like attention with rel pos bias)
|
||||
- [x] offer setting in diffusion prior to split time and image embeddings into multiple tokens, configurable, for more surface area during attention
|
||||
- [x] make sure resnet hyperparameters can be configurable across unet depth (groups and expansion factor)
|
||||
- [x] pull logic for training diffusion prior into a class DiffusionPriorTrainer, for eventual script based + CLI based training
|
||||
- [x] make sure the cascading ddpm in the repository can be trained unconditionally, offer a one-line CLI tool for training on a folder of images
|
||||
- [x] bring in cross-scale embedding from iclr paper https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/crossformer.py#L14
|
||||
- [x] cross embed layers for downsampling, as an option
|
||||
- [x] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a>
|
||||
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
|
||||
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
|
||||
- [ ] just take care of the training for the decoder in a wrapper class, as each unet in the cascade will need its own optimizer
|
||||
- [ ] train on a toy task, offer in colab
|
||||
- [ ] think about how best to design a declarative training config that handles preencoding for prior and training of multiple networks in decoder
|
||||
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
|
||||
- [ ] bring in tools to train vqgan-vae
|
||||
- [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824
|
||||
- [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove
|
||||
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
|
||||
- [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783
|
||||
- [ ] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
|
||||
- [ ] bring in skip-layer excitatons (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training
|
||||
- [ ] decoder needs one day worth of refactor for tech debt
|
||||
- [ ] allow for unet to be able to condition non-cross attention style as well
|
||||
- [ ] for all model classes with hyperparameters that changes the network architecture, make it requirement that they must expose a config property, and write a simple function that asserts that it restores the object correctly
|
||||
- [ ] for both diffusion prior and decoder, all exponential moving averaged models needs to be saved and restored as well (as well as the step number)
|
||||
- [ ] read the paper, figure it out, and build it https://github.com/lucidrains/DALLE2-pytorch/issues/89
|
||||
|
||||
## Citations
|
||||
|
||||
@@ -779,10 +1110,12 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@inproceedings{Liu2022ACF,
|
||||
title = {A ConvNet for the 2020https://arxiv.org/abs/2112.11435s},
|
||||
author = {Zhuang Liu and Hanzi Mao and Chaozheng Wu and Christoph Feichtenhofer and Trevor Darrell and Saining Xie},
|
||||
year = {2022}
|
||||
@article{shen2019efficient,
|
||||
author = {Zhuoran Shen and Mingyuan Zhang and Haiyu Zhao and Shuai Yi and Hongsheng Li},
|
||||
title = {Efficient Attention: Attention with Linear Complexities},
|
||||
journal = {CoRR},
|
||||
year = {2018},
|
||||
url = {http://arxiv.org/abs/1812.01243},
|
||||
}
|
||||
```
|
||||
|
||||
@@ -804,4 +1137,44 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
}
|
||||
```
|
||||
|
||||
*Creating noise from data is easy; creating data from noise is generative modeling.* - Yang Song's <a href="https://arxiv.org/abs/2011.13456">paper</a>
|
||||
```bibtex
|
||||
@article{Shleifer2021NormFormerIT,
|
||||
title = {NormFormer: Improved Transformer Pretraining with Extra Normalization},
|
||||
author = {Sam Shleifer and Jason Weston and Myle Ott},
|
||||
journal = {ArXiv},
|
||||
year = {2021},
|
||||
volume = {abs/2110.09456}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@article{Yu2022CoCaCC,
|
||||
title = {CoCa: Contrastive Captioners are Image-Text Foundation Models},
|
||||
author = {Jiahui Yu and Zirui Wang and Vijay Vasudevan and Legg Yeung and Mojtaba Seyedhosseini and Yonghui Wu},
|
||||
journal = {ArXiv},
|
||||
year = {2022},
|
||||
volume = {abs/2205.01917}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{wang2021crossformer,
|
||||
title = {CrossFormer: A Versatile Vision Transformer Hinging on Cross-scale Attention},
|
||||
author = {Wenxiao Wang and Lu Yao and Long Chen and Binbin Lin and Deng Cai and Xiaofei He and Wei Liu},
|
||||
year = {2021},
|
||||
eprint = {2108.00154},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@article{ho2021cascaded,
|
||||
title = {Cascaded Diffusion Models for High Fidelity Image Generation},
|
||||
author = {Ho, Jonathan and Saharia, Chitwan and Chan, William and Fleet, David J and Norouzi, Mohammad and Salimans, Tim},
|
||||
journal = {arXiv preprint arXiv:2106.15282},
|
||||
year = {2021}
|
||||
}
|
||||
```
|
||||
|
||||
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
|
||||
from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter
|
||||
from dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer
|
||||
|
||||
from dalle2_pytorch.vqgan_vae import VQGanVAE
|
||||
from x_clip import CLIP
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import click
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
from functools import reduce
|
||||
from pathlib import Path
|
||||
|
||||
from dalle2_pytorch import DALLE2, Decoder, DiffusionPrior
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
41
dalle2_pytorch/dataloaders/README.md
Normal file
41
dalle2_pytorch/dataloaders/README.md
Normal file
@@ -0,0 +1,41 @@
|
||||
## Dataloaders
|
||||
In order to make loading data simple and efficient, we include some general dataloaders that can be used to train portions of the network.
|
||||
|
||||
### Decoder: Image Embedding Dataset
|
||||
When training the decoder (and up samplers if training together) in isolation, you will need to load images and corresponding image embeddings. This dataset can read two similar types of datasets. First, it can read a [webdataset](https://github.com/webdataset/webdataset) that contains `.jpg` and `.npy` files in the `.tar`s that contain the images and associated image embeddings respectively. Alternatively, you can also specify a source for the embeddings outside of the webdataset. In this case, the path to the embeddings should contain `.npy` files with the same shard numbers as the webdataset and there should be a correspondence between the filename of the `.jpg` and the index of the embedding in the `.npy`. So, for example, `0001.tar` from the webdataset with image `00010509.jpg` (the first 4 digits are the shard number and the last 4 are the index) in it should be paralleled by a `img_emb_0001.npy` which contains a NumPy array with the embedding at index 509.
|
||||
|
||||
Generating a dataset of this type:
|
||||
1. Use [img2dataset](https://github.com/rom1504/img2dataset) to generate a webdataset.
|
||||
2. Use [clip-retrieval](https://github.com/rom1504/clip-retrieval) to convert the images to embeddings.
|
||||
3. Use [embedding-dataset-reordering](https://github.com/Veldrovive/embedding-dataset-reordering) to reorder the embeddings into the expected format.
|
||||
|
||||
Usage:
|
||||
```python
|
||||
from dalle2_pytorch.dataloaders import ImageEmbeddingDataset, create_image_embedding_dataloader
|
||||
|
||||
# Create a dataloader directly.
|
||||
dataloader = create_image_embedding_dataloader(
|
||||
tar_url="/path/or/url/to/webdataset/{0000..9999}.tar", # Uses braket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar
|
||||
embeddings_url="path/or/url/to/embeddings/folder", # Included if .npy files are not in webdataset. Left out or set to None otherwise
|
||||
num_workers=4,
|
||||
batch_size=32,
|
||||
shard_width=4, # If a file in the webdataset shard 3 is named 0003039.jpg, we know the shard width is 4 and the last three digits are the index
|
||||
shuffle_num=200, # Does a shuffle of the data with a buffer size of 200
|
||||
shuffle_shards=True, # Shuffle the order the shards are read in
|
||||
resample_shards=False, # Sample shards with replacement. If true, an epoch will be infinite unless stopped manually
|
||||
)
|
||||
for img, emb in dataloader:
|
||||
print(img.shape) # torch.Size([32, 3, 256, 256])
|
||||
print(emb.shape) # torch.Size([32, 512])
|
||||
# Train decoder only as shown above
|
||||
|
||||
# Or create a dataset without a loader so you can configure it manually
|
||||
dataset = ImageEmbeddingDataset(
|
||||
urls="/path/or/url/to/webdataset/{0000..9999}.tar",
|
||||
embedding_folder_url="path/or/url/to/embeddings/folder",
|
||||
shard_width=4,
|
||||
shuffle_shards=True,
|
||||
resample=False
|
||||
)
|
||||
```
|
||||
|
||||
2
dalle2_pytorch/dataloaders/__init__.py
Normal file
2
dalle2_pytorch/dataloaders/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from dalle2_pytorch.dataloaders.decoder_loader import ImageEmbeddingDataset, create_image_embedding_dataloader
|
||||
from dalle2_pytorch.dataloaders.embedding_wrapper import make_splits
|
||||
234
dalle2_pytorch/dataloaders/decoder_loader.py
Normal file
234
dalle2_pytorch/dataloaders/decoder_loader.py
Normal file
@@ -0,0 +1,234 @@
|
||||
import os
|
||||
import webdataset as wds
|
||||
import torch
|
||||
import numpy as np
|
||||
import fsspec
|
||||
import shutil
|
||||
|
||||
def get_shard(filename):
|
||||
"""
|
||||
Filenames with shards in them have a consistent structure that we can take advantage of
|
||||
Standard structure: path/to/file/prefix_string_00001.ext
|
||||
"""
|
||||
try:
|
||||
return filename.split("_")[-1].split(".")[0]
|
||||
except ValueError:
|
||||
raise RuntimeError(f"Could not find shard for filename {filename}")
|
||||
|
||||
def get_example_file(fs, path, file_format):
|
||||
"""
|
||||
Given a file system and a file extension, return the example file
|
||||
"""
|
||||
return fs.glob(os.path.join(path, f"*.{file_format}"))[0]
|
||||
|
||||
def embedding_inserter(samples, embeddings_url, index_width, handler=wds.handlers.reraise_exception):
|
||||
"""Given a datum of {"__key__": str, "__url__": str, ...} adds the cooresponding embedding and yields"""
|
||||
previous_tar_url = None
|
||||
current_embeddings = None
|
||||
# Get a reference to an abstract file system where the embeddings are stored
|
||||
embeddings_fs, embeddings_path = fsspec.core.url_to_fs(embeddings_url)
|
||||
example_embedding_file = get_example_file(embeddings_fs, embeddings_path, "npy")
|
||||
example_embedding_shard = get_shard(example_embedding_file)
|
||||
emb_shard_width = len(example_embedding_shard)
|
||||
# Easier to get the basename without the shard once than search through for the correct file every time
|
||||
embedding_file_basename = '_'.join(example_embedding_file.split("_")[:-1]) + "_"
|
||||
|
||||
def load_corresponding_embeds(tar_url):
|
||||
"""Finds and reads the npy files that contains embeddings for the given webdataset tar"""
|
||||
shard = int(tar_url.split("/")[-1].split(".")[0])
|
||||
embedding_url = embedding_file_basename + str(shard).zfill(emb_shard_width) + '.npy'
|
||||
with embeddings_fs.open(embedding_url) as f:
|
||||
data = np.load(f)
|
||||
return torch.from_numpy(data)
|
||||
|
||||
for sample in samples:
|
||||
try:
|
||||
tar_url = sample["__url__"]
|
||||
key = sample["__key__"]
|
||||
if tar_url != previous_tar_url:
|
||||
# If the tar changed, we need to download new embeddings
|
||||
# This means if we shuffle before inserting it will load many more files than we expect and be very inefficient.
|
||||
previous_tar_url = tar_url
|
||||
current_embeddings = load_corresponding_embeds(tar_url)
|
||||
|
||||
embedding_index = int(key[-index_width:])
|
||||
embedding = current_embeddings[embedding_index]
|
||||
# We need to check if this sample is nonzero. If it is, this embedding is not valid and we should continue to the next loop
|
||||
if torch.count_nonzero(embedding) == 0:
|
||||
raise RuntimeError(f"Webdataset had a sample, but no embedding was found. ImgShard: {key[:-index_width]} - Index: {key[-index_width:]}")
|
||||
sample["npy"] = embedding
|
||||
yield sample
|
||||
except Exception as exn: # From wds implementation
|
||||
if handler(exn):
|
||||
continue
|
||||
else:
|
||||
break
|
||||
insert_embedding = wds.filters.pipelinefilter(embedding_inserter)
|
||||
|
||||
def unassociated_shard_skipper(tarfiles, embeddings_url, handler=wds.handlers.reraise_exception):
|
||||
"""Finds if the is a corresponding embedding for the tarfile at { url: [URL] }"""
|
||||
embeddings_fs, embeddings_path = fsspec.core.url_to_fs(embeddings_url)
|
||||
embedding_files = embeddings_fs.ls(embeddings_path)
|
||||
get_embedding_shard = lambda embedding_file: int(embedding_file.split("_")[-1].split(".")[0])
|
||||
embedding_shards = set([get_embedding_shard(filename) for filename in embedding_files]) # Sets have O(1) check for member
|
||||
|
||||
get_tar_shard = lambda tar_file: int(tar_file.split("/")[-1].split(".")[0])
|
||||
for tarfile in tarfiles:
|
||||
try:
|
||||
webdataset_shard = get_tar_shard(tarfile["url"])
|
||||
# If this shard has an associated embeddings file, we pass it through. Otherwise we iterate until we do have one
|
||||
if webdataset_shard in embedding_shards:
|
||||
yield tarfile
|
||||
except Exception as exn: # From wds implementation
|
||||
if handler(exn):
|
||||
continue
|
||||
else:
|
||||
break
|
||||
|
||||
skip_unassociated_shards = wds.filters.pipelinefilter(unassociated_shard_skipper)
|
||||
|
||||
def verify_keys(samples, handler=wds.handlers.reraise_exception):
|
||||
"""
|
||||
Requires that both the image and embedding are present in the sample
|
||||
This is important to do as a user may forget they do not have embeddings in their webdataset and neglect to add them using the embedding_folder_url parameter.
|
||||
"""
|
||||
for sample in samples:
|
||||
try:
|
||||
assert "jpg" in sample, f"Sample {sample['__key__']} missing image"
|
||||
assert "npy" in sample, f"Sample {sample['__key__']} missing embedding. Did you set embedding_folder_url?"
|
||||
yield sample
|
||||
except Exception as exn: # From wds implementation
|
||||
if handler(exn):
|
||||
continue
|
||||
else:
|
||||
break
|
||||
|
||||
class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
|
||||
"""
|
||||
A fluid interface wrapper for DataPipline that returns image embedding pairs
|
||||
Reads embeddings as npy files from the webdataset if they exist. If embedding_folder_url is set, they will be inserted in from the alternate source.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
urls,
|
||||
embedding_folder_url=None,
|
||||
index_width=None,
|
||||
img_preproc=None,
|
||||
extra_keys=[],
|
||||
handler=wds.handlers.reraise_exception,
|
||||
resample=False,
|
||||
shuffle_shards=True
|
||||
):
|
||||
"""
|
||||
Modeled directly off of the WebDataset constructor
|
||||
|
||||
:param urls: A url pointing to the tar files of the webdataset formatted as /path/to/webdataset/{0000..9999}.tar
|
||||
:param embedding_folder_url: Required if webdataset does not contain embeddings. A url pointing to the npy files of the embeddings. Should have the same number of shards as the webdataset.
|
||||
Webdataset image keys should align with the index of the embedding. This means missing image indices must have a corresponding embedding of all zeros.
|
||||
:param index_width: The number of digits in the index. This is used to align the embedding index with the image index.
|
||||
For example, if a file in the webdataset shard 3 is named 0003039.jpg, we know the shard is 4 digits and the last 3 digits are the index_width.
|
||||
:param img_preproc: This function is run on the img before it is batched and returned. Useful for data augmentation or converting to torch tensor.
|
||||
:param handler: A webdataset handler.
|
||||
:param resample: If true, resample webdataset shards with replacement. You need to set your own epoch size if this is true since it will resample infinitely.
|
||||
:param shuffle_shards: If true, shuffle the shards before resampling. This cannot be true if resample is true.
|
||||
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
keys = ["jpg", "npy"] + extra_keys
|
||||
self.key_map = {key: i for i, key in enumerate(keys)}
|
||||
self.resampling = resample
|
||||
self.img_preproc = img_preproc
|
||||
# If s3, check if s3fs is installed and s3cmd is installed and check if the data is piped instead of straight up
|
||||
if (isinstance(urls, str) and "s3:" in urls) or (isinstance(urls, list) and any(["s3:" in url for url in urls])):
|
||||
# Then this has an s3 link for the webdataset and we need extra packages
|
||||
if shutil.which("s3cmd") is None:
|
||||
raise RuntimeError("s3cmd is required for s3 webdataset")
|
||||
if "s3:" in embedding_folder_url:
|
||||
# Then the embeddings are being loaded from s3 and fsspec requires s3fs
|
||||
try:
|
||||
import s3fs
|
||||
except ImportError:
|
||||
raise RuntimeError("s3fs is required to load embeddings from s3")
|
||||
# Add the shardList and randomize or resample if requested
|
||||
if resample:
|
||||
assert not shuffle_shards, "Cannot both resample and shuffle"
|
||||
self.append(wds.ResampledShards(urls))
|
||||
else:
|
||||
self.append(wds.SimpleShardList(urls))
|
||||
if shuffle_shards:
|
||||
self.append(wds.filters.shuffle(1000))
|
||||
|
||||
if embedding_folder_url is not None:
|
||||
# There may be webdataset shards that do not have a embedding shard associated with it. If we do not skip these, they would cause issues.
|
||||
self.append(skip_unassociated_shards(embeddings_url=embedding_folder_url, handler=handler))
|
||||
|
||||
self.append(wds.split_by_node)
|
||||
self.append(wds.split_by_worker)
|
||||
|
||||
self.append(wds.tarfile_to_samples(handler=handler))
|
||||
self.append(wds.decode("pilrgb", handler=handler))
|
||||
if embedding_folder_url is not None:
|
||||
# Then we are loading embeddings for a remote source
|
||||
assert index_width is not None, "Reading embeddings separately requires index width length to be given"
|
||||
self.append(insert_embedding(embeddings_url=embedding_folder_url, index_width=index_width, handler=handler))
|
||||
self.append(verify_keys)
|
||||
# Apply preprocessing
|
||||
self.append(wds.map(self.preproc))
|
||||
self.append(wds.to_tuple(*keys))
|
||||
|
||||
def preproc(self, sample):
|
||||
"""Applies the preprocessing for images"""
|
||||
if self.img_preproc is not None:
|
||||
sample["jpg"] = self.img_preproc(sample["jpg"])
|
||||
return sample
|
||||
|
||||
def create_image_embedding_dataloader(
|
||||
tar_url,
|
||||
num_workers,
|
||||
batch_size,
|
||||
embeddings_url=None,
|
||||
index_width=None,
|
||||
shuffle_num = None,
|
||||
shuffle_shards = True,
|
||||
resample_shards = False,
|
||||
img_preproc=None,
|
||||
extra_keys=[],
|
||||
handler=wds.handlers.reraise_exception#warn_and_continue
|
||||
):
|
||||
"""
|
||||
Convenience function to create an image embedding dataseta and dataloader in one line
|
||||
|
||||
:param tar_url: A url pointing to the tar files of the webdataset formatted as /path/to/webdataset/{0000..9999}.tar
|
||||
:param num_workers: The number of workers to use for the dataloader
|
||||
:param batch_size: The batch size to use for the dataloader
|
||||
:param embeddings_url: Required if webdataset does not contain embeddings. A url pointing to the npy files of the embeddings. Should have the same number of shards as the webdataset.
|
||||
Webdataset image keys should align with the index of the embedding. This means missing image indices must have a corresponding embedding of all zeros.
|
||||
:param index_width: The number of digits in the index. This is used to align the embedding index with the image index.
|
||||
For example, if a file in the webdataset shard 3 is named 0003039.jpg, we know the shard is 4 digits and the last 3 digits are the index_width.
|
||||
:param shuffle_num: If not None, shuffle the dataset with this size buffer after sampling.
|
||||
:param shuffle_shards: If true, shuffle the shards before sampling. This cannot be true if resample is true.
|
||||
:param resample_shards: If true, resample webdataset shards with replacement. You need to set your own epoch size if this is true since it will resample infinitely.
|
||||
:param handler: A webdataset handler.
|
||||
"""
|
||||
ds = ImageEmbeddingDataset(
|
||||
tar_url,
|
||||
embeddings_url,
|
||||
index_width=index_width,
|
||||
shuffle_shards=shuffle_shards,
|
||||
resample=resample_shards,
|
||||
extra_keys=extra_keys,
|
||||
img_preproc=img_preproc,
|
||||
handler=handler
|
||||
)
|
||||
if shuffle_num is not None and shuffle_num > 0:
|
||||
ds.shuffle(1000)
|
||||
return wds.WebLoader(
|
||||
ds,
|
||||
num_workers=num_workers,
|
||||
batch_size=batch_size,
|
||||
prefetch_factor=2, # This might be good to have high so the next npy file is prefetched
|
||||
pin_memory=True,
|
||||
shuffle=False
|
||||
)
|
||||
180
dalle2_pytorch/dataloaders/embedding_wrapper.py
Normal file
180
dalle2_pytorch/dataloaders/embedding_wrapper.py
Normal file
@@ -0,0 +1,180 @@
|
||||
from torch.utils.data import IterableDataset
|
||||
from torch import from_numpy
|
||||
from clip import tokenize
|
||||
from embedding_reader import EmbeddingReader
|
||||
|
||||
|
||||
class PriorEmbeddingLoader(IterableDataset):
|
||||
def __init__(
|
||||
self,
|
||||
text_conditioned: bool,
|
||||
batch_size: int,
|
||||
start: int,
|
||||
stop: int,
|
||||
image_reader,
|
||||
text_reader: EmbeddingReader = None,
|
||||
device: str = "cpu",
|
||||
) -> None:
|
||||
super(PriorEmbeddingLoader).__init__()
|
||||
|
||||
self.text_conditioned = text_conditioned
|
||||
|
||||
if not self.text_conditioned:
|
||||
self.text_reader = text_reader
|
||||
|
||||
self.image_reader = image_reader
|
||||
self.batch_size = batch_size
|
||||
self.start = start
|
||||
self.stop = stop
|
||||
self.device = device
|
||||
|
||||
def __iter__(self):
|
||||
self.n = 0
|
||||
loader_args = dict(
|
||||
batch_size=self.batch_size,
|
||||
start=self.start,
|
||||
end=self.stop,
|
||||
show_progress=False,
|
||||
)
|
||||
if self.text_conditioned:
|
||||
self.loader = self.image_reader(**loader_args)
|
||||
else:
|
||||
self.loader = zip(
|
||||
self.image_reader(**loader_args), self.text_reader(**loader_args)
|
||||
)
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
try:
|
||||
return self.get_sample()
|
||||
except StopIteration:
|
||||
raise StopIteration
|
||||
|
||||
def get_sample(self):
|
||||
"""
|
||||
pre-proocess data from either reader into a common format
|
||||
"""
|
||||
self.n += 1
|
||||
|
||||
if self.text_conditioned:
|
||||
image_embedding, caption = next(self.loader)
|
||||
|
||||
image_embedding = from_numpy(image_embedding).to(self.device)
|
||||
tokenized_caption = tokenize(
|
||||
caption["caption"].to_list(), truncate=True
|
||||
).to(self.device)
|
||||
|
||||
return image_embedding, tokenized_caption
|
||||
|
||||
else:
|
||||
(image_embedding, _), (text_embedding, _) = next(self.loader)
|
||||
|
||||
image_embedding = from_numpy(image_embedding).to(self.device)
|
||||
text_embedding = from_numpy(text_embedding).to(self.device)
|
||||
|
||||
return image_embedding, text_embedding
|
||||
|
||||
|
||||
def make_splits(
|
||||
text_conditioned: bool,
|
||||
batch_size: int,
|
||||
num_data_points: int,
|
||||
train_split: float,
|
||||
eval_split: float,
|
||||
device: str,
|
||||
img_url: str,
|
||||
meta_url: str = None,
|
||||
txt_url: str = None,
|
||||
):
|
||||
|
||||
assert img_url is not None, "Must supply some image embeddings"
|
||||
|
||||
if text_conditioned:
|
||||
assert meta_url is not None, "Must supply metadata url if text-conditioning"
|
||||
image_reader = EmbeddingReader(
|
||||
embeddings_folder=img_url,
|
||||
file_format="parquet_npy",
|
||||
meta_columns=["caption"],
|
||||
metadata_folder=meta_url,
|
||||
)
|
||||
|
||||
# compute split points
|
||||
if num_data_points > image_reader.count:
|
||||
print("Specified point count is larger than the number of points available...defaulting to max length of reader.")
|
||||
num_data_points = image_reader.count
|
||||
|
||||
train_set_size = int(train_split * num_data_points)
|
||||
eval_set_size = int(eval_split * num_data_points)
|
||||
eval_stop = int(train_set_size + eval_set_size)
|
||||
|
||||
train_loader = PriorEmbeddingLoader(
|
||||
text_conditioned=text_conditioned,
|
||||
image_reader=image_reader,
|
||||
batch_size=batch_size,
|
||||
start=0,
|
||||
stop=train_set_size,
|
||||
device=device,
|
||||
)
|
||||
eval_loader = PriorEmbeddingLoader(
|
||||
text_conditioned=text_conditioned,
|
||||
image_reader=image_reader,
|
||||
batch_size=batch_size,
|
||||
start=train_set_size,
|
||||
stop=eval_stop,
|
||||
device=device,
|
||||
)
|
||||
test_loader = PriorEmbeddingLoader(
|
||||
text_conditioned=text_conditioned,
|
||||
image_reader=image_reader,
|
||||
batch_size=batch_size,
|
||||
start=eval_stop,
|
||||
stop=int(num_data_points),
|
||||
device=device,
|
||||
)
|
||||
|
||||
else:
|
||||
assert (
|
||||
txt_url is not None
|
||||
), "Must supply text embedding url if not text-conditioning"
|
||||
|
||||
image_reader = EmbeddingReader(img_url, file_format="npy")
|
||||
text_reader = EmbeddingReader(txt_url, file_format="npy")
|
||||
|
||||
# compute split points
|
||||
if num_data_points > image_reader.count:
|
||||
print("Specified point count is larger than the number of points available...defaulting to max length of reader.")
|
||||
num_data_points = image_reader.count
|
||||
|
||||
train_set_size = int(train_split * num_data_points)
|
||||
eval_set_size = int(eval_split * num_data_points)
|
||||
eval_stop = int(train_set_size + eval_set_size)
|
||||
|
||||
train_loader = PriorEmbeddingLoader(
|
||||
text_conditioned=text_conditioned,
|
||||
image_reader=image_reader,
|
||||
text_reader=text_reader,
|
||||
batch_size=batch_size,
|
||||
start=0,
|
||||
stop=train_set_size,
|
||||
device=device,
|
||||
)
|
||||
eval_loader = PriorEmbeddingLoader(
|
||||
text_conditioned=text_conditioned,
|
||||
image_reader=image_reader,
|
||||
text_reader=text_reader,
|
||||
batch_size=batch_size,
|
||||
start=train_set_size,
|
||||
stop=eval_stop,
|
||||
device=device,
|
||||
)
|
||||
test_loader = PriorEmbeddingLoader(
|
||||
text_conditioned=text_conditioned,
|
||||
image_reader=image_reader,
|
||||
text_reader=text_reader,
|
||||
batch_size=batch_size,
|
||||
start=eval_stop,
|
||||
stop=int(num_data_points),
|
||||
device=device,
|
||||
)
|
||||
|
||||
return train_loader, eval_loader, test_loader
|
||||
59
dalle2_pytorch/dataloaders/simple_image_only_dataloader.py
Normal file
59
dalle2_pytorch/dataloaders/simple_image_only_dataloader.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch.utils import data
|
||||
from torchvision import transforms, utils
|
||||
|
||||
from PIL import Image
|
||||
|
||||
# helpers functions
|
||||
|
||||
def cycle(dl):
|
||||
while True:
|
||||
for data in dl:
|
||||
yield data
|
||||
|
||||
# dataset and dataloader
|
||||
|
||||
class Dataset(data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
folder,
|
||||
image_size,
|
||||
exts = ['jpg', 'jpeg', 'png']
|
||||
):
|
||||
super().__init__()
|
||||
self.folder = folder
|
||||
self.image_size = image_size
|
||||
self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]
|
||||
|
||||
self.transform = transforms.Compose([
|
||||
transforms.Resize(image_size),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.CenterCrop(image_size),
|
||||
transforms.ToTensor()
|
||||
])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.paths)
|
||||
|
||||
def __getitem__(self, index):
|
||||
path = self.paths[index]
|
||||
img = Image.open(path)
|
||||
return self.transform(img)
|
||||
|
||||
def get_images_dataloader(
|
||||
folder,
|
||||
*,
|
||||
batch_size,
|
||||
image_size,
|
||||
shuffle = True,
|
||||
cycle_dl = True,
|
||||
pin_memory = True
|
||||
):
|
||||
ds = Dataset(folder, image_size)
|
||||
dl = data.DataLoader(ds, batch_size = batch_size, shuffle = shuffle, pin_memory = pin_memory)
|
||||
|
||||
if cycle_dl:
|
||||
dl = cycle(dl)
|
||||
return dl
|
||||
@@ -7,16 +7,17 @@ def separate_weight_decayable_params(params):
|
||||
|
||||
def get_optimizer(
|
||||
params,
|
||||
lr = 3e-4,
|
||||
lr = 1e-4,
|
||||
wd = 1e-2,
|
||||
betas = (0.9, 0.999),
|
||||
eps = 1e-8,
|
||||
filter_by_requires_grad = False
|
||||
):
|
||||
if filter_by_requires_grad:
|
||||
params = list(filter(lambda t: t.requires_grad, params))
|
||||
|
||||
if wd == 0:
|
||||
return Adam(params, lr = lr, betas = betas)
|
||||
return Adam(params, lr = lr, betas = betas, eps = eps)
|
||||
|
||||
params = set(params)
|
||||
wd_params, no_wd_params = separate_weight_decayable_params(params)
|
||||
@@ -26,4 +27,4 @@ def get_optimizer(
|
||||
{'params': list(no_wd_params), 'weight_decay': 0},
|
||||
]
|
||||
|
||||
return AdamW(param_groups, lr = lr, weight_decay = wd, betas = betas)
|
||||
return AdamW(param_groups, lr = lr, weight_decay = wd, betas = betas, eps = eps)
|
||||
|
||||
117
dalle2_pytorch/trackers.py
Normal file
117
dalle2_pytorch/trackers.py
Normal file
@@ -0,0 +1,117 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from enum import Enum
|
||||
import importlib
|
||||
from itertools import zip_longest
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
# constants
|
||||
|
||||
DEFAULT_DATA_PATH = './.tracker-data'
|
||||
|
||||
# helper functions
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def import_or_print_error(pkg_name, err_str = None):
|
||||
try:
|
||||
return importlib.import_module(pkg_name)
|
||||
except ModuleNotFoundError as e:
|
||||
if exists(err_str):
|
||||
print(err_str)
|
||||
exit()
|
||||
|
||||
# load state dict functions
|
||||
|
||||
def load_wandb_state_dict(run_path, file_path, **kwargs):
|
||||
wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb recall function')
|
||||
file_reference = wandb.restore(file_path, run_path=run_path)
|
||||
return torch.load(file_reference.name)
|
||||
|
||||
def load_local_state_dict(file_path, **kwargs):
|
||||
return torch.load(file_path)
|
||||
|
||||
# base class
|
||||
|
||||
class BaseTracker(nn.Module):
|
||||
def __init__(self, data_path = DEFAULT_DATA_PATH):
|
||||
super().__init__()
|
||||
assert data_path is not None, "Tracker must have a data_path to save local content"
|
||||
self.data_path = Path(data_path)
|
||||
self.data_path.mkdir(parents = True, exist_ok = True)
|
||||
|
||||
def init(self, config, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def log(self, log, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def log_images(self, images, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def save_state_dict(self, state_dict, relative_path, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def recall_state_dict(self, recall_source, *args, **kwargs):
|
||||
"""
|
||||
Loads a state dict from any source.
|
||||
Since a user may wish to load a model from a different source than their own tracker (i.e. tracking using wandb but recalling from disk),
|
||||
this should not be linked to any individual tracker.
|
||||
"""
|
||||
# TODO: Pull this into a dict or something similar so that we can add more sources without having a massive switch statement
|
||||
if recall_source == 'wandb':
|
||||
return load_wandb_state_dict(*args, **kwargs)
|
||||
elif recall_source == 'local':
|
||||
return load_local_state_dict(*args, **kwargs)
|
||||
else:
|
||||
raise ValueError('`recall_source` must be one of `wandb` or `local`')
|
||||
|
||||
|
||||
# basic stdout class
|
||||
|
||||
class ConsoleTracker(BaseTracker):
|
||||
def init(self, **config):
|
||||
print(config)
|
||||
|
||||
def log(self, log, **kwargs):
|
||||
print(log)
|
||||
|
||||
def log_images(self, images, **kwargs): # noop for logging images
|
||||
pass
|
||||
|
||||
def save_state_dict(self, state_dict, relative_path, **kwargs):
|
||||
torch.save(state_dict, str(self.data_path / relative_path))
|
||||
|
||||
# basic wandb class
|
||||
|
||||
class WandbTracker(BaseTracker):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb experiment tracker')
|
||||
os.environ["WANDB_SILENT"] = "true"
|
||||
|
||||
def init(self, **config):
|
||||
self.wandb.init(**config)
|
||||
|
||||
def log(self, log, verbose=False, **kwargs):
|
||||
if verbose:
|
||||
print(log)
|
||||
self.wandb.log(log, **kwargs)
|
||||
|
||||
def log_images(self, images, captions=[], image_section="images", **kwargs):
|
||||
"""
|
||||
Takes a tensor of images and a list of captions and logs them to wandb.
|
||||
"""
|
||||
wandb_images = [self.wandb.Image(image, caption=caption) for image, caption in zip_longest(images, captions)]
|
||||
self.wandb.log({ image_section: wandb_images }, **kwargs)
|
||||
|
||||
def save_state_dict(self, state_dict, relative_path, **kwargs):
|
||||
"""
|
||||
Saves a state_dict to disk and uploads it
|
||||
"""
|
||||
full_path = str(self.data_path / relative_path)
|
||||
torch.save(state_dict, full_path)
|
||||
self.wandb.save(full_path, base_path = str(self.data_path)) # Upload and keep relative to data_path
|
||||
@@ -1,53 +0,0 @@
|
||||
import copy
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
# exponential moving average wrapper
|
||||
|
||||
class EMA(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
beta = 0.99,
|
||||
ema_update_after_step = 1000,
|
||||
ema_update_every = 10,
|
||||
):
|
||||
super().__init__()
|
||||
self.beta = beta
|
||||
self.online_model = model
|
||||
self.ema_model = copy.deepcopy(model)
|
||||
|
||||
self.ema_update_after_step = ema_update_after_step # only start EMA after this step number, starting at 0
|
||||
self.ema_update_every = ema_update_every
|
||||
|
||||
self.register_buffer('initted', torch.Tensor([False]))
|
||||
self.register_buffer('step', torch.tensor([0.]))
|
||||
|
||||
def update(self):
|
||||
self.step += 1
|
||||
|
||||
if self.step <= self.ema_update_after_step or (self.step % self.ema_update_every) != 0:
|
||||
return
|
||||
|
||||
if not self.initted:
|
||||
self.ema_model.state_dict(self.online_model.state_dict())
|
||||
self.initted.data.copy_(torch.Tensor([True]))
|
||||
|
||||
self.update_moving_average(self.ema_model, self.online_model)
|
||||
|
||||
def update_moving_average(self, ma_model, current_model):
|
||||
def calculate_ema(beta, old, new):
|
||||
if not exists(old):
|
||||
return new
|
||||
return old * beta + (1 - beta) * new
|
||||
|
||||
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
|
||||
old_weight, up_weight = ma_params.data, current_params.data
|
||||
ma_params.data = calculate_ema(self.beta, old_weight, up_weight)
|
||||
|
||||
for current_buffer, ma_buffer in zip(current_model.buffers(), ma_model.buffers()):
|
||||
new_buffer_value = calculate_ema(self.beta, ma_buffer, current_buffer)
|
||||
ma_buffer.copy_(new_buffer_value)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.ema_model(*args, **kwargs)
|
||||
491
dalle2_pytorch/trainer.py
Normal file
491
dalle2_pytorch/trainer.py
Normal file
@@ -0,0 +1,491 @@
|
||||
import time
|
||||
import copy
|
||||
from math import ceil
|
||||
from functools import partial, wraps
|
||||
from collections.abc import Iterable
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.cuda.amp import autocast, GradScaler
|
||||
|
||||
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
|
||||
from dalle2_pytorch.optimizer import get_optimizer
|
||||
|
||||
import numpy as np
|
||||
|
||||
# helper functions
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
def cast_tuple(val, length = 1):
|
||||
return val if isinstance(val, tuple) else ((val,) * length)
|
||||
|
||||
def pick_and_pop(keys, d):
|
||||
values = list(map(lambda key: d.pop(key), keys))
|
||||
return dict(zip(keys, values))
|
||||
|
||||
def group_dict_by_key(cond, d):
|
||||
return_val = [dict(),dict()]
|
||||
for key in d.keys():
|
||||
match = bool(cond(key))
|
||||
ind = int(not match)
|
||||
return_val[ind][key] = d[key]
|
||||
return (*return_val,)
|
||||
|
||||
def string_begins_with(prefix, str):
|
||||
return str.startswith(prefix)
|
||||
|
||||
def group_by_key_prefix(prefix, d):
|
||||
return group_dict_by_key(partial(string_begins_with, prefix), d)
|
||||
|
||||
def groupby_prefix_and_trim(prefix, d):
|
||||
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
|
||||
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
|
||||
return kwargs_without_prefix, kwargs
|
||||
|
||||
def num_to_groups(num, divisor):
|
||||
groups = num // divisor
|
||||
remainder = num % divisor
|
||||
arr = [divisor] * groups
|
||||
if remainder > 0:
|
||||
arr.append(remainder)
|
||||
return arr
|
||||
|
||||
# decorators
|
||||
|
||||
def cast_torch_tensor(fn):
|
||||
@wraps(fn)
|
||||
def inner(model, *args, **kwargs):
|
||||
device = kwargs.pop('_device', next(model.parameters()).device)
|
||||
cast_device = kwargs.pop('_cast_device', True)
|
||||
|
||||
kwargs_keys = kwargs.keys()
|
||||
all_args = (*args, *kwargs.values())
|
||||
split_kwargs_index = len(all_args) - len(kwargs_keys)
|
||||
all_args = tuple(map(lambda t: torch.from_numpy(t) if exists(t) and isinstance(t, np.ndarray) else t, all_args))
|
||||
|
||||
if cast_device:
|
||||
all_args = tuple(map(lambda t: t.to(device) if exists(t) and isinstance(t, torch.Tensor) else t, all_args))
|
||||
|
||||
args, kwargs_values = all_args[:split_kwargs_index], all_args[split_kwargs_index:]
|
||||
kwargs = dict(tuple(zip(kwargs_keys, kwargs_values)))
|
||||
|
||||
out = fn(model, *args, **kwargs)
|
||||
return out
|
||||
return inner
|
||||
|
||||
# gradient accumulation functions
|
||||
|
||||
def split_iterable(it, split_size):
|
||||
accum = []
|
||||
for ind in range(ceil(len(it) / split_size)):
|
||||
start_index = ind * split_size
|
||||
accum.append(it[start_index: (start_index + split_size)])
|
||||
return accum
|
||||
|
||||
def split(t, split_size = None):
|
||||
if not exists(split_size):
|
||||
return t
|
||||
|
||||
if isinstance(t, torch.Tensor):
|
||||
return t.split(split_size, dim = 0)
|
||||
|
||||
if isinstance(t, Iterable):
|
||||
return split_iterable(t, split_size)
|
||||
|
||||
return TypeError
|
||||
|
||||
def find_first(cond, arr):
|
||||
for el in arr:
|
||||
if cond(el):
|
||||
return el
|
||||
return None
|
||||
|
||||
def split_args_and_kwargs(*args, split_size = None, **kwargs):
|
||||
all_args = (*args, *kwargs.values())
|
||||
len_all_args = len(all_args)
|
||||
first_tensor = find_first(lambda t: isinstance(t, torch.Tensor), all_args)
|
||||
assert exists(first_tensor)
|
||||
|
||||
batch_size = len(first_tensor)
|
||||
split_size = default(split_size, batch_size)
|
||||
num_chunks = ceil(batch_size / split_size)
|
||||
|
||||
dict_len = len(kwargs)
|
||||
dict_keys = kwargs.keys()
|
||||
split_kwargs_index = len_all_args - dict_len
|
||||
|
||||
split_all_args = [split(arg, split_size = split_size) if exists(arg) and isinstance(arg, (torch.Tensor, Iterable)) else ((arg,) * num_chunks) for arg in all_args]
|
||||
chunk_sizes = tuple(map(len, split_all_args[0]))
|
||||
|
||||
for (chunk_size, *chunked_all_args) in tuple(zip(chunk_sizes, *split_all_args)):
|
||||
chunked_args, chunked_kwargs_values = chunked_all_args[:split_kwargs_index], chunked_all_args[split_kwargs_index:]
|
||||
chunked_kwargs = dict(tuple(zip(dict_keys, chunked_kwargs_values)))
|
||||
chunk_size_frac = chunk_size / batch_size
|
||||
yield chunk_size_frac, (chunked_args, chunked_kwargs)
|
||||
|
||||
# print helpers
|
||||
|
||||
def print_ribbon(s, symbol = '=', repeat = 40):
|
||||
flank = symbol * repeat
|
||||
return f'{flank} {s} {flank}'
|
||||
|
||||
# saving and loading functions
|
||||
|
||||
# for diffusion prior
|
||||
|
||||
def load_diffusion_model(dprior_path, device):
|
||||
dprior_path = Path(dprior_path)
|
||||
assert dprior_path.exists(), 'Dprior model file does not exist'
|
||||
loaded_obj = torch.load(str(dprior_path), map_location='cpu')
|
||||
|
||||
# Get hyperparameters of loaded model
|
||||
dpn_config = loaded_obj['hparams']['diffusion_prior_network']
|
||||
dp_config = loaded_obj['hparams']['diffusion_prior']
|
||||
image_embed_dim = loaded_obj['image_embed_dim']['image_embed_dim']
|
||||
|
||||
# Create DiffusionPriorNetwork and DiffusionPrior with loaded hyperparameters
|
||||
|
||||
# DiffusionPriorNetwork
|
||||
prior_network = DiffusionPriorNetwork( dim = image_embed_dim, **dpn_config).to(device)
|
||||
|
||||
# DiffusionPrior with text embeddings and image embeddings pre-computed
|
||||
diffusion_prior = DiffusionPrior(net = prior_network, **dp_config, image_embed_dim = image_embed_dim).to(device)
|
||||
|
||||
# Load state dict from saved model
|
||||
diffusion_prior.load_state_dict(loaded_obj['model'])
|
||||
|
||||
return diffusion_prior, loaded_obj
|
||||
|
||||
def save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim):
|
||||
# Saving State Dict
|
||||
print_ribbon('Saving checkpoint')
|
||||
|
||||
state_dict = dict(model=model.state_dict(),
|
||||
optimizer=optimizer.state_dict(),
|
||||
scaler=scaler.state_dict(),
|
||||
hparams = config,
|
||||
image_embed_dim = {"image_embed_dim":image_embed_dim})
|
||||
torch.save(state_dict, save_path+'/'+str(time.time())+'_saved_model.pth')
|
||||
|
||||
# exponential moving average wrapper
|
||||
|
||||
class EMA(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
beta = 0.9999,
|
||||
update_after_step = 1000,
|
||||
update_every = 10,
|
||||
):
|
||||
super().__init__()
|
||||
self.beta = beta
|
||||
self.online_model = model
|
||||
self.ema_model = copy.deepcopy(model)
|
||||
|
||||
self.update_every = update_every
|
||||
self.update_after_step = update_after_step // update_every # only start EMA after this step number, starting at 0
|
||||
|
||||
self.register_buffer('initted', torch.Tensor([False]))
|
||||
self.register_buffer('step', torch.tensor([0.]))
|
||||
|
||||
def restore_ema_model_device(self):
|
||||
device = self.initted.device
|
||||
self.ema_model.to(device)
|
||||
|
||||
def copy_params_from_model_to_ema(self):
|
||||
self.ema_model.state_dict(self.online_model.state_dict())
|
||||
|
||||
def update(self):
|
||||
self.step += 1
|
||||
|
||||
if (self.step % self.update_every) != 0:
|
||||
return
|
||||
|
||||
if self.step <= self.update_after_step:
|
||||
self.copy_params_from_model_to_ema()
|
||||
return
|
||||
|
||||
if not self.initted:
|
||||
self.copy_params_from_model_to_ema()
|
||||
self.initted.data.copy_(torch.Tensor([True]))
|
||||
|
||||
self.update_moving_average(self.ema_model, self.online_model)
|
||||
|
||||
def update_moving_average(self, ma_model, current_model):
|
||||
def calculate_ema(beta, old, new):
|
||||
if not exists(old):
|
||||
return new
|
||||
return old * beta + (1 - beta) * new
|
||||
|
||||
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
|
||||
old_weight, up_weight = ma_params.data, current_params.data
|
||||
ma_params.data = calculate_ema(self.beta, old_weight, up_weight)
|
||||
|
||||
for current_buffer, ma_buffer in zip(current_model.buffers(), ma_model.buffers()):
|
||||
new_buffer_value = calculate_ema(self.beta, ma_buffer, current_buffer)
|
||||
ma_buffer.copy_(new_buffer_value)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.ema_model(*args, **kwargs)
|
||||
|
||||
# diffusion prior trainer
|
||||
|
||||
def prior_sample_in_chunks(fn):
|
||||
@wraps(fn)
|
||||
def inner(self, *args, max_batch_size = None, **kwargs):
|
||||
if not exists(max_batch_size):
|
||||
return fn(self, *args, **kwargs)
|
||||
|
||||
outputs = [fn(self, *chunked_args, **chunked_kwargs) for _, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs)]
|
||||
return torch.cat(outputs, dim = 0)
|
||||
return inner
|
||||
|
||||
class DiffusionPriorTrainer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
diffusion_prior,
|
||||
use_ema = True,
|
||||
lr = 3e-4,
|
||||
wd = 1e-2,
|
||||
eps = 1e-6,
|
||||
max_grad_norm = None,
|
||||
amp = False,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(diffusion_prior, DiffusionPrior)
|
||||
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
|
||||
|
||||
self.diffusion_prior = diffusion_prior
|
||||
|
||||
# exponential moving average
|
||||
|
||||
self.use_ema = use_ema
|
||||
if self.use_ema:
|
||||
self.ema_diffusion_prior = EMA(diffusion_prior, **ema_kwargs)
|
||||
|
||||
# optimizer and mixed precision stuff
|
||||
|
||||
self.amp = amp
|
||||
|
||||
self.scaler = GradScaler(enabled = amp)
|
||||
|
||||
self.optimizer = get_optimizer(
|
||||
diffusion_prior.parameters(),
|
||||
lr = lr,
|
||||
wd = wd,
|
||||
eps = eps,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
# gradient clipping if needed
|
||||
|
||||
self.max_grad_norm = max_grad_norm
|
||||
|
||||
self.register_buffer('step', torch.tensor([0.]))
|
||||
|
||||
def update(self):
|
||||
if exists(self.max_grad_norm):
|
||||
self.scaler.unscale_(self.optimizer)
|
||||
nn.utils.clip_grad_norm_(self.diffusion_prior.parameters(), self.max_grad_norm)
|
||||
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
if self.use_ema:
|
||||
self.ema_diffusion_prior.update()
|
||||
|
||||
self.step += 1
|
||||
|
||||
@torch.no_grad()
|
||||
@cast_torch_tensor
|
||||
@prior_sample_in_chunks
|
||||
def p_sample_loop(self, *args, **kwargs):
|
||||
return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs)
|
||||
|
||||
@torch.no_grad()
|
||||
@cast_torch_tensor
|
||||
@prior_sample_in_chunks
|
||||
def sample(self, *args, **kwargs):
|
||||
return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_batch_size(self, *args, **kwargs):
|
||||
return self.ema_diffusion_prior.ema_model.sample_batch_size(*args, **kwargs)
|
||||
|
||||
@cast_torch_tensor
|
||||
def forward(
|
||||
self,
|
||||
*args,
|
||||
max_batch_size = None,
|
||||
**kwargs
|
||||
):
|
||||
total_loss = 0.
|
||||
|
||||
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
|
||||
with autocast(enabled = self.amp):
|
||||
loss = self.diffusion_prior(*chunked_args, **chunked_kwargs)
|
||||
loss = loss * chunk_size_frac
|
||||
|
||||
total_loss += loss.item()
|
||||
|
||||
if self.training:
|
||||
self.scaler.scale(loss).backward()
|
||||
|
||||
return total_loss
|
||||
|
||||
# decoder trainer
|
||||
|
||||
def decoder_sample_in_chunks(fn):
|
||||
@wraps(fn)
|
||||
def inner(self, *args, max_batch_size = None, **kwargs):
|
||||
if not exists(max_batch_size):
|
||||
return fn(self, *args, **kwargs)
|
||||
|
||||
if self.decoder.unconditional:
|
||||
batch_size = kwargs.get('batch_size')
|
||||
batch_sizes = num_to_groups(batch_size, max_batch_size)
|
||||
outputs = [fn(self, *args, **{**kwargs, 'batch_size': sub_batch_size}) for sub_batch_size in batch_sizes]
|
||||
else:
|
||||
outputs = [fn(self, *chunked_args, **chunked_kwargs) for _, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs)]
|
||||
|
||||
return torch.cat(outputs, dim = 0)
|
||||
return inner
|
||||
|
||||
class DecoderTrainer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
decoder,
|
||||
use_ema = True,
|
||||
lr = 1e-4,
|
||||
wd = 1e-2,
|
||||
eps = 1e-8,
|
||||
max_grad_norm = 0.5,
|
||||
amp = False,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(decoder, Decoder)
|
||||
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
|
||||
|
||||
self.decoder = decoder
|
||||
self.num_unets = len(self.decoder.unets)
|
||||
|
||||
self.use_ema = use_ema
|
||||
self.ema_unets = nn.ModuleList([])
|
||||
|
||||
self.amp = amp
|
||||
|
||||
# be able to finely customize learning rate, weight decay
|
||||
# per unet
|
||||
|
||||
lr, wd, eps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps))
|
||||
|
||||
for ind, (unet, unet_lr, unet_wd, unet_eps) in enumerate(zip(self.decoder.unets, lr, wd, eps)):
|
||||
optimizer = get_optimizer(
|
||||
unet.parameters(),
|
||||
lr = unet_lr,
|
||||
wd = unet_wd,
|
||||
eps = unet_eps,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
setattr(self, f'optim{ind}', optimizer) # cannot use pytorch ModuleList for some reason with optimizers
|
||||
|
||||
if self.use_ema:
|
||||
self.ema_unets.append(EMA(unet, **ema_kwargs))
|
||||
|
||||
scaler = GradScaler(enabled = amp)
|
||||
setattr(self, f'scaler{ind}', scaler)
|
||||
|
||||
# gradient clipping if needed
|
||||
|
||||
self.max_grad_norm = max_grad_norm
|
||||
|
||||
self.register_buffer('step', torch.tensor([0.]))
|
||||
|
||||
@property
|
||||
def unets(self):
|
||||
return nn.ModuleList([ema.ema_model for ema in self.ema_unets])
|
||||
|
||||
def scale(self, loss, *, unet_number):
|
||||
assert 1 <= unet_number <= self.num_unets
|
||||
index = unet_number - 1
|
||||
scaler = getattr(self, f'scaler{index}')
|
||||
return scaler.scale(loss)
|
||||
|
||||
def update(self, unet_number = None):
|
||||
if self.num_unets == 1:
|
||||
unet_number = default(unet_number, 1)
|
||||
|
||||
assert exists(unet_number) and 1 <= unet_number <= self.num_unets
|
||||
index = unet_number - 1
|
||||
unet = self.decoder.unets[index]
|
||||
|
||||
optimizer = getattr(self, f'optim{index}')
|
||||
scaler = getattr(self, f'scaler{index}')
|
||||
|
||||
if exists(self.max_grad_norm):
|
||||
scaler.unscale_(optimizer)
|
||||
nn.utils.clip_grad_norm_(unet.parameters(), self.max_grad_norm)
|
||||
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if self.use_ema:
|
||||
ema_unet = self.ema_unets[index]
|
||||
ema_unet.update()
|
||||
|
||||
self.step += 1
|
||||
|
||||
@torch.no_grad()
|
||||
@cast_torch_tensor
|
||||
@decoder_sample_in_chunks
|
||||
def sample(self, *args, **kwargs):
|
||||
if kwargs.pop('use_non_ema', False) or not self.use_ema:
|
||||
return self.decoder.sample(*args, **kwargs)
|
||||
|
||||
trainable_unets = self.decoder.unets
|
||||
self.decoder.unets = self.unets # swap in exponential moving averaged unets for sampling
|
||||
|
||||
output = self.decoder.sample(*args, **kwargs)
|
||||
|
||||
self.decoder.unets = trainable_unets # restore original training unets
|
||||
|
||||
# cast the ema_model unets back to original device
|
||||
for ema in self.ema_unets:
|
||||
ema.restore_ema_model_device()
|
||||
|
||||
return output
|
||||
|
||||
@cast_torch_tensor
|
||||
def forward(
|
||||
self,
|
||||
*args,
|
||||
unet_number = None,
|
||||
max_batch_size = None,
|
||||
**kwargs
|
||||
):
|
||||
if self.num_unets == 1:
|
||||
unet_number = default(unet_number, 1)
|
||||
|
||||
total_loss = 0.
|
||||
|
||||
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
|
||||
with autocast(enabled = self.amp):
|
||||
loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs)
|
||||
loss = loss * chunk_size_frac
|
||||
|
||||
total_loss += loss.item()
|
||||
|
||||
if self.training:
|
||||
self.scale(loss, unet_number = unet_number).backward()
|
||||
|
||||
return total_loss
|
||||
@@ -285,6 +285,10 @@ class ResnetEncDec(nn.Module):
|
||||
def get_encoded_fmap_size(self, image_size):
|
||||
return image_size // (2 ** self.layers)
|
||||
|
||||
@property
|
||||
def last_dec_layer(self):
|
||||
return self.decoders[-1].weight
|
||||
|
||||
def encode(self, x):
|
||||
for enc in self.encoders:
|
||||
x = enc(x)
|
||||
@@ -504,6 +508,10 @@ class ViTEncDec(nn.Module):
|
||||
def get_encoded_fmap_size(self, image_size):
|
||||
return image_size // self.patch_size
|
||||
|
||||
@property
|
||||
def last_dec_layer(self):
|
||||
return self.decoder[-3][-1].weight
|
||||
|
||||
def encode(self, x):
|
||||
return self.encoder(x)
|
||||
|
||||
@@ -739,7 +747,7 @@ class VQGanVAE(nn.Module):
|
||||
|
||||
# calculate adaptive weight
|
||||
|
||||
last_dec_layer = self.decoders[-1].weight
|
||||
last_dec_layer = self.enc_dec.last_dec_layer
|
||||
|
||||
norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_dec_layer).norm(p = 2)
|
||||
norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_dec_layer).norm(p = 2)
|
||||
|
||||
277
dalle2_pytorch/vqgan_vae_trainer.py
Normal file
277
dalle2_pytorch/vqgan_vae_trainer.py
Normal file
@@ -0,0 +1,277 @@
|
||||
from math import sqrt
|
||||
import copy
|
||||
from random import choice
|
||||
from pathlib import Path
|
||||
from shutil import rmtree
|
||||
from PIL import Image
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.cuda.amp import autocast, GradScaler
|
||||
from torch.utils.data import Dataset, DataLoader, random_split
|
||||
|
||||
import torchvision.transforms as T
|
||||
from torchvision.datasets import ImageFolder
|
||||
from torchvision.utils import make_grid, save_image
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
from dalle2_pytorch.train import EMA
|
||||
from dalle2_pytorch.vqgan_vae import VQGanVAE
|
||||
from dalle2_pytorch.optimizer import get_optimizer
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def noop(*args, **kwargs):
|
||||
pass
|
||||
|
||||
def cycle(dl):
|
||||
while True:
|
||||
for data in dl:
|
||||
yield data
|
||||
|
||||
def cast_tuple(t):
|
||||
return t if isinstance(t, (tuple, list)) else (t,)
|
||||
|
||||
def yes_or_no(question):
|
||||
answer = input(f'{question} (y/n) ')
|
||||
return answer.lower() in ('yes', 'y')
|
||||
|
||||
def accum_log(log, new_logs):
|
||||
for key, new_value in new_logs.items():
|
||||
old_value = log.get(key, 0.)
|
||||
log[key] = old_value + new_value
|
||||
return log
|
||||
|
||||
# classes
|
||||
|
||||
class ImageDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
folder,
|
||||
image_size,
|
||||
exts = ['jpg', 'jpeg', 'png']
|
||||
):
|
||||
super().__init__()
|
||||
self.folder = folder
|
||||
self.image_size = image_size
|
||||
self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]
|
||||
|
||||
print(f'{len(self.paths)} training samples found at {folder}')
|
||||
|
||||
self.transform = T.Compose([
|
||||
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
|
||||
T.Resize(image_size),
|
||||
T.RandomHorizontalFlip(),
|
||||
T.CenterCrop(image_size),
|
||||
T.ToTensor()
|
||||
])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.paths)
|
||||
|
||||
def __getitem__(self, index):
|
||||
path = self.paths[index]
|
||||
img = Image.open(path)
|
||||
return self.transform(img)
|
||||
|
||||
# main trainer class
|
||||
|
||||
class VQGanVAETrainer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
vae,
|
||||
*,
|
||||
num_train_steps,
|
||||
lr,
|
||||
batch_size,
|
||||
folder,
|
||||
grad_accum_every,
|
||||
wd = 0.,
|
||||
save_results_every = 100,
|
||||
save_model_every = 1000,
|
||||
results_folder = './results',
|
||||
valid_frac = 0.05,
|
||||
random_split_seed = 42,
|
||||
ema_beta = 0.995,
|
||||
ema_update_after_step = 2000,
|
||||
ema_update_every = 10,
|
||||
apply_grad_penalty_every = 4,
|
||||
amp = False
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(vae, VQGanVAE), 'vae must be instance of VQGanVAE'
|
||||
image_size = vae.image_size
|
||||
|
||||
self.vae = vae
|
||||
self.ema_vae = EMA(vae, update_after_step = ema_update_after_step, update_every = ema_update_every)
|
||||
|
||||
self.register_buffer('steps', torch.Tensor([0]))
|
||||
|
||||
self.num_train_steps = num_train_steps
|
||||
self.batch_size = batch_size
|
||||
self.grad_accum_every = grad_accum_every
|
||||
|
||||
all_parameters = set(vae.parameters())
|
||||
discr_parameters = set(vae.discr.parameters())
|
||||
vae_parameters = all_parameters - discr_parameters
|
||||
|
||||
self.optim = get_optimizer(vae_parameters, lr = lr, wd = wd)
|
||||
self.discr_optim = get_optimizer(discr_parameters, lr = lr, wd = wd)
|
||||
|
||||
self.amp = amp
|
||||
self.scaler = GradScaler(enabled = amp)
|
||||
self.discr_scaler = GradScaler(enabled = amp)
|
||||
|
||||
# create dataset
|
||||
|
||||
self.ds = ImageDataset(folder, image_size = image_size)
|
||||
|
||||
# split for validation
|
||||
|
||||
if valid_frac > 0:
|
||||
train_size = int((1 - valid_frac) * len(self.ds))
|
||||
valid_size = len(self.ds) - train_size
|
||||
self.ds, self.valid_ds = random_split(self.ds, [train_size, valid_size], generator = torch.Generator().manual_seed(random_split_seed))
|
||||
print(f'training with dataset of {len(self.ds)} samples and validating with randomly splitted {len(self.valid_ds)} samples')
|
||||
else:
|
||||
self.valid_ds = self.ds
|
||||
print(f'training with shared training and valid dataset of {len(self.ds)} samples')
|
||||
|
||||
# dataloader
|
||||
|
||||
self.dl = cycle(DataLoader(
|
||||
self.ds,
|
||||
batch_size = batch_size,
|
||||
shuffle = True
|
||||
))
|
||||
|
||||
self.valid_dl = cycle(DataLoader(
|
||||
self.valid_ds,
|
||||
batch_size = batch_size,
|
||||
shuffle = True
|
||||
))
|
||||
|
||||
self.save_model_every = save_model_every
|
||||
self.save_results_every = save_results_every
|
||||
|
||||
self.apply_grad_penalty_every = apply_grad_penalty_every
|
||||
|
||||
self.results_folder = Path(results_folder)
|
||||
|
||||
if len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?'):
|
||||
rmtree(str(self.results_folder))
|
||||
|
||||
self.results_folder.mkdir(parents = True, exist_ok = True)
|
||||
|
||||
def train_step(self):
|
||||
device = next(self.vae.parameters()).device
|
||||
steps = int(self.steps.item())
|
||||
apply_grad_penalty = not (steps % self.apply_grad_penalty_every)
|
||||
|
||||
self.vae.train()
|
||||
|
||||
# logs
|
||||
|
||||
logs = {}
|
||||
|
||||
# update vae (generator)
|
||||
|
||||
for _ in range(self.grad_accum_every):
|
||||
img = next(self.dl)
|
||||
img = img.to(device)
|
||||
|
||||
with autocast(enabled = self.amp):
|
||||
loss = self.vae(
|
||||
img,
|
||||
return_loss = True,
|
||||
apply_grad_penalty = apply_grad_penalty
|
||||
)
|
||||
|
||||
|
||||
self.scaler.scale(loss / self.grad_accum_every).backward()
|
||||
|
||||
accum_log(logs, {'loss': loss.item() / self.grad_accum_every})
|
||||
|
||||
self.scaler.step(self.optim)
|
||||
self.scaler.update()
|
||||
self.optim.zero_grad()
|
||||
|
||||
# update discriminator
|
||||
|
||||
if exists(self.vae.discr):
|
||||
discr_loss = 0
|
||||
for _ in range(self.grad_accum_every):
|
||||
img = next(self.dl)
|
||||
img = img.to(device)
|
||||
|
||||
with autocast(enabled = self.amp):
|
||||
loss = self.vae(img, return_discr_loss = True)
|
||||
|
||||
self.discr_scaler.scale(loss / self.grad_accum_every).backward()
|
||||
|
||||
accum_log(logs, {'discr_loss': loss.item() / self.grad_accum_every})
|
||||
|
||||
self.discr_scaler.step(self.discr_optim)
|
||||
self.discr_scaler.update()
|
||||
self.discr_optim.zero_grad()
|
||||
|
||||
# log
|
||||
|
||||
print(f"{steps}: vae loss: {logs['loss']} - discr loss: {logs['discr_loss']}")
|
||||
|
||||
# update exponential moving averaged generator
|
||||
|
||||
self.ema_vae.update()
|
||||
|
||||
# sample results every so often
|
||||
|
||||
if not (steps % self.save_results_every):
|
||||
for model, filename in ((self.ema_vae.ema_model, f'{steps}.ema'), (self.vae, str(steps))):
|
||||
model.eval()
|
||||
|
||||
imgs = next(self.dl)
|
||||
imgs = imgs.to(device)
|
||||
|
||||
recons = model(imgs)
|
||||
nrows = int(sqrt(self.batch_size))
|
||||
|
||||
imgs_and_recons = torch.stack((imgs, recons), dim = 0)
|
||||
imgs_and_recons = rearrange(imgs_and_recons, 'r b ... -> (b r) ...')
|
||||
|
||||
imgs_and_recons = imgs_and_recons.detach().cpu().float().clamp(0., 1.)
|
||||
grid = make_grid(imgs_and_recons, nrow = 2, normalize = True, value_range = (0, 1))
|
||||
|
||||
logs['reconstructions'] = grid
|
||||
|
||||
save_image(grid, str(self.results_folder / f'{filename}.png'))
|
||||
|
||||
print(f'{steps}: saving to {str(self.results_folder)}')
|
||||
|
||||
# save model every so often
|
||||
|
||||
if not (steps % self.save_model_every):
|
||||
state_dict = self.vae.state_dict()
|
||||
model_path = str(self.results_folder / f'vae.{steps}.pt')
|
||||
torch.save(state_dict, model_path)
|
||||
|
||||
ema_state_dict = self.ema_vae.state_dict()
|
||||
model_path = str(self.results_folder / f'vae.{steps}.ema.pt')
|
||||
torch.save(ema_state_dict, model_path)
|
||||
|
||||
print(f'{steps}: saving model to {str(self.results_folder)}')
|
||||
|
||||
self.steps += 1
|
||||
return logs
|
||||
|
||||
def train(self, log_fn = noop):
|
||||
device = next(self.vae.parameters()).device
|
||||
|
||||
while self.steps < self.num_train_steps:
|
||||
logs = self.train_step()
|
||||
log_fn(logs)
|
||||
|
||||
print('training complete')
|
||||
BIN
samples/oxford.png
Normal file
BIN
samples/oxford.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 985 KiB |
14
setup.py
14
setup.py
@@ -10,11 +10,12 @@ setup(
|
||||
'dream = dalle2_pytorch.cli:dream'
|
||||
],
|
||||
},
|
||||
version = '0.0.75',
|
||||
version = '0.3.4',
|
||||
license='MIT',
|
||||
description = 'DALL-E 2',
|
||||
author = 'Phil Wang',
|
||||
author_email = 'lucidrains@gmail.com',
|
||||
long_description_content_type = 'text/markdown',
|
||||
url = 'https://github.com/lucidrains/dalle2-pytorch',
|
||||
keywords = [
|
||||
'artificial intelligence',
|
||||
@@ -24,16 +25,23 @@ setup(
|
||||
install_requires=[
|
||||
'click',
|
||||
'clip-anytorch',
|
||||
'coca-pytorch>=0.0.5',
|
||||
'einops>=0.4',
|
||||
'einops-exts>=0.0.3',
|
||||
'embedding-reader',
|
||||
'kornia>=0.5.4',
|
||||
'numpy',
|
||||
'pillow',
|
||||
'resize-right>=0.0.2',
|
||||
'rotary-embedding-torch',
|
||||
'torch>=1.10',
|
||||
'torchvision',
|
||||
'tqdm',
|
||||
'vector-quantize-pytorch',
|
||||
'x-clip>=0.5.1',
|
||||
'youtokentome'
|
||||
'x-clip>=0.4.4',
|
||||
'youtokentome',
|
||||
'webdataset>=0.2.5',
|
||||
'fsspec>=2022.1.0'
|
||||
],
|
||||
classifiers=[
|
||||
'Development Status :: 4 - Beta',
|
||||
|
||||
375
train_diffusion_prior.py
Normal file
375
train_diffusion_prior.py
Normal file
@@ -0,0 +1,375 @@
|
||||
from pathlib import Path
|
||||
import click
|
||||
import math
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import clip
|
||||
from torch import nn
|
||||
|
||||
from dalle2_pytorch.dataloaders import make_splits
|
||||
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter
|
||||
from dalle2_pytorch.trainer import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model, print_ribbon
|
||||
|
||||
from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker
|
||||
|
||||
from embedding_reader import EmbeddingReader
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
# constants
|
||||
|
||||
REPORT_METRICS_EVERY = 250 # for cosine similarity and other metric reporting during training
|
||||
|
||||
tracker = WandbTracker()
|
||||
|
||||
# helpers functions
|
||||
|
||||
def exists(val):
|
||||
val is not None
|
||||
|
||||
class Timer:
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.last_time = time.time()
|
||||
|
||||
def elapsed(self):
|
||||
return time.time() - self.last_time
|
||||
|
||||
# functions
|
||||
|
||||
def eval_model(model, dataloader, text_conditioned, loss_type, phase="Validation"):
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
total_loss = 0.
|
||||
total_samples = 0.
|
||||
|
||||
for image_embeddings, text_data in tqdm(dataloader):
|
||||
|
||||
batches = image_embeddings.shape[0]
|
||||
|
||||
input_args = dict(image_embed=image_embeddings)
|
||||
if text_conditioned:
|
||||
input_args = dict(**input_args, text = text_data)
|
||||
else:
|
||||
input_args = dict(**input_args, text_embed=text_data)
|
||||
|
||||
loss = model(**input_args)
|
||||
|
||||
total_loss += loss * batches
|
||||
total_samples += batches
|
||||
|
||||
avg_loss = (total_loss / total_samples)
|
||||
|
||||
tracker.log({f'{phase} {loss_type}': avg_loss})
|
||||
|
||||
def report_cosine_sims(diffusion_prior, dataloader, text_conditioned):
|
||||
diffusion_prior.eval()
|
||||
|
||||
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
|
||||
|
||||
for test_image_embeddings, text_data in tqdm(dataloader):
|
||||
|
||||
# we are text conditioned, we produce an embedding from the tokenized text
|
||||
if text_conditioned:
|
||||
text_embedding, text_encodings, text_mask = diffusion_prior.clip.embed_text(
|
||||
text_data)
|
||||
text_cond = dict(text_embed=text_embedding,
|
||||
text_encodings=text_encodings, mask=text_mask)
|
||||
else:
|
||||
text_embedding = text_data
|
||||
text_cond = dict(text_embed=text_embedding)
|
||||
|
||||
# make a copy of the text embeddings for shuffling
|
||||
text_embed_shuffled = text_embedding.clone()
|
||||
|
||||
# roll the text to simulate "unrelated" captions
|
||||
rolled_idx = torch.roll(torch.arange(text_embedding.shape[0]), 1)
|
||||
text_embed_shuffled = text_embed_shuffled[rolled_idx]
|
||||
text_embed_shuffled = text_embed_shuffled / \
|
||||
text_embed_shuffled.norm(dim=1, keepdim=True)
|
||||
|
||||
if text_conditioned:
|
||||
text_encodings_shuffled = text_encodings[rolled_idx]
|
||||
text_mask_shuffled = text_mask[rolled_idx]
|
||||
else:
|
||||
text_encodings_shuffled = None
|
||||
text_mask_shuffled = None
|
||||
|
||||
text_cond_shuffled = dict(text_embed=text_embed_shuffled,
|
||||
text_encodings=text_encodings_shuffled, mask=text_mask_shuffled)
|
||||
|
||||
# prepare the text embedding
|
||||
text_embed = text_embedding / text_embedding.norm(dim=1, keepdim=True)
|
||||
|
||||
# prepare image embeddings
|
||||
test_image_embeddings = test_image_embeddings / \
|
||||
test_image_embeddings.norm(dim=1, keepdim=True)
|
||||
|
||||
# predict on the unshuffled text embeddings
|
||||
predicted_image_embeddings = diffusion_prior.p_sample_loop(
|
||||
test_image_embeddings.shape, text_cond)
|
||||
predicted_image_embeddings = predicted_image_embeddings / \
|
||||
predicted_image_embeddings.norm(dim=1, keepdim=True)
|
||||
|
||||
# predict on the shuffled embeddings
|
||||
predicted_unrelated_embeddings = diffusion_prior.p_sample_loop(
|
||||
test_image_embeddings.shape, text_cond_shuffled)
|
||||
predicted_unrelated_embeddings = predicted_unrelated_embeddings / \
|
||||
predicted_unrelated_embeddings.norm(dim=1, keepdim=True)
|
||||
|
||||
# calculate similarities
|
||||
original_similarity = cos(
|
||||
text_embed, test_image_embeddings).cpu().numpy()
|
||||
predicted_similarity = cos(
|
||||
text_embed, predicted_image_embeddings).cpu().numpy()
|
||||
unrelated_similarity = cos(
|
||||
text_embed, predicted_unrelated_embeddings).cpu().numpy()
|
||||
predicted_img_similarity = cos(
|
||||
test_image_embeddings, predicted_image_embeddings).cpu().numpy()
|
||||
tracker.log({"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity),
|
||||
"CosineSimilarity(text_embed,predicted_image_embed)":np.mean(predicted_similarity),
|
||||
"CosineSimilarity(orig_image_embed,predicted_image_embed)":np.mean(predicted_img_similarity),
|
||||
"CosineSimilarity(text_embed,predicted_unrelated_embed)": np.mean(unrelated_similarity),
|
||||
"Cosine similarity difference":np.mean(predicted_similarity - original_similarity)})
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option("--wandb-entity", default="laion")
|
||||
@click.option("--wandb-project", default="diffusion-prior")
|
||||
@click.option("--wandb-dataset", default="LAION-5B")
|
||||
@click.option("--wandb-arch", default="DiffusionPrior")
|
||||
@click.option("--image-embed-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/")
|
||||
@click.option("--text-embed-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/")
|
||||
@click.option("--meta-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/laion2B-en-metadata/")
|
||||
@click.option("--learning-rate", default=1.1e-4)
|
||||
@click.option("--weight-decay", default=6.02e-2)
|
||||
@click.option("--dropout", default=5e-2)
|
||||
@click.option("--max-grad-norm", default=0.5)
|
||||
@click.option("--num-data-points", default=250e6)
|
||||
@click.option("--batch-size", default=320)
|
||||
@click.option("--num-epochs", default=5)
|
||||
@click.option("--image-embed-dim", default=768)
|
||||
@click.option("--train-percent", default=0.9)
|
||||
@click.option("--val-percent", default=1e-7)
|
||||
@click.option("--test-percent", default=0.0999999)
|
||||
@click.option("--dpn-depth", default=12)
|
||||
@click.option("--dpn-dim-head", default=64)
|
||||
@click.option("--dpn-heads", default=12)
|
||||
@click.option("--dp-condition-on-text-encodings", default=True)
|
||||
@click.option("--dp-timesteps", default=1000)
|
||||
@click.option("--dp-normformer", default=True)
|
||||
@click.option("--dp-cond-drop-prob", default=0.1)
|
||||
@click.option("--dp-loss-type", default="l2")
|
||||
@click.option("--clip", default="ViT-L/14")
|
||||
@click.option("--amp", default=False)
|
||||
@click.option("--save-interval", default=120)
|
||||
@click.option("--save-path", default="./diffusion_prior_checkpoints")
|
||||
@click.option("--pretrained-model-path", default=None)
|
||||
@click.option("--gpu-device", default=0)
|
||||
def train(
|
||||
wandb_entity,
|
||||
wandb_project,
|
||||
wandb_dataset,
|
||||
wandb_arch,
|
||||
image_embed_url,
|
||||
text_embed_url,
|
||||
meta_url,
|
||||
learning_rate,
|
||||
weight_decay,
|
||||
dropout,
|
||||
max_grad_norm,
|
||||
num_data_points,
|
||||
batch_size,
|
||||
num_epochs,
|
||||
image_embed_dim,
|
||||
train_percent,
|
||||
val_percent,
|
||||
test_percent,
|
||||
dpn_depth,
|
||||
dpn_dim_head,
|
||||
dpn_heads,
|
||||
dp_condition_on_text_encodings,
|
||||
dp_timesteps,
|
||||
dp_normformer,
|
||||
dp_cond_drop_prob,
|
||||
dp_loss_type,
|
||||
clip,
|
||||
amp,
|
||||
save_interval,
|
||||
save_path,
|
||||
pretrained_model_path,
|
||||
gpu_device
|
||||
):
|
||||
config = {
|
||||
"learning_rate": learning_rate,
|
||||
"architecture": wandb_arch,
|
||||
"dataset": wandb_dataset,
|
||||
"weight_decay": weight_decay,
|
||||
"max_gradient_clipping_norm": max_grad_norm,
|
||||
"batch_size": batch_size,
|
||||
"epochs": num_epochs,
|
||||
"diffusion_prior_network": {
|
||||
"depth": dpn_depth,
|
||||
"dim_head": dpn_dim_head,
|
||||
"heads": dpn_heads,
|
||||
"normformer": dp_normformer
|
||||
},
|
||||
"diffusion_prior": {
|
||||
"condition_on_text_encodings": dp_condition_on_text_encodings,
|
||||
"timesteps": dp_timesteps,
|
||||
"cond_drop_prob": dp_cond_drop_prob,
|
||||
"loss_type": dp_loss_type,
|
||||
"clip": clip
|
||||
}
|
||||
}
|
||||
|
||||
# Check if DPRIOR_PATH exists(saved model path)
|
||||
|
||||
DPRIOR_PATH = pretrained_model_path
|
||||
RESUME = exists(DPRIOR_PATH)
|
||||
|
||||
if not RESUME:
|
||||
tracker.init(
|
||||
entity = wandb_entity,
|
||||
project = wandb_project,
|
||||
config = config
|
||||
)
|
||||
|
||||
# Obtain the utilized device.
|
||||
|
||||
has_cuda = torch.cuda.is_available()
|
||||
if has_cuda:
|
||||
device = torch.device(f"cuda:{gpu_device}")
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
# Training loop
|
||||
# diffusion prior network
|
||||
|
||||
prior_network = DiffusionPriorNetwork(
|
||||
dim = image_embed_dim,
|
||||
depth = dpn_depth,
|
||||
dim_head = dpn_dim_head,
|
||||
heads = dpn_heads,
|
||||
attn_dropout = dropout,
|
||||
ff_dropout = dropout,
|
||||
normformer = dp_normformer
|
||||
)
|
||||
|
||||
# Load clip model if text-conditioning
|
||||
if dp_condition_on_text_encodings:
|
||||
clip_adapter = OpenAIClipAdapter(clip)
|
||||
else:
|
||||
clip_adapter = None
|
||||
|
||||
# diffusion prior with text embeddings and image embeddings pre-computed
|
||||
|
||||
diffusion_prior = DiffusionPrior(
|
||||
net = prior_network,
|
||||
clip = clip_adapter,
|
||||
image_embed_dim = image_embed_dim,
|
||||
timesteps = dp_timesteps,
|
||||
cond_drop_prob = dp_cond_drop_prob,
|
||||
loss_type = dp_loss_type,
|
||||
condition_on_text_encodings = dp_condition_on_text_encodings
|
||||
)
|
||||
|
||||
# Load pre-trained model from DPRIOR_PATH
|
||||
|
||||
if RESUME:
|
||||
diffusion_prior, loaded_obj = load_diffusion_model(DPRIOR_PATH, device)
|
||||
tracker.init(entity = wandb_entity, project = wandb_project, config = config)
|
||||
|
||||
# diffusion prior trainer
|
||||
|
||||
trainer = DiffusionPriorTrainer(
|
||||
diffusion_prior = diffusion_prior,
|
||||
lr = learning_rate,
|
||||
wd = weight_decay,
|
||||
max_grad_norm = max_grad_norm,
|
||||
amp = amp,
|
||||
).to(device)
|
||||
|
||||
# load optimizer and scaler
|
||||
|
||||
if RESUME:
|
||||
trainer.optimizer.load_state_dict(loaded_obj['optimizer'])
|
||||
trainer.scaler.load_state_dict(loaded_obj['scaler'])
|
||||
|
||||
# Create save_path if it doesn't exist
|
||||
|
||||
Path(save_path).mkdir(exist_ok = True, parents = True)
|
||||
|
||||
# Utilize wrapper to abstract away loader logic
|
||||
print_ribbon("Downloading Embeddings")
|
||||
loader_args = dict(text_conditioned=dp_condition_on_text_encodings, batch_size=batch_size, num_data_points=num_data_points,
|
||||
train_split=train_percent, eval_split=val_percent, device=device, img_url=image_embed_url)
|
||||
|
||||
if dp_condition_on_text_encodings:
|
||||
loader_args = dict(**loader_args, meta_url=meta_url)
|
||||
else:
|
||||
loader_args = dict(**loader_args, txt_url=text_embed_url)
|
||||
|
||||
train_loader, eval_loader, test_loader = make_splits(**loader_args)
|
||||
|
||||
### Training code ###
|
||||
|
||||
step = 1
|
||||
timer = Timer()
|
||||
epochs = num_epochs
|
||||
|
||||
for _ in range(epochs):
|
||||
|
||||
for image, text in tqdm(train_loader):
|
||||
|
||||
diffusion_prior.train()
|
||||
|
||||
input_args = dict(image_embed=image)
|
||||
if dp_condition_on_text_encodings:
|
||||
input_args = dict(**input_args, text = text)
|
||||
else:
|
||||
input_args = dict(**input_args, text_embed=text)
|
||||
|
||||
loss = trainer(**input_args)
|
||||
|
||||
# Samples per second
|
||||
|
||||
samples_per_sec = batch_size * step / timer.elapsed()
|
||||
|
||||
# Save checkpoint every save_interval minutes
|
||||
if(int(timer.elapsed()) >= 60 * save_interval):
|
||||
timer.reset()
|
||||
|
||||
save_diffusion_model(
|
||||
save_path,
|
||||
diffusion_prior,
|
||||
trainer.optimizer,
|
||||
trainer.scaler,
|
||||
config,
|
||||
image_embed_dim)
|
||||
|
||||
# Log to wandb
|
||||
tracker.log({"Training loss": loss,
|
||||
"Steps": step,
|
||||
"Samples per second": samples_per_sec})
|
||||
# Log cosineSim(text_embed,predicted_image_embed) - cosineSim(text_embed,image_embed)
|
||||
# Use NUM_TEST_EMBEDDINGS samples from the test set each time
|
||||
# Get embeddings from the most recently saved model
|
||||
if(step % REPORT_METRICS_EVERY) == 0:
|
||||
report_cosine_sims(diffusion_prior, eval_loader, dp_condition_on_text_encodings)
|
||||
### Evaluate model(validation run) ###
|
||||
eval_model(diffusion_prior, eval_loader, dp_condition_on_text_encodings, dp_loss_type, phase="Validation")
|
||||
|
||||
step += 1
|
||||
trainer.update()
|
||||
|
||||
### Test run ###
|
||||
eval_model(diffusion_prior, test_loader, dp_condition_on_text_encodings, dp_loss_type, phase="Test")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
Reference in New Issue
Block a user