Compare commits

...

52 Commits

Author SHA1 Message Date
Phil Wang
924455d97d align the ema model device back after sampling from the cascading ddpm in the decoder 2022-05-11 19:56:54 -07:00
Phil Wang
6021945fc8 default to l2 loss 2022-05-11 19:24:51 -07:00
Light-V
6f76652d11 fix typo in README.md (#85)
The default config for clip from openai should be ViT-B/32
2022-05-11 13:38:16 -07:00
Phil Wang
3dda2570ed fix amp issue for https://github.com/lucidrains/DALLE2-pytorch/issues/82 2022-05-11 08:21:39 -07:00
Phil Wang
2f3c02dba8 numerical accuracy for noise schedule parameters 2022-05-10 15:28:46 -07:00
Phil Wang
908088cfea wrap up cross embed layer feature 2022-05-10 12:19:34 -07:00
Phil Wang
8dc8a3de0d product management 2022-05-10 11:51:38 -07:00
Phil Wang
35f89556ba bring in the cross embed layer from Crossformer paper for initial convolution in unet 2022-05-10 11:50:38 -07:00
Phil Wang
2b55f753b9 fix new issue with github actions and auto pypi package uploading 2022-05-10 10:51:15 -07:00
Phil Wang
fc8fce38fb make sure cascading DDPM can be trained unconditionally, to ready for CLI one command training for the public 2022-05-10 10:48:10 -07:00
Phil Wang
a1bfb03ba4 project management 2022-05-10 10:13:51 -07:00
Phil Wang
b1e7b5f6bb make sure resnet groups in unet is finely customizable 2022-05-10 10:12:50 -07:00
z
10b905b445 smol typo (#81) 2022-05-10 09:52:50 -07:00
Phil Wang
9b322ea634 patch 2022-05-09 19:46:19 -07:00
Phil Wang
ba64ea45cc 0.2.3 2022-05-09 16:50:31 -07:00
Phil Wang
64f7be1926 some cleanup 2022-05-09 16:50:21 -07:00
Phil Wang
db805e73e1 fix a bug with numerical stability in attention, sorry! 🐛 2022-05-09 16:23:37 -07:00
z
cb07b37970 Ensure Eval Mode In Metric Functions (#79)
* add eval/train toggles

* train/eval flags

* shift train toggle

Co-authored-by: nousr <z@localhost.com>
2022-05-09 16:05:40 -07:00
Phil Wang
a774bfefe2 add attention and feedforward dropouts to train_diffusion_prior script 2022-05-09 13:57:15 -07:00
Phil Wang
2ae57f0cf5 cleanup 2022-05-09 13:51:26 -07:00
Phil Wang
e46eaec817 deal the diffusion prior problem yet another blow 2022-05-09 11:08:52 -07:00
Kumar R
8647cb5e76 Val loss changes, with quite a few other changes. This is in place of the earlier PR(https://github.com/lucidrains/DALLE2-pytorch/pull/67) (#77)
* Val_loss changes - no rebased with lucidrains' master.

* Val Loss changes - now rebased with lucidrains' master

* train_diffusion_prior.py updates

* dalle2_pytorch.py updates

* __init__.py changes

* Update train_diffusion_prior.py

* Update dalle2_pytorch.py

* Update train_diffusion_prior.py

* Update train_diffusion_prior.py

* Update dalle2_pytorch.py

* Update train_diffusion_prior.py

* Update train_diffusion_prior.py

* Update train_diffusion_prior.py

* Update train_diffusion_prior.py

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md
2022-05-09 08:53:29 -07:00
Phil Wang
53c189e46a give more surface area for attention in diffusion prior 2022-05-09 08:08:11 -07:00
Phil Wang
dde51fd362 revert restriction for classifier free guidance for diffusion prior, given @crowsonkb advice 2022-05-07 20:55:41 -07:00
Nasir Khalid
2eac7996fa Additional image_embed metric (#75)
Added metric to track image_embed vs predicted_image_embed
2022-05-07 14:32:33 -07:00
Phil Wang
4010aec033 turn off classifier free guidance if predicting x_start for diffusion prior 2022-05-07 09:38:17 -07:00
Phil Wang
c87b84a259 todo 2022-05-07 09:21:08 -07:00
Phil Wang
8b05468653 todo 2022-05-07 08:33:45 -07:00
Phil Wang
830afd3c15 sinusoidal embed time embeddings for diffusion prior as well, for continuous version 2022-05-07 08:32:43 -07:00
Phil Wang
8f93729d19 when in doubt, make it a hyperparameter 2022-05-07 07:52:17 -07:00
z
cd5f2c1de4 simulate unrelated captions as a training metric (#66)
* add unrelated embedding metric

* change to torch.roll

Co-authored-by: nousr <z@localhost.com>
Co-authored-by: nousr <>
2022-05-07 05:34:59 -07:00
Phil Wang
85ed77d512 fix a potentially huge bug thanks to @CiaoHe https://github.com/lucidrains/DALLE2-pytorch/issues/71 2022-05-07 05:05:54 -07:00
Piero Rolando
fd53fa17db Fix a typo in README (#70)
Change "pyhon" for "python" (correct)
2022-05-06 16:53:36 -07:00
Phil Wang
3676ef4d49 make sure vqgan-vae trainer supports mixed precision 2022-05-06 10:44:16 -07:00
Phil Wang
28e944f328 make sure openai clip adapter outputs l2normed embeddings 2022-05-06 10:12:03 -07:00
Phil Wang
14e63a3f67 also offer l2norm clamping in diffusion prior during training, if one were using predict x0 objective 2022-05-06 10:05:14 -07:00
Phil Wang
09e9eaa5a6 project management 2022-05-06 09:00:22 -07:00
Phil Wang
e6d752cf4a reprioritize 2022-05-06 08:55:26 -07:00
Phil Wang
ad20a14a4d bring in rotary embeddings for diffusion prior causal transformer (the most powerful relative positional encoding, used in PaLM) - 0.1.0 because of breaking change 2022-05-06 08:45:30 -07:00
Phil Wang
0be1e0d64c support CoCa, which seems to be better than CLIP (has an autoregressive text encoder) https://arxiv.org/abs/2205.01917 2022-05-06 08:27:12 -07:00
Phil Wang
98df1ba51e add diffusion prior trainer, which automatically takes care of the exponential moving average (training and sampling), as well as mixed precision, gradient clipping 2022-05-06 08:11:09 -07:00
Phil Wang
878b555ef7 fix training with clip 2022-05-06 07:37:57 -07:00
Phil Wang
63029f7388 remove l2norm output from train_diffusion_prior.py 2022-05-05 19:07:58 -07:00
Phil Wang
c76a964fd6 allow for CLIP to be optional in Decoder, and allow DecoderTrainer to work off training pre-encoded image embeddings 2022-05-05 08:11:01 -07:00
Phil Wang
79fabc4341 reorg readme 2022-05-05 07:54:12 -07:00
Kumar R
f7ef4bde38 Added some documentation for the diffusion prior in README.md (#62)
* Delete README.md

* Create README.md

* Update README.md

* Update README.md
2022-05-05 07:51:31 -07:00
Phil Wang
93ba019069 product management 2022-05-05 07:39:51 -07:00
Phil Wang
8518684ae9 does not make much sense, as researchers may want to try predicting noise with diffusionprior instead of predicting x0 2022-05-05 07:37:00 -07:00
Phil Wang
1d5dc08810 take @crowsonkb 's suggestion at https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132 2022-05-05 07:28:53 -07:00
Phil Wang
d8d8b6caf1 dataloaders for decoder training, from @Veldrovive 2022-05-05 07:09:45 -07:00
Aidan Dempster
15acc03bd4 Add a dataloader for training the decoder (#57)
* Added dataloader and updated requirements

* Added option to set embedding shard width separately from webdataset shard length.
There must be a better way to do this.

* Changed embedding loader to read using fsspec

* Moved the loader into a more compatible location

* Removed unnecessary package

* Fixed typo (Embeding -> Embedding)

* Simplified example embedding finder code to remove unnecessary get_file_list function

* Added example usage of ImageEmbeddingDataset

* Changed the name of create_dataloader to be more verbose
Added a dataloaders __init__.py
2022-05-05 07:08:45 -07:00
Phil Wang
896f19786d remove convnext blocks, they are illsuited for generative work, validated by early experimental results at https://github.com/lucidrains/video-diffusion-pytorch 2022-05-05 07:07:21 -07:00
10 changed files with 935 additions and 354 deletions

221
README.md
View File

@@ -508,7 +508,7 @@ To use a pretrained OpenAI CLIP, simply import `OpenAIClipAdapter` and pass it i
import torch
from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder, OpenAIClipAdapter
# openai pretrained clip - defaults to ViT/B-32
# openai pretrained clip - defaults to ViT-B/32
clip = OpenAIClipAdapter()
@@ -786,6 +786,181 @@ mock_image_embed = torch.randn(4, 512).cuda()
images = decoder_trainer.sample(mock_image_embed, text = text) # (4, 3, 256, 256)
```
### Diffusion Prior Training
Similarly, one can use the `DiffusionPriorTrainer` to automatically instantiate and keep track of an exponential moving averaged prior.
```python
import torch
from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, DiffusionPriorTrainer, Unet, Decoder, CLIP
clip = CLIP(
dim_text = 512,
dim_image = 512,
dim_latent = 512,
num_text_tokens = 49408,
text_enc_depth = 6,
text_seq_len = 256,
text_heads = 8,
visual_enc_depth = 6,
visual_image_size = 256,
visual_patch_size = 32,
visual_heads = 8
).cuda()
# mock data
text = torch.randint(0, 49408, (4, 256)).cuda()
images = torch.randn(4, 3, 256, 256).cuda()
# prior networks (with transformer)
prior_network = DiffusionPriorNetwork(
dim = 512,
depth = 6,
dim_head = 64,
heads = 8
).cuda()
diffusion_prior = DiffusionPrior(
net = prior_network,
clip = clip,
timesteps = 100,
cond_drop_prob = 0.2
).cuda()
diffusion_prior_trainer = DiffusionPriorTrainer(
diffusion_prior,
lr = 3e-4,
wd = 1e-2,
ema_beta = 0.99,
ema_update_after_step = 1000,
ema_update_every = 10,
)
loss = diffusion_prior_trainer(text, images)
loss.backward()
diffusion_prior_trainer.update() # this will update the optimizer as well as the exponential moving averaged diffusion prior
# after much of the above three lines in a loop
# you can sample from the exponential moving average of the diffusion prior identically to how you do so for DiffusionPrior
image_embeds = diffusion_prior_trainer.sample(text) # (4, 512) - exponential moving averaged image embeddings
```
### Decoder Dataloaders
In order to make loading data simple and efficient, we include some general dataloaders that can be used to train portions of the network.
#### Decoder: Image Embedding Dataset
When training the decoder (and up samplers if training together) in isolation, you will need to load images and corresponding image embeddings. This dataset can read two similar types of datasets. First, it can read a [webdataset](https://github.com/webdataset/webdataset) that contains `.jpg` and `.npy` files in the `.tar`s that contain the images and associated image embeddings respectively. Alternatively, you can also specify a source for the embeddings outside of the webdataset. In this case, the path to the embeddings should contain `.npy` files with the same shard numbers as the webdataset and there should be a correspondence between the filename of the `.jpg` and the index of the embedding in the `.npy`. So, for example, `0001.tar` from the webdataset with image `00010509.jpg` (the first 4 digits are the shard number and the last 4 are the index) in it should be paralleled by a `img_emb_0001.npy` which contains a NumPy array with the embedding at index 509.
Generating a dataset of this type:
1. Use [img2dataset](https://github.com/rom1504/img2dataset) to generate a webdataset.
2. Use [clip-retrieval](https://github.com/rom1504/clip-retrieval) to convert the images to embeddings.
3. Use [embedding-dataset-reordering](https://github.com/Veldrovive/embedding-dataset-reordering) to reorder the embeddings into the expected format.
Usage:
```python
from dalle2_pytorch.dataloaders import ImageEmbeddingDataset, create_image_embedding_dataloader
# Create a dataloader directly.
dataloader = create_image_embedding_dataloader(
tar_url="/path/or/url/to/webdataset/{0000..9999}.tar", # Uses braket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar
embeddings_url="path/or/url/to/embeddings/folder", # Included if .npy files are not in webdataset. Left out or set to None otherwise
num_workers=4,
batch_size=32,
shard_width=4, # If a file in the webdataset shard 3 is named 0003039.jpg, we know the shard width is 4 and the last three digits are the index
shuffle_num=200, # Does a shuffle of the data with a buffer size of 200
shuffle_shards=True, # Shuffle the order the shards are read in
resample_shards=False, # Sample shards with replacement. If true, an epoch will be infinite unless stopped manually
)
for img, emb in dataloader:
print(img.shape) # torch.Size([32, 3, 256, 256])
print(emb.shape) # torch.Size([32, 512])
# Train decoder only as shown above
# Or create a dataset without a loader so you can configure it manually
dataset = ImageEmbeddingDataset(
urls="/path/or/url/to/webdataset/{0000..9999}.tar",
embedding_folder_url="path/or/url/to/embeddings/folder",
shard_width=4,
shuffle_shards=True,
resample=False
)
```
## Scripts
### Using the `train_diffusion_prior.py` script
This script allows training the DiffusionPrior on pre-computed text and image embeddings. The working example below elucidates this process.
Please note that the script internally passes text_embed and image_embed to the DiffusionPrior, unlike the example below.
### Usage
```bash
$ python train_diffusion_prior.py
```
The most significant parameters for the script are as follows:
--image-embed-url, default = "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/")
--text-embed-url, default = "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/")
--image-embed-dim, default=768 - 768 corresponds to the ViT iL/14 embedding size,change it to what your chosen ViT generates
--learning-rate, default=1.1e-4
--weight-decay, default=6.02e-2
--max-grad-norm, default=0.5
--batch-size, default=10 ** 4
--num-epochs, default=5
--clip, default=None # Signals the prior to use pre-computed embeddings
### Sample wandb run log
Please find a sample wandb run log at : https://wandb.ai/laion/diffusion-prior/runs/1blxu24j
### Loading and saving the Diffusion Prior model
Two methods are provided, load_diffusion_model and save_diffusion_model, the names being self-explanatory.
## from dalle2_pytorch.train import load_diffusion_model, save_diffusion_model
load_diffusion_model(dprior_path, device)
dprior_path : path to saved model(.pth)
device : the cuda device you're running on
save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim)
save_path : path to save at
model : object of Diffusion_Prior
optimizer : optimizer object - see train_diffusion_prior.py for how to create one.
e.g: optimizer = get_optimizer(diffusion_prior.net.parameters(), wd=weight_decay, lr=learning_rate)
scaler : a GradScaler object.
e.g: scaler = GradScaler(enabled=amp)
config : config object created in train_diffusion_prior.py - see file for example.
image_embed_dim - the dimension of the image_embedding
e.g: 768
## CLI (wip)
```bash
@@ -823,20 +998,25 @@ Once built, images will be saved to the same directory the command is invoked
- [x] add convnext backbone for vqgan-vae (in addition to vit [vit-vqgan] + resnet)
- [x] make sure DDPMs can be run with traditional resnet blocks (but leave convnext as an option for experimentation)
- [x] make sure for the latter unets in the cascade, one can train on crops for learning super resolution (constrain the unet to be only convolutions in that case, or allow conv-like attention with rel pos bias)
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo)
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
- [x] offer setting in diffusion prior to split time and image embeddings into multiple tokens, configurable, for more surface area during attention
- [x] make sure resnet hyperparameters can be configurable across unet depth (groups and expansion factor)
- [x] pull logic for training diffusion prior into a class DiffusionPriorTrainer, for eventual script based + CLI based training
- [x] make sure the cascading ddpm in the repository can be trained unconditionally, offer a one-line CLI tool for training on a folder of images
- [x] bring in cross-scale embedding from iclr paper https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/crossformer.py#L14
- [x] cross embed layers for downsampling, as an option
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
- [ ] pull logic for training diffusion prior into a class DiffusionPriorTrainer, for eventual script based + CLI based training
- [ ] train on a toy task, offer in colab
- [ ] think about how best to design a declarative training config that handles preencoding for prior and training of multiple networks in decoder
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
- [ ] bring in cross-scale embedding from iclr paper https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/crossformer.py#L14
- [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824
- [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove
- [ ] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a>
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
- [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783
- [ ] make sure resnet | convnext block hyperparameters can be configurable across unet depth (groups and expansion factor)
- [ ] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
- [ ] bring in skip-layer excitatons (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training
- [ ] decoder needs one day worth of refactor for tech debt
## Citations
@@ -866,14 +1046,6 @@ Once built, images will be saved to the same directory the command is invoked
}
```
```bibtex
@inproceedings{Liu2022ACF,
title = {A ConvNet for the 2020s},
author = {Zhuang Liu and Hanzi Mao and Chaozheng Wu and Christoph Feichtenhofer and Trevor Darrell and Saining Xie},
year = {2022}
}
```
```bibtex
@article{shen2019efficient,
author = {Zhuoran Shen and Mingyuan Zhang and Haiyu Zhao and Shuai Yi and Hongsheng Li},
@@ -912,4 +1084,25 @@ Once built, images will be saved to the same directory the command is invoked
}
```
```bibtex
@article{Yu2022CoCaCC,
title = {CoCa: Contrastive Captioners are Image-Text Foundation Models},
author = {Jiahui Yu and Zirui Wang and Vijay Vasudevan and Legg Yeung and Mojtaba Seyedhosseini and Yonghui Wu},
journal = {ArXiv},
year = {2022},
volume = {abs/2205.01917}
}
```
```bibtex
@misc{wang2021crossformer,
title = {CrossFormer: A Versatile Vision Transformer Hinging on Cross-scale Attention},
author = {Wenxiao Wang and Lu Yao and Long Chen and Binbin Lin and Deng Cai and Xiaofei He and Wei Liu},
year = {2021},
eprint = {2108.00154},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>

View File

@@ -1,6 +1,6 @@
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter
from dalle2_pytorch.train import DecoderTrainer
from dalle2_pytorch.train import DecoderTrainer, DiffusionPriorTrainer
from dalle2_pytorch.vqgan_vae import VQGanVAE
from x_clip import CLIP

View File

@@ -4,6 +4,7 @@ from inspect import isfunction
from functools import partial
from contextlib import contextmanager
from collections import namedtuple
from pathlib import Path
import torch
import torch.nn.functional as F
@@ -23,9 +24,14 @@ from dalle2_pytorch.vqgan_vae import NullVQGanVAE, VQGanVAE
from resize_right import resize
# rotary embeddings
from rotary_embedding_torch import RotaryEmbedding
# use x-clip
from x_clip import CLIP
from coca_pytorch import CoCa
# helper functions
@@ -113,9 +119,10 @@ EmbeddedText = namedtuple('EmbedTextReturn', ['text_embed', 'text_encodings', 't
EmbeddedImage = namedtuple('EmbedImageReturn', ['image_embed', 'image_encodings'])
class BaseClipAdapter(nn.Module):
def __init__(self, clip):
def __init__(self, clip, **kwargs):
super().__init__()
self.clip = clip
self.overrides = kwargs
@property
def dim_latent(self):
@@ -173,6 +180,39 @@ class XClipAdapter(BaseClipAdapter):
image_embed = self.clip.to_visual_latent(image_cls)
return EmbeddedImage(l2norm(image_embed), image_encodings)
class CoCaAdapter(BaseClipAdapter):
@property
def dim_latent(self):
return self.clip.dim
@property
def image_size(self):
assert 'image_size' in self.overrides
return self.overrides['image_size']
@property
def image_channels(self):
assert 'image_channels' in self.overrides
return self.overrides['image_channels']
@property
def max_text_len(self):
assert 'max_text_len' in self.overrides
return self.overrides['max_text_len']
@torch.no_grad()
def embed_text(self, text):
text = text[..., :self.max_text_len]
text_mask = text != 0
text_embed, text_encodings = self.clip.embed_text(text)
return EmbeddedText(text_embed, text_encodings, text_mask)
@torch.no_grad()
def embed_image(self, image):
image = resize_image_to(image, self.image_size)
image_embed, image_encodings = self.clip.embed_image(image)
return EmbeddedImage(image_embed, image_encodings)
class OpenAIClipAdapter(BaseClipAdapter):
def __init__(
self,
@@ -225,7 +265,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
text_embed = self.clip.encode_text(text)
text_encodings = self.text_encodings
del self.text_encodings
return EmbeddedText(text_embed.float(), text_encodings.float(), text_mask)
return EmbeddedText(l2norm(text_embed.float()), text_encodings.float(), text_mask)
@torch.no_grad()
def embed_image(self, image):
@@ -233,7 +273,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
image = resize_image_to(image, self.image_size)
image = self.clip_normalize(unnormalize_img(image))
image_embed = self.clip.encode_image(image)
return EmbeddedImage(image_embed.float(), None)
return EmbeddedImage(l2norm(image_embed.float()), None)
# classifier free guidance functions
@@ -263,7 +303,7 @@ def cosine_beta_schedule(timesteps, s = 0.008):
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
"""
steps = timesteps + 1
x = torch.linspace(0, timesteps, steps)
x = torch.linspace(0, timesteps, steps, dtype = torch.float64)
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
@@ -274,21 +314,21 @@ def linear_beta_schedule(timesteps):
scale = 1000 / timesteps
beta_start = scale * 0.0001
beta_end = scale * 0.02
return torch.linspace(beta_start, beta_end, timesteps)
return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64)
def quadratic_beta_schedule(timesteps):
scale = 1000 / timesteps
beta_start = scale * 0.0001
beta_end = scale * 0.02
return torch.linspace(beta_start**2, beta_end**2, timesteps) ** 2
return torch.linspace(beta_start**2, beta_end**2, timesteps, dtype = torch.float64) ** 2
def sigmoid_beta_schedule(timesteps):
scale = 1000 / timesteps
beta_start = scale * 0.0001
beta_end = scale * 0.02
betas = torch.linspace(-6, 6, timesteps)
betas = torch.linspace(-6, 6, timesteps, dtype = torch.float64)
return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
@@ -328,17 +368,21 @@ class BaseGaussianDiffusion(nn.Module):
self.loss_type = loss_type
self.loss_fn = loss_fn
self.register_buffer('betas', betas)
self.register_buffer('alphas_cumprod', alphas_cumprod)
self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
# register buffer helper function to cast double back to float
register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))
register_buffer('betas', betas)
register_buffer('alphas_cumprod', alphas_cumprod)
register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
# calculations for posterior q(x_{t-1} | x_t, x_0)
@@ -346,13 +390,13 @@ class BaseGaussianDiffusion(nn.Module):
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.register_buffer('posterior_variance', posterior_variance)
register_buffer('posterior_variance', posterior_variance)
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
def q_mean_variance(self, x_start, t):
mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
@@ -531,7 +575,8 @@ class Attention(nn.Module):
heads = 8,
dropout = 0.,
causal = False,
post_norm = False
post_norm = False,
rotary_emb = None
):
super().__init__()
self.scale = dim_head ** -0.5
@@ -547,6 +592,8 @@ class Attention(nn.Module):
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
self.rotary_emb = rotary_emb
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim, bias = False),
LayerNorm(dim) if post_norm else nn.Identity()
@@ -559,6 +606,12 @@ class Attention(nn.Module):
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1))
q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)
q = q * self.scale
# rotary embeddings
if exists(self.rotary_emb):
q, k = map(self.rotary_emb.rotate_queries_or_keys, (q, k))
# add null key / value for classifier free guidance in prior net
@@ -566,7 +619,7 @@ class Attention(nn.Module):
k = torch.cat((nk, k), dim = -2)
v = torch.cat((nv, v), dim = -2)
q = q * self.scale
# calculate query / key similarities
sim = einsum('b h i d, b j d -> b h i j', q, k)
@@ -591,7 +644,7 @@ class Attention(nn.Module):
# attention
sim = sim - sim.amax(dim = -1, keepdim = True)
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
attn = sim.softmax(dim = -1)
attn = self.dropout(attn)
@@ -616,15 +669,18 @@ class CausalTransformer(nn.Module):
attn_dropout = 0.,
ff_dropout = 0.,
final_proj = True,
normformer = False
normformer = False,
rotary_emb = True
):
super().__init__()
self.rel_pos_bias = RelPosBias(heads = heads)
rotary_emb = RotaryEmbedding(dim = min(32, dim_head)) if rotary_emb else None
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, post_norm = normformer),
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, post_norm = normformer, rotary_emb = rotary_emb),
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
]))
@@ -652,14 +708,33 @@ class DiffusionPriorNetwork(nn.Module):
self,
dim,
num_timesteps = None,
l2norm_output = False, # whether to restrict image embedding output with l2norm at the end (may make it easier to learn?)
num_time_embeds = 1,
num_image_embeds = 1,
num_text_embeds = 1,
**kwargs
):
super().__init__()
self.time_embeddings = nn.Embedding(num_timesteps, dim) if exists(num_timesteps) else nn.Sequential(Rearrange('b -> b 1'), MLP(1, dim)) # also offer a continuous version of timestep embeddings, with a 2 layer MLP
self.num_time_embeds = num_time_embeds
self.num_image_embeds = num_image_embeds
self.num_text_embeds = num_text_embeds
self.to_text_embeds = nn.Sequential(
nn.Linear(dim, dim * num_text_embeds) if num_text_embeds > 1 else nn.Identity(),
Rearrange('b (n d) -> b n d', n = num_text_embeds)
)
self.to_time_embeds = nn.Sequential(
nn.Embedding(num_timesteps, dim * num_time_embeds) if exists(num_timesteps) else nn.Sequential(SinusoidalPosEmb(dim), MLP(dim, dim * num_time_embeds)), # also offer a continuous version of timestep embeddings, with a 2 layer MLP
Rearrange('b (n d) -> b n d', n = num_time_embeds)
)
self.to_image_embeds = nn.Sequential(
nn.Linear(dim, dim * num_image_embeds) if num_image_embeds > 1 else nn.Identity(),
Rearrange('b (n d) -> b n d', n = num_image_embeds)
)
self.learned_query = nn.Parameter(torch.randn(dim))
self.causal_transformer = CausalTransformer(dim = dim, **kwargs)
self.l2norm_output = l2norm_output
def forward_with_cond_scale(
self,
@@ -687,10 +762,13 @@ class DiffusionPriorNetwork(nn.Module):
):
batch, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype
num_time_embeds, num_image_embeds, num_text_embeds = self.num_time_embeds, self.num_image_embeds, self.num_text_embeds
# in section 2.2, last paragraph
# "... consisting of encoded text, CLIP text embedding, diffusion timestep embedding, noised CLIP image embedding, final embedding for prediction"
text_embed, image_embed = rearrange_many((text_embed, image_embed), 'b d -> b 1 d')
text_embed = self.to_text_embeds(text_embed)
image_embed = self.to_image_embeds(image_embed)
# make text encodings optional
# although the paper seems to suggest it is present <--
@@ -710,16 +788,17 @@ class DiffusionPriorNetwork(nn.Module):
# whether text embedding is masked or not depends on the classifier free guidance conditional masking
keep_mask = repeat(keep_mask, 'b 1 -> b n', n = num_text_embeds)
mask = torch.cat((mask, keep_mask), dim = 1)
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
# but let's just do it right
if exists(mask):
mask = F.pad(mask, (0, 2), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
attend_padding = 1 + num_time_embeds + num_image_embeds # 1 for learned queries + number of image embeds + time embeds
mask = F.pad(mask, (0, attend_padding), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
time_embed = self.time_embeddings(diffusion_timesteps)
time_embed = rearrange(time_embed, 'b d -> b 1 d')
time_embed = self.to_time_embeds(diffusion_timesteps)
learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch)
@@ -727,6 +806,7 @@ class DiffusionPriorNetwork(nn.Module):
text_encodings,
text_embed,
time_embed,
image_embed,
learned_queries
), dim = -2)
@@ -738,8 +818,7 @@ class DiffusionPriorNetwork(nn.Module):
pred_image_embed = tokens[..., -1, :]
output_fn = l2norm if self.l2norm_output else identity
return output_fn(pred_image_embed)
return pred_image_embed
class DiffusionPrior(BaseGaussianDiffusion):
def __init__(
@@ -751,12 +830,16 @@ class DiffusionPrior(BaseGaussianDiffusion):
image_size = None,
image_channels = 3,
timesteps = 1000,
cond_drop_prob = 0.2,
loss_type = "l1",
cond_drop_prob = 0.,
loss_type = "l2",
predict_x_start = True,
beta_schedule = "cosine",
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
sampling_clamp_l2norm = False
sampling_clamp_l2norm = False,
training_clamp_l2norm = False,
init_image_embed_l2norm = False,
image_embed_scale = None, # this is for scaling the l2-normed image embedding, so it is more suitable for gaussian diffusion, as outlined by Katherine (@crowsonkb) https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
clip_adapter_overrides = dict()
):
super().__init__(
beta_schedule = beta_schedule,
@@ -766,7 +849,9 @@ class DiffusionPrior(BaseGaussianDiffusion):
if exists(clip):
if isinstance(clip, CLIP):
clip = XClipAdapter(clip)
clip = XClipAdapter(clip, **clip_adapter_overrides)
elif isinstance(clip, CoCa):
clip = CoCaAdapter(clip, **clip_adapter_overrides)
assert isinstance(clip, BaseClipAdapter)
freeze_model_and_make_eval_(clip)
@@ -782,11 +867,16 @@ class DiffusionPrior(BaseGaussianDiffusion):
self.cond_drop_prob = cond_drop_prob
self.condition_on_text_encodings = condition_on_text_encodings
self.predict_x_start = predict_x_start
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
self.predict_x_start = predict_x_start
# @crowsonkb 's suggestion - https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
self.image_embed_scale = default(image_embed_scale, self.image_embed_dim ** 0.5)
# whether to force an l2norm, similar to clipping denoised, when sampling
self.sampling_clamp_l2norm = sampling_clamp_l2norm
self.training_clamp_l2norm = training_clamp_l2norm
self.init_image_embed_l2norm = init_image_embed_l2norm
def p_mean_variance(self, x, t, text_cond, clip_denoised: bool):
pred = self.net(x, t, **text_cond)
@@ -802,7 +892,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
x_recon.clamp_(-1., 1.)
if self.predict_x_start and self.sampling_clamp_l2norm:
x_recon = l2norm(x_recon)
x_recon = l2norm(x_recon) * self.image_embed_scale
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
return model_mean, posterior_variance, posterior_log_variance
@@ -821,11 +911,16 @@ class DiffusionPrior(BaseGaussianDiffusion):
device = self.betas.device
b = shape[0]
img = torch.randn(shape, device=device)
image_embed = torch.randn(shape, device=device)
if self.init_image_embed_l2norm:
image_embed = l2norm(image_embed) * self.image_embed_scale
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), text_cond = text_cond)
return img
times = torch.full((b,), i, device = device, dtype = torch.long)
image_embed = self.p_sample(image_embed, times, text_cond = text_cond)
return image_embed
def p_losses(self, image_embed, times, text_cond, noise = None):
noise = default(noise, lambda: torch.randn_like(image_embed))
@@ -839,11 +934,26 @@ class DiffusionPrior(BaseGaussianDiffusion):
**text_cond
)
if self.predict_x_start and self.training_clamp_l2norm:
pred = l2norm(pred) * self.image_embed_scale
target = noise if not self.predict_x_start else image_embed
loss = self.loss_fn(pred, target)
return loss
@torch.inference_mode()
@eval_decorator
def sample_batch_size(self, batch_size, text_cond):
device = self.betas.device
shape = (batch_size, self.image_embed_dim)
img = torch.randn(shape, device = device)
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
img = self.p_sample(img, torch.full((batch_size,), i, device = device, dtype = torch.long), text_cond = text_cond)
return img
@torch.inference_mode()
@eval_decorator
def sample(self, text, num_samples_per_batch = 2):
@@ -862,6 +972,11 @@ class DiffusionPrior(BaseGaussianDiffusion):
text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask}
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond)
# retrieve original unscaled image embed
image_embeds /= self.image_embed_scale
text_embeds = text_cond['text_embed']
text_embeds = rearrange(text_embeds, '(b r) d -> b r d', r = num_samples_per_batch)
@@ -909,6 +1024,10 @@ class DiffusionPrior(BaseGaussianDiffusion):
batch, device = image_embed.shape[0], image_embed.device
times = torch.randint(0, self.num_timesteps, (batch,), device = device, dtype = torch.long)
# scale image embed (Katherine)
image_embed *= self.image_embed_scale
# calculate forward loss
return self.p_losses(image_embed, times, text_cond = text_cond, *args, **kwargs)
@@ -999,68 +1118,6 @@ class ResnetBlock(nn.Module):
h = self.block2(h)
return h + self.res_conv(x)
class ConvNextBlock(nn.Module):
""" https://arxiv.org/abs/2201.03545 """
def __init__(
self,
dim,
dim_out,
*,
cond_dim = None,
time_cond_dim = None,
mult = 2
):
super().__init__()
need_projection = dim != dim_out
self.cross_attn = None
if exists(cond_dim):
self.cross_attn = EinopsToAndFrom(
'b c h w',
'b (h w) c',
CrossAttention(
dim = dim,
context_dim = cond_dim
)
)
self.time_mlp = None
if exists(time_cond_dim):
self.time_mlp = nn.Sequential(
nn.GELU(),
nn.Linear(time_cond_dim, dim)
)
self.ds_conv = nn.Conv2d(dim, dim, 7, padding = 3, groups = dim)
inner_dim = int(dim_out * mult)
self.net = nn.Sequential(
ChanLayerNorm(dim),
nn.Conv2d(dim, inner_dim, 3, padding = 1),
nn.GELU(),
nn.Conv2d(inner_dim, dim_out, 3, padding = 1)
)
self.res_conv = nn.Conv2d(dim, dim_out, 1) if need_projection else nn.Identity()
def forward(self, x, cond = None, time = None):
h = self.ds_conv(x)
if exists(time) and exists(self.time_mlp):
t = self.time_mlp(time)
h = rearrange(t, 'b c -> b c 1 1') + h
if exists(self.cross_attn):
assert exists(cond)
h = self.cross_attn(h, context = cond) + h
h = self.net(h)
return h + self.res_conv(x)
class CrossAttention(nn.Module):
def __init__(
self,
@@ -1114,7 +1171,7 @@ class CrossAttention(nn.Module):
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, max_neg_value)
sim = sim - sim.amax(dim = -1, keepdim = True)
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
attn = sim.softmax(dim = -1)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
@@ -1175,6 +1232,33 @@ class LinearAttention(nn.Module):
out = self.nonlin(out)
return self.to_out(out)
class CrossEmbedLayer(nn.Module):
def __init__(
self,
dim_in,
kernel_sizes,
dim_out = None,
stride = 2
):
super().__init__()
assert all([*map(lambda t: (t % 2) == (stride % 2), kernel_sizes)])
dim_out = default(dim_out, dim_in)
kernel_sizes = sorted(kernel_sizes)
num_scales = len(kernel_sizes)
# calculate the dimension at each scale
dim_scales = [int(dim_out / (2 ** i)) for i in range(1, num_scales)]
dim_scales = [*dim_scales, dim_out - sum(dim_scales)]
self.convs = nn.ModuleList([])
for kernel, dim_scale in zip(kernel_sizes, dim_scales):
self.convs.append(nn.Conv2d(dim_in, dim_scale, kernel, stride = stride, padding = (kernel - stride) // 2))
def forward(self, x):
fmaps = tuple(map(lambda conv: conv(x), self.convs))
return torch.cat(fmaps, dim = 1)
class Unet(nn.Module):
def __init__(
self,
@@ -1198,9 +1282,10 @@ class Unet(nn.Module):
cond_on_image_embeds = False,
init_dim = None,
init_conv_kernel_size = 7,
block_type = 'resnet',
block_resnet_groups = 8,
block_convnext_mult = 2,
resnet_groups = 8,
init_cross_embed_kernel_sizes = (3, 7, 15),
cross_embed_downsample = False,
cross_embed_downsample_kernel_sizes = (2, 4),
**kwargs
):
super().__init__()
@@ -1219,10 +1304,9 @@ class Unet(nn.Module):
self.channels = channels
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
init_dim = default(init_dim, dim // 2)
init_dim = default(init_dim, dim // 3 * 2)
assert (init_conv_kernel_size % 2) == 1
self.init_conv = nn.Conv2d(init_channels, init_dim, init_conv_kernel_size, padding = init_conv_kernel_size // 2)
self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1)
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
@@ -1276,14 +1360,17 @@ class Unet(nn.Module):
attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head)
# whether to use resnet or the (improved?) convnext blocks
# resnet block klass
if block_type == 'resnet':
block_klass = partial(ResnetBlock, groups = block_resnet_groups)
elif block_type == 'convnext':
block_klass = partial(ConvNextBlock, mult = block_convnext_mult)
else:
raise ValueError(f'unimplemented block type {block_type}')
resnet_groups = cast_tuple(resnet_groups, len(in_out))
assert len(resnet_groups) == len(in_out)
# downsample klass
downsample_klass = Downsample
if cross_embed_downsample:
downsample_klass = partial(CrossEmbedLayer, kernel_sizes = cross_embed_downsample_kernel_sizes)
# layers
@@ -1291,38 +1378,39 @@ class Unet(nn.Module):
self.ups = nn.ModuleList([])
num_resolutions = len(in_out)
for ind, (dim_in, dim_out) in enumerate(in_out):
for ind, ((dim_in, dim_out), groups) in enumerate(zip(in_out, resnet_groups)):
is_first = ind == 0
is_last = ind >= (num_resolutions - 1)
layer_cond_dim = cond_dim if not is_first else None
self.downs.append(nn.ModuleList([
block_klass(dim_in, dim_out, time_cond_dim = time_cond_dim),
ResnetBlock(dim_in, dim_out, time_cond_dim = time_cond_dim, groups = groups),
Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(),
block_klass(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
Downsample(dim_out) if not is_last else nn.Identity()
ResnetBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
downsample_klass(dim_out) if not is_last else nn.Identity()
]))
mid_dim = dims[-1]
self.mid_block1 = block_klass(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim)
self.mid_block1 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
self.mid_block2 = block_klass(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim)
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
for ind, ((dim_in, dim_out), groups) in enumerate(zip(reversed(in_out[1:]), reversed(resnet_groups))):
is_last = ind >= (num_resolutions - 2)
layer_cond_dim = cond_dim if not is_last else None
self.ups.append(nn.ModuleList([
block_klass(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
ResnetBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
Residual(LinearAttention(dim_in, **attn_kwargs)) if sparse_attn else nn.Identity(),
block_klass(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
ResnetBlock(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
Upsample(dim_in)
]))
out_dim = default(out_dim, channels)
self.final_conv = nn.Sequential(
block_klass(dim, dim),
ResnetBlock(dim, dim, groups = resnet_groups[0]),
nn.Conv2d(dim, out_dim, 1)
)
@@ -1333,12 +1421,13 @@ class Unet(nn.Module):
*,
lowres_cond,
channels,
cond_on_image_embeds
cond_on_image_embeds,
cond_on_text_encodings
):
if lowres_cond == self.lowres_cond and channels == self.channels and cond_on_image_embeds == self.cond_on_image_embeds:
if lowres_cond == self.lowres_cond and channels == self.channels and cond_on_image_embeds == self.cond_on_image_embeds and cond_on_text_encodings == self.cond_on_text_encodings:
return self
updated_kwargs = {'lowres_cond': lowres_cond, 'channels': channels, 'cond_on_image_embeds': cond_on_image_embeds}
updated_kwargs = {'lowres_cond': lowres_cond, 'channels': channels, 'cond_on_image_embeds': cond_on_image_embeds, 'cond_on_text_encodings': cond_on_text_encodings}
return self.__class__(**{**self._locals, **updated_kwargs})
def forward_with_cond_scale(
@@ -1403,11 +1492,12 @@ class Unet(nn.Module):
if self.cond_on_image_embeds:
image_tokens = self.image_to_cond(image_embed)
null_image_embed = self.null_image_embed.to(image_tokens.dtype) # for some reason pytorch AMP not working
image_tokens = torch.where(
image_keep_mask,
image_tokens,
self.null_image_embed
null_image_embed
)
# take care of text encodings (optional)
@@ -1431,10 +1521,12 @@ class Unet(nn.Module):
text_mask = rearrange(text_mask, 'b n -> b n 1')
text_keep_mask = text_mask & text_keep_mask
null_text_embed = self.null_text_embed.to(text_tokens.dtype) # for some reason pytorch AMP not working
text_tokens = torch.where(
text_keep_mask,
text_tokens,
self.null_text_embed
null_text_embed
)
# main conditioning tokens (c)
@@ -1515,12 +1607,14 @@ class Decoder(BaseGaussianDiffusion):
self,
unet,
*,
clip,
clip = None,
image_size = None,
channels = 3,
vae = tuple(),
timesteps = 1000,
image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5,
loss_type = 'l1',
loss_type = 'l2',
beta_schedule = 'cosine',
predict_x_start = False,
predict_x_start_for_latent_diffusion = False,
@@ -1531,7 +1625,9 @@ class Decoder(BaseGaussianDiffusion):
blur_kernel_size = 3, # cascading ddpm - blur kernel size
condition_on_text_encodings = False, # the paper suggested that this didn't do much in the decoder, but i'm allowing the option for experimentation
clip_denoised = True,
clip_x_start = True
clip_x_start = True,
clip_adapter_overrides = dict(),
unconditional = False
):
super().__init__(
beta_schedule = beta_schedule,
@@ -1539,15 +1635,27 @@ class Decoder(BaseGaussianDiffusion):
loss_type = loss_type
)
if isinstance(clip, CLIP):
clip = XClipAdapter(clip)
self.unconditional = unconditional
assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present'
freeze_model_and_make_eval_(clip)
assert isinstance(clip, BaseClipAdapter)
assert exists(clip) ^ exists(image_size), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)'
self.clip = clip
self.clip_image_size = clip.image_size
self.channels = clip.image_channels
self.clip = None
if exists(clip):
if isinstance(clip, CLIP):
clip = XClipAdapter(clip, **clip_adapter_overrides)
elif isinstance(clip, CoCa):
clip = CoCaAdapter(clip, **clip_adapter_overrides)
freeze_model_and_make_eval_(clip)
assert isinstance(clip, BaseClipAdapter)
self.clip = clip
self.clip_image_size = clip.image_size
self.channels = clip.image_channels
else:
self.clip_image_size = image_size
self.channels = channels
self.condition_on_text_encodings = condition_on_text_encodings
@@ -1571,7 +1679,8 @@ class Decoder(BaseGaussianDiffusion):
one_unet = one_unet.cast_model_parameters(
lowres_cond = not is_first,
cond_on_image_embeds = is_first,
cond_on_image_embeds = is_first and not unconditional,
cond_on_text_encodings = one_unet.cond_on_text_encodings and not unconditional,
channels = unet_channels
)
@@ -1580,7 +1689,7 @@ class Decoder(BaseGaussianDiffusion):
# unet image sizes
image_sizes = default(image_sizes, (clip.image_size,))
image_sizes = default(image_sizes, (self.clip_image_size,))
image_sizes = tuple(sorted(set(image_sizes)))
assert len(self.unets) == len(image_sizes), f'you did not supply the correct number of u-nets ({len(self.unets)}) for resolutions {image_sizes}'
@@ -1706,12 +1815,16 @@ class Decoder(BaseGaussianDiffusion):
@eval_decorator
def sample(
self,
image_embed,
image_embed = None,
text = None,
batch_size = 1,
cond_scale = 1.,
stop_at_unet_number = None
):
batch_size = image_embed.shape[0]
assert self.unconditional or exists(image_embed), 'image embed must be present on sampling from decoder unless if trained unconditionally'
if not self.unconditional:
batch_size = image_embed.shape[0]
text_encodings = text_mask = None
if exists(text):
@@ -1721,10 +1834,11 @@ class Decoder(BaseGaussianDiffusion):
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
img = None
is_cuda = next(self.parameters()).is_cuda
for unet_number, unet, vae, channel, image_size, predict_x_start in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)):
context = self.one_unet_in_gpu(unet = unet) if image_embed.is_cuda else null_context()
context = self.one_unet_in_gpu(unet = unet) if is_cuda else null_context()
with context:
lowres_cond_img = None
@@ -1785,10 +1899,12 @@ class Decoder(BaseGaussianDiffusion):
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
if not exists(image_embed):
assert exists(self.clip), 'if you want to derive CLIP image embeddings automatically, you must supply `clip` to the decoder on init'
image_embed, _ = self.clip.embed_image(image)
text_encodings = text_mask = None
if exists(text) and not exists(text_encodings):
assert exists(self.clip), 'if you are passing in raw text, you need to supply `clip` to the decoder'
_, text_encodings, text_mask = self.clip.embed_text(text)
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
@@ -1862,3 +1978,4 @@ class DALLE2(nn.Module):
return images[0]
return images

View File

@@ -0,0 +1 @@
from dalle2_pytorch.dataloaders.decoder_loader import ImageEmbeddingDataset, create_image_embedding_dataloader

View File

@@ -0,0 +1,170 @@
import os
import webdataset as wds
import torch
import numpy as np
import fsspec
def get_shard(filename):
"""
Filenames with shards in them have a consistent structure that we can take advantage of
Standard structure: path/to/file/prefix_string_00001.ext
"""
try:
return filename.split("_")[-1].split(".")[0]
except ValueError:
raise RuntimeError(f"Could not find shard for filename {filename}")
def get_example_file(fs, path, file_format):
"""
Given a file system and a file extension, return the example file
"""
return fs.glob(os.path.join(path, f"*.{file_format}"))[0]
def embedding_inserter(samples, embeddings_url, shard_width, handler=wds.handlers.reraise_exception):
"""Given a datum of {"__key__": str, "__url__": str, ...} adds the cooresponding embedding and yields"""
previous_tar_url = None
current_embeddings = None
# Get a reference to an abstract file system where the embeddings are stored
embeddings_fs, embeddings_path = fsspec.core.url_to_fs(embeddings_url)
example_embedding_file = get_example_file(embeddings_fs, embeddings_path, "npy")
example_embedding_shard = get_shard(example_embedding_file)
emb_shard_width = len(example_embedding_shard)
# Easier to get the basename without the shard once than search through for the correct file every time
embedding_file_basename = '_'.join(example_embedding_file.split("_")[:-1]) + "_"
def load_corresponding_embeds(tar_url):
"""Finds and reads the npy files that contains embeddings for the given webdataset tar"""
shard = int(tar_url.split("/")[-1].split(".")[0])
embedding_url = embedding_file_basename + str(shard).zfill(emb_shard_width) + '.npy'
with embeddings_fs.open(embedding_url) as f:
data = np.load(f)
return torch.from_numpy(data)
for sample in samples:
try:
tar_url = sample["__url__"]
key = sample["__key__"]
if tar_url != previous_tar_url:
# If the tar changed, we need to download new embeddings
# This means if we shuffle before inserting it will load many more files than we expect and be very inefficient.
previous_tar_url = tar_url
current_embeddings = load_corresponding_embeds(tar_url)
embedding_index = int(key[shard_width:])
sample["npy"] = current_embeddings[embedding_index]
yield sample
except Exception as exn: # From wds implementation
if handler(exn):
continue
else:
break
insert_embedding = wds.filters.pipelinefilter(embedding_inserter)
def verify_keys(samples, handler=wds.handlers.reraise_exception):
"""
Requires that both the image and embedding are present in the sample
This is important to do as a user may forget they do not have embeddings in their webdataset and neglect to add them using the embedding_folder_url parameter.
"""
for sample in samples:
try:
assert "jpg" in sample, f"Sample {sample['__key__']} missing image"
assert "npy" in sample, f"Sample {sample['__key__']} missing embedding. Did you set embedding_folder_url?"
yield sample
except Exception as exn: # From wds implementation
if handler(exn):
continue
else:
break
class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
"""
A fluid interface wrapper for DataPipline that returns image embedding pairs
Reads embeddings as npy files from the webdataset if they exist. If embedding_folder_url is set, they will be inserted in from the alternate source.
"""
def __init__(
self,
urls,
embedding_folder_url=None,
shard_width=None,
handler=wds.handlers.reraise_exception,
resample=False,
shuffle_shards=True
):
"""
Modeled directly off of the WebDataset constructor
:param urls: A url pointing to the tar files of the webdataset formatted as /path/to/webdataset/{0000..9999}.tar
:param embedding_folder_url: Required if webdataset does not contain embeddings. A url pointing to the npy files of the embeddings. Should have the same number of shards as the webdataset.
Webdataset image keys should align with the index of the embedding. This means missing image indices must have a corresponding embedding of all zeros.
:param shard_width: The number of digits in the shard number. This is used to align the embedding index with the image index.
For example, if a file in the webdataset shard 3 is named 0003039.jpg, we know the shard with this 4 and the last three digits are the index.
:param handler: A webdataset handler.
:param resample: If true, resample webdataset shards with replacement. You need to set your own epoch size if this is true since it will resample infinitely.
:param shuffle_shards: If true, shuffle the shards before resampling. This cannot be true if resample is true.
"""
super().__init__()
# Add the shardList and randomize or resample if requested
if resample:
assert not shuffle_shards, "Cannot both resample and shuffle"
self.append(wds.ResampledShards(urls))
else:
self.append(wds.SimpleShardList(urls))
if shuffle_shards:
self.append(wds.filters.shuffle(1000))
self.append(wds.split_by_node)
self.append(wds.split_by_worker)
self.append(wds.tarfile_to_samples(handler=handler))
self.append(wds.decode("torchrgb"))
if embedding_folder_url is not None:
assert shard_width is not None, "Reading embeddings separately requires shard length to be given"
self.append(insert_embedding(embeddings_url=embedding_folder_url, shard_width=shard_width, handler=handler))
self.append(verify_keys)
self.append(wds.to_tuple("jpg", "npy"))
def create_image_embedding_dataloader(
tar_url,
num_workers,
batch_size,
embeddings_url=None,
shard_width=None,
shuffle_num = None,
shuffle_shards = True,
resample_shards = False,
handler=wds.handlers.warn_and_continue
):
"""
Convenience function to create an image embedding dataseta and dataloader in one line
:param tar_url: A url pointing to the tar files of the webdataset formatted as /path/to/webdataset/{0000..9999}.tar
:param num_workers: The number of workers to use for the dataloader
:param batch_size: The batch size to use for the dataloader
:param embeddings_url: Required if webdataset does not contain embeddings. A url pointing to the npy files of the embeddings. Should have the same number of shards as the webdataset.
Webdataset image keys should align with the index of the embedding. This means missing image indices must have a corresponding embedding of all zeros.
:param shard_width: The number of digits in the shard number. This is used to align the embedding index with the image index.
For example, if a file in the webdataset shard 3 is named 0003039.jpg, we know the shard width is 4 and the last three digits are the index.
:param shuffle_num: If not None, shuffle the dataset with this size buffer after sampling.
:param shuffle_shards: If true, shuffle the shards before sampling. This cannot be true if resample is true.
:param resample_shards: If true, resample webdataset shards with replacement. You need to set your own epoch size if this is true since it will resample infinitely.
:param handler: A webdataset handler.
"""
ds = ImageEmbeddingDataset(
tar_url,
embeddings_url,
shard_width=shard_width,
shuffle_shards=shuffle_shards,
resample=resample_shards,
handler=handler
)
if shuffle_num is not None and shuffle_num > 0:
ds.shuffle(1000)
return wds.WebLoader(
ds,
num_workers=num_workers,
batch_size=batch_size,
prefetch_factor=2, # This might be good to have high so the next npy file is prefetched
pin_memory=True,
shuffle=False
)

View File

@@ -1,3 +1,4 @@
import time
import copy
from functools import partial
@@ -5,7 +6,7 @@ import torch
from torch import nn
from torch.cuda.amp import autocast, GradScaler
from dalle2_pytorch.dalle2_pytorch import Decoder
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
from dalle2_pytorch.optimizer import get_optimizer
# helper functions
@@ -39,6 +40,50 @@ def groupby_prefix_and_trim(prefix, d):
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
return kwargs_without_prefix, kwargs
# print helpers
def print_ribbon(s, symbol = '=', repeat = 40):
flank = symbol * repeat
return f'{flank} {s} {flank}'
# saving and loading functions
# for diffusion prior
def load_diffusion_model(dprior_path, device):
dprior_path = Path(dprior_path)
assert dprior_path.exists(), 'Dprior model file does not exist'
loaded_obj = torch.load(str(dprior_path), map_location='cpu')
# Get hyperparameters of loaded model
dpn_config = loaded_obj['hparams']['diffusion_prior_network']
dp_config = loaded_obj['hparams']['diffusion_prior']
image_embed_dim = loaded_obj['image_embed_dim']['image_embed_dim']
# Create DiffusionPriorNetwork and DiffusionPrior with loaded hyperparameters
# DiffusionPriorNetwork
prior_network = DiffusionPriorNetwork( dim = image_embed_dim, **dpn_config).to(device)
# DiffusionPrior with text embeddings and image embeddings pre-computed
diffusion_prior = DiffusionPrior(net = prior_network, **dp_config, image_embed_dim = image_embed_dim).to(device)
# Load state dict from saved model
diffusion_prior.load_state_dict(loaded_obj['model'])
return diffusion_prior
def save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim):
# Saving State Dict
print_ribbon('Saving checkpoint')
state_dict = dict(model=model.state_dict(),
optimizer=optimizer.state_dict(),
scaler=scaler.state_dict(),
hparams = config,
image_embed_dim = {"image_embed_dim":image_embed_dim})
torch.save(state_dict, save_path+'/'+str(time.time())+'_saved_model.pth')
# exponential moving average wrapper
class EMA(nn.Module):
@@ -60,6 +105,10 @@ class EMA(nn.Module):
self.register_buffer('initted', torch.Tensor([False]))
self.register_buffer('step', torch.tensor([0.]))
def restore_ema_model_device(self):
device = self.initted.device
self.ema_model.to(device)
def update(self):
self.step += 1
@@ -89,7 +138,83 @@ class EMA(nn.Module):
def __call__(self, *args, **kwargs):
return self.ema_model(*args, **kwargs)
# trainers
# diffusion prior trainer
class DiffusionPriorTrainer(nn.Module):
def __init__(
self,
diffusion_prior,
use_ema = True,
lr = 3e-4,
wd = 1e-2,
max_grad_norm = None,
amp = False,
**kwargs
):
super().__init__()
assert isinstance(diffusion_prior, DiffusionPrior)
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
self.diffusion_prior = diffusion_prior
# exponential moving average
self.use_ema = use_ema
if self.use_ema:
self.ema_diffusion_prior = EMA(diffusion_prior, **ema_kwargs)
# optimizer and mixed precision stuff
self.amp = amp
self.scaler = GradScaler(enabled = amp)
self.optimizer = get_optimizer(
diffusion_prior.parameters(),
lr = lr,
wd = wd,
**kwargs
)
# gradient clipping if needed
self.max_grad_norm = max_grad_norm
def update(self):
if exists(self.max_grad_norm):
self.scaler.unscale_(self.optimizer)
nn.utils.clip_grad_norm_(self.diffusion_prior.parameters(), self.max_grad_norm)
self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.zero_grad()
if self.use_ema:
self.ema_diffusion_prior.update()
@torch.inference_mode()
def p_sample_loop(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs)
@torch.inference_mode()
def sample(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs)
@torch.inference_mode()
def sample_batch_size(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.sample_batch_size(*args, **kwargs)
def forward(
self,
*args,
divisor = 1,
**kwargs
):
with autocast(enabled = self.amp):
loss = self.diffusion_prior(*args, **kwargs)
return self.scaler.scale(loss / divisor)
# decoder trainer
class DecoderTrainer(nn.Module):
def __init__(
@@ -184,6 +309,11 @@ class DecoderTrainer(nn.Module):
if self.use_ema:
self.decoder.unets = trainable_unets # restore original training unets
# cast the ema_model unets back to original device
for ema in self.ema_unets:
ema.restore_ema_model_device()
return output
def forward(

View File

@@ -3,14 +3,15 @@ import copy
from random import choice
from pathlib import Path
from shutil import rmtree
from PIL import Image
import torch
from torch import nn
from PIL import Image
from torchvision.datasets import ImageFolder
import torchvision.transforms as T
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as T
from torchvision.datasets import ImageFolder
from torchvision.utils import make_grid, save_image
from einops import rearrange
@@ -99,6 +100,7 @@ class VQGanVAETrainer(nn.Module):
ema_update_after_step = 2000,
ema_update_every = 10,
apply_grad_penalty_every = 4,
amp = False
):
super().__init__()
assert isinstance(vae, VQGanVAE), 'vae must be instance of VQGanVAE'
@@ -120,6 +122,10 @@ class VQGanVAETrainer(nn.Module):
self.optim = get_optimizer(vae_parameters, lr = lr, wd = wd)
self.discr_optim = get_optimizer(discr_parameters, lr = lr, wd = wd)
self.amp = amp
self.scaler = GradScaler(enabled = amp)
self.discr_scaler = GradScaler(enabled = amp)
# create dataset
self.ds = ImageDataset(folder, image_size = image_size)
@@ -178,20 +184,22 @@ class VQGanVAETrainer(nn.Module):
img = next(self.dl)
img = img.to(device)
loss = self.vae(
img,
return_loss = True,
apply_grad_penalty = apply_grad_penalty
)
with autocast(enabled = self.amp):
loss = self.vae(
img,
return_loss = True,
apply_grad_penalty = apply_grad_penalty
)
self.scaler.scale(loss / self.grad_accum_every).backward()
accum_log(logs, {'loss': loss.item() / self.grad_accum_every})
(loss / self.grad_accum_every).backward()
self.optim.step()
self.scaler.step(self.optim)
self.scaler.update()
self.optim.zero_grad()
# update discriminator
if exists(self.vae.discr):
@@ -200,12 +208,15 @@ class VQGanVAETrainer(nn.Module):
img = next(self.dl)
img = img.to(device)
loss = self.vae(img, return_discr_loss = True)
with autocast(enabled = self.amp):
loss = self.vae(img, return_discr_loss = True)
self.discr_scaler.scale(loss / self.grad_accum_every).backward()
accum_log(logs, {'discr_loss': loss.item() / self.grad_accum_every})
(loss / self.grad_accum_every).backward()
self.discr_optim.step()
self.discr_scaler.step(self.discr_optim)
self.discr_scaler.update()
self.discr_optim.zero_grad()
# log

View File

@@ -331,112 +331,6 @@ class ResBlock(nn.Module):
def forward(self, x):
return self.net(x) + x
# convnext enc dec
class ChanLayerNorm(nn.Module):
def __init__(self, dim, eps = 1e-5):
super().__init__()
self.eps = eps
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
def forward(self, x):
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (var + self.eps).sqrt() * self.g
class ConvNext(nn.Module):
def __init__(self, dim, mult = 4, kernel_size = 3, ds_kernel_size = 7):
super().__init__()
inner_dim = int(dim * mult)
self.net = nn.Sequential(
nn.Conv2d(dim, dim, ds_kernel_size, padding = ds_kernel_size // 2, groups = dim),
ChanLayerNorm(dim),
nn.Conv2d(dim, inner_dim, kernel_size, padding = kernel_size // 2),
nn.GELU(),
nn.Conv2d(inner_dim, dim, kernel_size, padding = kernel_size // 2)
)
def forward(self, x):
return self.net(x) + x
class ConvNextEncDec(nn.Module):
def __init__(
self,
dim,
*,
channels = 3,
layers = 4,
layer_mults = None,
num_blocks = 1,
first_conv_kernel_size = 5,
use_attn = True,
attn_dim_head = 64,
attn_heads = 8,
attn_dropout = 0.,
):
super().__init__()
self.layers = layers
self.encoders = MList([])
self.decoders = MList([])
layer_mults = default(layer_mults, list(map(lambda t: 2 ** t, range(layers))))
assert len(layer_mults) == layers, 'layer multipliers must be equal to designated number of layers'
layer_dims = [dim * mult for mult in layer_mults]
dims = (dim, *layer_dims)
self.encoded_dim = dims[-1]
dim_pairs = zip(dims[:-1], dims[1:])
append = lambda arr, t: arr.append(t)
prepend = lambda arr, t: arr.insert(0, t)
if not isinstance(num_blocks, tuple):
num_blocks = (*((0,) * (layers - 1)), num_blocks)
if not isinstance(use_attn, tuple):
use_attn = (*((False,) * (layers - 1)), use_attn)
assert len(num_blocks) == layers, 'number of blocks config must be equal to number of layers'
assert len(use_attn) == layers
for layer_index, (dim_in, dim_out), layer_num_blocks, layer_use_attn in zip(range(layers), dim_pairs, num_blocks, use_attn):
append(self.encoders, nn.Sequential(nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), leaky_relu()))
prepend(self.decoders, nn.Sequential(nn.ConvTranspose2d(dim_out, dim_in, 4, 2, 1), leaky_relu()))
if layer_use_attn:
prepend(self.decoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
for _ in range(layer_num_blocks):
append(self.encoders, ConvNext(dim_out))
prepend(self.decoders, ConvNext(dim_out))
if layer_use_attn:
append(self.encoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
prepend(self.encoders, nn.Conv2d(channels, dim, first_conv_kernel_size, padding = first_conv_kernel_size // 2))
append(self.decoders, nn.Conv2d(dim, channels, 1))
def get_encoded_fmap_size(self, image_size):
return image_size // (2 ** self.layers)
@property
def last_dec_layer(self):
return self.decoders[-1].weight
def encode(self, x):
for enc in self.encoders:
x = enc(x)
return x
def decode(self, x):
for dec in self.decoders:
x = dec(x)
return x
# vqgan attention layer
class VQGanAttention(nn.Module):
@@ -682,8 +576,6 @@ class VQGanVAE(nn.Module):
enc_dec_klass = ResnetEncDec
elif vae_type == 'vit':
enc_dec_klass = ViTEncDec
elif vae_type == 'convnext':
enc_dec_klass = ConvNextEncDec
else:
raise ValueError(f'{vae_type} not valid')

View File

@@ -10,11 +10,12 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.100',
version = '0.2.12',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',
author_email = 'lucidrains@gmail.com',
long_description_content_type = 'text/markdown',
url = 'https://github.com/lucidrains/dalle2-pytorch',
keywords = [
'artificial intelligence',
@@ -24,19 +25,22 @@ setup(
install_requires=[
'click',
'clip-anytorch',
'coca-pytorch>=0.0.5',
'einops>=0.4',
'einops-exts>=0.0.3',
'embedding-reader',
'kornia>=0.5.4',
'pillow',
'resize-right>=0.0.2',
'rotary-embedding-torch',
'torch>=1.10',
'torchvision',
'tqdm',
'vector-quantize-pytorch',
'webdataset',
'x-clip>=0.5.1',
'youtokentome'
'x-clip>=0.4.4',
'youtokentome',
'webdataset>=0.2.5',
'fsspec>=2022.1.0'
],
classifiers=[
'Development Status :: 4 - Beta',

View File

@@ -7,6 +7,7 @@ import torch
from torch import nn
from embedding_reader import EmbeddingReader
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork
from dalle2_pytorch.train import load_diffusion_model, save_diffusion_model, print_ribbon
from dalle2_pytorch.optimizer import get_optimizer
from torch.cuda.amp import autocast,GradScaler
@@ -41,37 +42,56 @@ def eval_model(model,device,image_reader,text_reader,start,end,batch_size,loss_t
avg_loss = (total_loss / total_samples)
wandb.log({f'{phase} {loss_type}': avg_loss})
def save_model(save_path, state_dict):
# Saving State Dict
print("====================================== Saving checkpoint ======================================")
torch.save(state_dict, save_path+'/'+str(time.time())+'_saved_model.pth')
def report_cosine_sims(diffusion_prior,image_reader,text_reader,train_set_size,NUM_TEST_EMBEDDINGS,device):
diffusion_prior.eval()
def report_cosine_sims(diffusion_prior,image_reader,text_reader,train_set_size,val_set_size,NUM_TEST_EMBEDDINGS,device):
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
tstart = train_set_size+val_set_size
tend = train_set_size+val_set_size+NUM_TEST_EMBEDDINGS
for embt, embi in zip(text_reader(batch_size = NUM_TEST_EMBEDDINGS, start=tstart, end = tend),image_reader(batch_size = NUM_TEST_EMBEDDINGS, start=tstart, end = tend)):
text_embed = torch.tensor(embt[0]).to(device)
text_embed = text_embed / text_embed.norm(dim=1, keepdim=True)
test_text_cond = dict(text_embed = text_embed)
test_image_embeddings = torch.tensor(embi[0]).to(device)
test_image_embeddings = test_image_embeddings / test_image_embeddings.norm(dim=1, keepdim=True)
predicted_image_embeddings = diffusion_prior.p_sample_loop((NUM_TEST_EMBEDDINGS, 768), text_cond = test_text_cond)
predicted_image_embeddings = predicted_image_embeddings / predicted_image_embeddings.norm(dim=1, keepdim=True)
original_similarity = cos(text_embed,test_image_embeddings).cpu().numpy()
predicted_similarity = cos(text_embed,predicted_image_embeddings).cpu().numpy()
wandb.log({"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity)})
wandb.log({"CosineSimilarity(text_embed,predicted_image_embed)":np.mean(predicted_similarity)})
return np.mean(predicted_similarity - original_similarity)
tstart = train_set_size
tend = train_set_size+NUM_TEST_EMBEDDINGS
for embt, embi in zip(text_reader(batch_size=NUM_TEST_EMBEDDINGS, start=tstart, end=tend),
image_reader(batch_size=NUM_TEST_EMBEDDINGS, start=tstart, end=tend)):
# make a copy of the text embeddings for shuffling
text_embed = torch.tensor(embt[0]).to(device)
text_embed_shuffled = text_embed.clone()
# roll the text embeddings to simulate "unrelated" captions
rolled_idx = torch.roll(torch.arange(NUM_TEST_EMBEDDINGS), 1)
text_embed_shuffled = text_embed_shuffled[rolled_idx]
text_embed_shuffled = text_embed_shuffled / \
text_embed_shuffled.norm(dim=1, keepdim=True)
test_text_shuffled_cond = dict(text_embed=text_embed_shuffled)
# prepare the text embedding
text_embed = text_embed / text_embed.norm(dim=1, keepdim=True)
test_text_cond = dict(text_embed=text_embed)
# prepare image embeddings
test_image_embeddings = torch.tensor(embi[0]).to(device)
test_image_embeddings = test_image_embeddings / \
test_image_embeddings.norm(dim=1, keepdim=True)
# predict on the unshuffled text embeddings
predicted_image_embeddings = diffusion_prior.p_sample_loop(
(NUM_TEST_EMBEDDINGS, 768), text_cond=test_text_cond)
predicted_image_embeddings = predicted_image_embeddings / \
predicted_image_embeddings.norm(dim=1, keepdim=True)
# predict on the shuffled embeddings
predicted_unrelated_embeddings = diffusion_prior.p_sample_loop(
(NUM_TEST_EMBEDDINGS, 768), text_cond=test_text_shuffled_cond)
predicted_unrelated_embeddings = predicted_unrelated_embeddings / \
predicted_unrelated_embeddings.norm(dim=1, keepdim=True)
# calculate similarities
original_similarity = cos(
text_embed, test_image_embeddings).cpu().numpy()
predicted_similarity = cos(
text_embed, predicted_image_embeddings).cpu().numpy()
unrelated_similarity = cos(
text_embed, predicted_unrelated_embeddings).cpu().numpy()
predicted_img_similarity = cos(
test_image_embeddings, predicted_image_embeddings).cpu().numpy()
wandb.log({"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity),
"CosineSimilarity(text_embed,predicted_image_embed)":np.mean(predicted_similarity),
"CosineSimilarity(orig_image_embed,predicted_image_embed)":np.mean(predicted_img_similarity),
"CosineSimilarity(text_embed,predicted_unrelated_embed)": np.mean(unrelated_similarity),
"Cosine similarity difference":np.mean(predicted_similarity - original_similarity)})
def train(image_embed_dim,
image_embed_url,
@@ -85,7 +105,6 @@ def train(image_embed_dim,
clip,
dp_condition_on_text_encodings,
dp_timesteps,
dp_l2norm_output,
dp_normformer,
dp_cond_drop_prob,
dpn_depth,
@@ -94,9 +113,15 @@ def train(image_embed_dim,
save_interval,
save_path,
device,
RESUME,
DPRIOR_PATH,
config,
wandb_entity,
wandb_project,
learning_rate=0.001,
max_grad_norm=0.5,
weight_decay=0.01,
dropout=0.05,
amp=False):
# DiffusionPriorNetwork
@@ -105,8 +130,9 @@ def train(image_embed_dim,
depth = dpn_depth,
dim_head = dpn_dim_head,
heads = dpn_heads,
normformer = dp_normformer,
l2norm_output = dp_l2norm_output).to(device)
attn_dropout = dropout,
ff_dropout = dropout,
normformer = dp_normformer).to(device)
# DiffusionPrior with text embeddings and image embeddings pre-computed
diffusion_prior = DiffusionPrior(
@@ -118,16 +144,21 @@ def train(image_embed_dim,
loss_type = dp_loss_type,
condition_on_text_encodings = dp_condition_on_text_encodings).to(device)
# Get image and text embeddings from the servers
print("==============Downloading embeddings - image and text====================")
image_reader = EmbeddingReader(embeddings_folder=image_embed_url, file_format="npy")
text_reader = EmbeddingReader(embeddings_folder=text_embed_url, file_format="npy")
num_data_points = text_reader.count
# Load pre-trained model from DPRIOR_PATH
if RESUME:
diffusion_prior=load_diffusion_model(DPRIOR_PATH,device)
wandb.init( entity=wandb_entity, project=wandb_project, config=config)
# Create save_path if it doesn't exist
if not os.path.exists(save_path):
os.makedirs(save_path)
# Get image and text embeddings from the servers
print_ribbon("Downloading embeddings - image and text")
image_reader = EmbeddingReader(embeddings_folder=image_embed_url, file_format="npy")
text_reader = EmbeddingReader(embeddings_folder=text_embed_url, file_format="npy")
num_data_points = text_reader.count
### Training code ###
scaler = GradScaler(enabled=amp)
optimizer = get_optimizer(diffusion_prior.net.parameters(), wd=weight_decay, lr=learning_rate)
@@ -138,12 +169,15 @@ def train(image_embed_dim,
train_set_size = int(train_percent*num_data_points)
val_set_size = int(val_percent*num_data_points)
eval_start = train_set_size
for _ in range(epochs):
diffusion_prior.train()
for emb_images,emb_text in zip(image_reader(batch_size=batch_size, start=0, end=train_set_size),
text_reader(batch_size=batch_size, start=0, end=train_set_size)):
diffusion_prior.train()
emb_images_tensor = torch.tensor(emb_images[0]).to(device)
emb_text_tensor = torch.tensor(emb_text[0]).to(device)
@@ -158,9 +192,13 @@ def train(image_embed_dim,
if(int(time.time()-t) >= 60*save_interval):
t = time.time()
save_model(
save_diffusion_model(
save_path,
dict(model=diffusion_prior.state_dict(), optimizer=optimizer.state_dict(), scaler=scaler.state_dict()))
diffusion_prior,
optimizer,
scaler,
config,
image_embed_dim)
# Log to wandb
wandb.log({"Training loss": loss.item(),
@@ -170,14 +208,22 @@ def train(image_embed_dim,
# Use NUM_TEST_EMBEDDINGS samples from the test set each time
# Get embeddings from the most recently saved model
if(step % REPORT_METRICS_EVERY) == 0:
diff_cosine_sim = report_cosine_sims(diffusion_prior,
report_cosine_sims(diffusion_prior,
image_reader,
text_reader,
train_set_size,
val_set_size,
NUM_TEST_EMBEDDINGS,
device)
wandb.log({"Cosine similarity difference": diff_cosine_sim})
### Evaluate model(validation run) ###
eval_model(diffusion_prior,
device,
image_reader,
text_reader,
eval_start,
eval_start+NUM_TEST_EMBEDDINGS,
NUM_TEST_EMBEDDINGS,
dp_loss_type,
phase="Validation")
scaler.unscale_(optimizer)
nn.utils.clip_grad_norm_(diffusion_prior.parameters(), max_grad_norm)
@@ -186,11 +232,6 @@ def train(image_embed_dim,
scaler.update()
optimizer.zero_grad()
### Evaluate model(validation run) ###
start = train_set_size
end=start+val_set_size
eval_model(diffusion_prior,device,image_reader,text_reader,start,end,batch_size,dp_loss_type,phase="Validation")
### Test run ###
test_set_size = int(test_percent*train_set_size)
start=train_set_size+val_set_size
@@ -202,7 +243,6 @@ def main():
# Logging
parser.add_argument("--wandb-entity", type=str, default="laion")
parser.add_argument("--wandb-project", type=str, default="diffusion-prior")
parser.add_argument("--wandb-name", type=str, default="laion-dprior")
parser.add_argument("--wandb-dataset", type=str, default="LAION-5B")
parser.add_argument("--wandb-arch", type=str, default="DiffusionPrior")
# URLs for embeddings
@@ -211,6 +251,7 @@ def main():
# Hyperparameters
parser.add_argument("--learning-rate", type=float, default=1.1e-4)
parser.add_argument("--weight-decay", type=float, default=6.02e-2)
parser.add_argument("--dropout", type=float, default=5e-2)
parser.add_argument("--max-grad-norm", type=float, default=0.5)
parser.add_argument("--batch-size", type=int, default=10**4)
parser.add_argument("--num-epochs", type=int, default=5)
@@ -228,7 +269,6 @@ def main():
# DiffusionPrior(dp) parameters
parser.add_argument("--dp-condition-on-text-encodings", type=bool, default=False)
parser.add_argument("--dp-timesteps", type=int, default=100)
parser.add_argument("--dp-l2norm-output", type=bool, default=False)
parser.add_argument("--dp-normformer", type=bool, default=False)
parser.add_argument("--dp-cond-drop-prob", type=float, default=0.1)
parser.add_argument("--dp-loss-type", type=str, default="l2")
@@ -237,22 +277,40 @@ def main():
# Model checkpointing interval(minutes)
parser.add_argument("--save-interval", type=int, default=30)
parser.add_argument("--save-path", type=str, default="./diffusion_prior_checkpoints")
# Saved model path
parser.add_argument("--pretrained-model-path", type=str, default=None)
args = parser.parse_args()
print("Setting up wandb logging... Please wait...")
config = ({"learning_rate": args.learning_rate,
"architecture": args.wandb_arch,
"dataset": args.wandb_dataset,
"weight_decay":args.weight_decay,
"max_gradient_clipping_norm":args.max_grad_norm,
"batch_size":args.batch_size,
"epochs": args.num_epochs,
"diffusion_prior_network":{"depth":args.dpn_depth,
"dim_head":args.dpn_dim_head,
"heads":args.dpn_heads,
"normformer":args.dp_normformer},
"diffusion_prior":{"condition_on_text_encodings": args.dp_condition_on_text_encodings,
"timesteps": args.dp_timesteps,
"cond_drop_prob":args.dp_cond_drop_prob,
"loss_type":args.dp_loss_type,
"clip":args.clip}
})
wandb.init(
entity=args.wandb_entity,
project=args.wandb_project,
config={
"learning_rate": args.learning_rate,
"architecture": args.wandb_arch,
"dataset": args.wandb_dataset,
"epochs": args.num_epochs,
})
RESUME = False
# Check if DPRIOR_PATH exists(saved model path)
DPRIOR_PATH = args.pretrained_model_path
if(DPRIOR_PATH is not None):
RESUME = True
else:
wandb.init(
entity=args.wandb_entity,
project=args.wandb_project,
config=config)
print("wandb logging setup done!")
# Obtain the utilized device.
has_cuda = torch.cuda.is_available()
@@ -273,7 +331,6 @@ def main():
args.clip,
args.dp_condition_on_text_encodings,
args.dp_timesteps,
args.dp_l2norm_output,
args.dp_normformer,
args.dp_cond_drop_prob,
args.dpn_depth,
@@ -282,9 +339,15 @@ def main():
args.save_interval,
args.save_path,
device,
RESUME,
DPRIOR_PATH,
config,
args.wandb_entity,
args.wandb_project,
args.learning_rate,
args.max_grad_norm,
args.weight_decay,
args.dropout,
args.amp)
if __name__ == "__main__":