Compare commits

...

62 Commits

Author SHA1 Message Date
Phil Wang
dab106d4e5 back to no_grad for now, also keep track and restore unet devices in one_unet_in_gpu contextmanager 2022-05-16 09:36:14 -07:00
Phil Wang
bb151ca6b1 unet_number on decoder trainer only needs to be passed in if there is greater than 1 unet, so that unconditional training of a single ddpm is seamless (experiment in progress locally) 2022-05-16 09:17:17 -07:00
zion
4a59dea4cf Migrate to text-conditioned prior training (#95)
* migrate to conditioned prior

* unify reader logic with a wrapper (#1)

* separate out reader logic

* support both training methods

* Update train prior to use embedding wrapper (#3)

* Support Both Methods

* bug fixes

* small bug fixes

* embedding only wrapper bug

* use smaller val perc

* final bug fix for embedding-only

Co-authored-by: nousr <>
2022-05-15 20:16:38 -07:00
Phil Wang
ecf9e8027d make sure classifier free guidance is used only if conditional dropout is present on the DiffusionPrior and Decoder classes. also make sure prior can have a different conditional scale than decoder 2022-05-15 19:09:38 -07:00
Phil Wang
36c5079bd7 LazyLinear is not mature, make users pass in text_embed_dim if text conditioning is turned on 2022-05-15 18:56:52 -07:00
Phil Wang
4a4c7ac9e6 cond drop prob for diffusion prior network should default to 0 2022-05-15 18:47:45 -07:00
Phil Wang
fad7481479 todo 2022-05-15 17:00:25 -07:00
Phil Wang
123658d082 cite Ho et al, since cascading ddpm is now trainable 2022-05-15 16:56:53 -07:00
Phil Wang
11d4e11f10 allow for training unconditional ddpm or cascading ddpms 2022-05-15 16:54:56 -07:00
Phil Wang
99778e12de trainer classes now takes care of auto-casting numpy to torch tensors, and setting correct device based on model parameter devices 2022-05-15 15:25:45 -07:00
Phil Wang
0f0011caf0 todo 2022-05-15 14:28:35 -07:00
Phil Wang
7b7a62044a use eval vs training mode to determine whether to call backprop on trainer forward 2022-05-15 14:20:59 -07:00
Phil Wang
156fe5ed9f final cleanup for the day 2022-05-15 12:38:41 -07:00
Phil Wang
5ec34bebe1 cleanup readme 2022-05-15 12:29:26 -07:00
Phil Wang
8eaacf1ac1 remove indirection 2022-05-15 12:05:45 -07:00
Phil Wang
e66c7b0249 incorrect naming 2022-05-15 11:23:52 -07:00
Phil Wang
f7cd4a0992 product management 2022-05-15 11:21:12 -07:00
Phil Wang
68e7d2f241 make sure gradient accumulation feature works even if all arguments passed in are keyword arguments 2022-05-15 11:16:16 -07:00
Phil Wang
74f222596a remove todo 2022-05-15 11:01:35 -07:00
Phil Wang
aa6772dcff make sure optimizer and scaler is reloaded on resume for training diffusion prior script, move argparse to click 2022-05-15 10:48:10 -07:00
Phil Wang
71d0c4edae cleanup to use diffusion prior trainer 2022-05-15 10:16:05 -07:00
Phil Wang
f7eee09d8b 0.2.30 2022-05-15 09:56:59 -07:00
Phil Wang
89de5af63e experiment tracker agnostic 2022-05-15 09:56:40 -07:00
Phil Wang
4ec6d0ba81 backwards pass is not recommended under the autocast context, per pytorch docs 2022-05-14 18:26:19 -07:00
Phil Wang
aee92dba4a simplify more 2022-05-14 17:16:46 -07:00
Phil Wang
b0cd5f24b6 take care of gradient accumulation automatically for researchers, by passing in a max_batch_size on the decoder or diffusion prior trainer forward 2022-05-14 17:04:09 -07:00
Phil Wang
b494ed81d4 take care of backwards within trainer classes for diffusion prior and decoder, readying to take care of gradient accumulation as well (plus, unsure if loss should be backwards within autocast block) 2022-05-14 15:49:24 -07:00
Phil Wang
ff3474f05c normalize conditioning tokens outside of cross attention blocks 2022-05-14 14:23:52 -07:00
Phil Wang
d5293f19f1 lineup with paper 2022-05-14 13:57:00 -07:00
Phil Wang
e697183849 be able to customize adam eps 2022-05-14 13:55:04 -07:00
Phil Wang
591d37e266 lower default initial learning rate to what Jonathan Ho had in his original repo 2022-05-14 13:22:43 -07:00
Phil Wang
d1f02e8f49 always use sandwich norm for attention layer 2022-05-14 12:13:41 -07:00
Phil Wang
9faab59b23 use post-attn-branch layernorm in attempt to stabilize cross attention conditioning in decoder 2022-05-14 11:58:09 -07:00
Phil Wang
5d27029e98 make sure lowres conditioning image is properly normalized to -1 to 1 for cascading ddpm 2022-05-14 01:23:54 -07:00
Phil Wang
3115fa17b3 fix everything around normalizing images to -1 to 1 for ddpm training automatically 2022-05-14 01:17:11 -07:00
Phil Wang
124d8577c8 move the inverse normalization function called before image embeddings are derived from clip to within the diffusion prior and decoder classes 2022-05-14 00:37:52 -07:00
Phil Wang
2db0c9794c comments 2022-05-12 14:25:20 -07:00
Phil Wang
2277b47ffd make sure learned variance can work for any number of unets in the decoder, defaults to first unet, as suggested was used in the paper 2022-05-12 14:18:15 -07:00
Phil Wang
28b58e568c cleanup in preparation of option for learned variance 2022-05-12 12:04:52 -07:00
Phil Wang
924455d97d align the ema model device back after sampling from the cascading ddpm in the decoder 2022-05-11 19:56:54 -07:00
Phil Wang
6021945fc8 default to l2 loss 2022-05-11 19:24:51 -07:00
Light-V
6f76652d11 fix typo in README.md (#85)
The default config for clip from openai should be ViT-B/32
2022-05-11 13:38:16 -07:00
Phil Wang
3dda2570ed fix amp issue for https://github.com/lucidrains/DALLE2-pytorch/issues/82 2022-05-11 08:21:39 -07:00
Phil Wang
2f3c02dba8 numerical accuracy for noise schedule parameters 2022-05-10 15:28:46 -07:00
Phil Wang
908088cfea wrap up cross embed layer feature 2022-05-10 12:19:34 -07:00
Phil Wang
8dc8a3de0d product management 2022-05-10 11:51:38 -07:00
Phil Wang
35f89556ba bring in the cross embed layer from Crossformer paper for initial convolution in unet 2022-05-10 11:50:38 -07:00
Phil Wang
2b55f753b9 fix new issue with github actions and auto pypi package uploading 2022-05-10 10:51:15 -07:00
Phil Wang
fc8fce38fb make sure cascading DDPM can be trained unconditionally, to ready for CLI one command training for the public 2022-05-10 10:48:10 -07:00
Phil Wang
a1bfb03ba4 project management 2022-05-10 10:13:51 -07:00
Phil Wang
b1e7b5f6bb make sure resnet groups in unet is finely customizable 2022-05-10 10:12:50 -07:00
z
10b905b445 smol typo (#81) 2022-05-10 09:52:50 -07:00
Phil Wang
9b322ea634 patch 2022-05-09 19:46:19 -07:00
Phil Wang
ba64ea45cc 0.2.3 2022-05-09 16:50:31 -07:00
Phil Wang
64f7be1926 some cleanup 2022-05-09 16:50:21 -07:00
Phil Wang
db805e73e1 fix a bug with numerical stability in attention, sorry! 🐛 2022-05-09 16:23:37 -07:00
z
cb07b37970 Ensure Eval Mode In Metric Functions (#79)
* add eval/train toggles

* train/eval flags

* shift train toggle

Co-authored-by: nousr <z@localhost.com>
2022-05-09 16:05:40 -07:00
Phil Wang
a774bfefe2 add attention and feedforward dropouts to train_diffusion_prior script 2022-05-09 13:57:15 -07:00
Phil Wang
2ae57f0cf5 cleanup 2022-05-09 13:51:26 -07:00
Phil Wang
e46eaec817 deal the diffusion prior problem yet another blow 2022-05-09 11:08:52 -07:00
Kumar R
8647cb5e76 Val loss changes, with quite a few other changes. This is in place of the earlier PR(https://github.com/lucidrains/DALLE2-pytorch/pull/67) (#77)
* Val_loss changes - no rebased with lucidrains' master.

* Val Loss changes - now rebased with lucidrains' master

* train_diffusion_prior.py updates

* dalle2_pytorch.py updates

* __init__.py changes

* Update train_diffusion_prior.py

* Update dalle2_pytorch.py

* Update train_diffusion_prior.py

* Update train_diffusion_prior.py

* Update dalle2_pytorch.py

* Update train_diffusion_prior.py

* Update train_diffusion_prior.py

* Update train_diffusion_prior.py

* Update train_diffusion_prior.py

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md
2022-05-09 08:53:29 -07:00
Phil Wang
53c189e46a give more surface area for attention in diffusion prior 2022-05-09 08:08:11 -07:00
11 changed files with 1202 additions and 415 deletions

163
README.md
View File

@@ -508,7 +508,7 @@ To use a pretrained OpenAI CLIP, simply import `OpenAIClipAdapter` and pass it i
import torch
from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder, OpenAIClipAdapter
# openai pretrained clip - defaults to ViT/B-32
# openai pretrained clip - defaults to ViT-B/32
clip = OpenAIClipAdapter()
@@ -706,7 +706,7 @@ mock_image_embed = torch.randn(1, 512).cuda()
images = decoder.sample(mock_image_embed) # (1, 3, 1024, 1024)
```
## Training wrapper (wip)
## Training wrapper
### Decoder Training
@@ -732,8 +732,8 @@ clip = CLIP(
# mock data
text = torch.randint(0, 49408, (4, 256)).cuda()
images = torch.randn(4, 3, 256, 256).cuda()
text = torch.randint(0, 49408, (32, 256)).cuda()
images = torch.randn(32, 3, 256, 256).cuda()
# decoder (with unet)
@@ -774,8 +774,12 @@ decoder_trainer = DecoderTrainer(
)
for unet_number in (1, 2):
loss = decoder_trainer(images, text = text, unet_number = unet_number) # use the decoder_trainer forward
loss.backward()
loss = decoder_trainer(
images,
text = text,
unet_number = unet_number, # which unet to train on
max_batch_size = 4 # gradient accumulation - this sets the maximum batch size in which to do forward and backwards pass - for this example 32 / 4 == 8 times
)
decoder_trainer.update(unet_number) # update the specific unet as well as its exponential moving average
@@ -810,8 +814,8 @@ clip = CLIP(
# mock data
text = torch.randint(0, 49408, (4, 256)).cuda()
images = torch.randn(4, 3, 256, 256).cuda()
text = torch.randint(0, 49408, (32, 256)).cuda()
images = torch.randn(32, 3, 256, 256).cuda()
# prior networks (with transformer)
@@ -838,8 +842,7 @@ diffusion_prior_trainer = DiffusionPriorTrainer(
ema_update_every = 10,
)
loss = diffusion_prior_trainer(text, images)
loss.backward()
loss = diffusion_prior_trainer(text, images, max_batch_size = 4)
diffusion_prior_trainer.update() # this will update the optimizer as well as the exponential moving averaged diffusion prior
# after much of the above three lines in a loop
@@ -848,6 +851,57 @@ diffusion_prior_trainer.update() # this will update the optimizer as well as th
image_embeds = diffusion_prior_trainer.sample(text) # (4, 512) - exponential moving averaged image embeddings
```
## Bonus
### Unconditional Training
The repository also contains the means to train unconditional DDPM model, or even cascading DDPMs. You simply have to set `unconditional = True` in the `Decoder`
ex.
```python
import torch
from dalle2_pytorch import Unet, Decoder
# unet for the cascading ddpm
unet1 = Unet(
dim = 128,
dim_mults=(1, 2, 4, 8)
).cuda()
unet2 = Unet(
dim = 32,
dim_mults = (1, 2, 4, 8, 16)
).cuda()
# decoder, which contains the unets
decoder = Decoder(
unet = (unet1, unet2),
image_sizes = (256, 512), # first unet up to 256px, then second to 512px
timesteps = 1000,
unconditional = True
).cuda()
# mock images (get a lot of this)
images = torch.randn(1, 3, 512, 512).cuda()
# feed images into decoder
for i in (1, 2):
loss = decoder(images, unet_number = i)
loss.backward()
# do the above for many many many many steps
# then it will learn to generate images
images = decoder.sample(batch_size = 2) # (2, 3, 512, 512)
```
## Dataloaders
### Decoder Dataloaders
In order to make loading data simple and efficient, we include some general dataloaders that can be used to train portions of the network.
@@ -892,14 +946,14 @@ dataset = ImageEmbeddingDataset(
)
```
## Scripts
### Scripts (wip)
### Using the `train_diffusion_prior.py` script
#### `train_diffusion_prior.py`
This script allows training the DiffusionPrior on pre-computed text and image embeddings. The working example below elucidates this process.
Please note that the script internally passes text_embed and image_embed to the DiffusionPrior, unlike the example below.
### Usage
#### Usage
```bash
$ python train_diffusion_prior.py
@@ -907,27 +961,50 @@ $ python train_diffusion_prior.py
The most significant parameters for the script are as follows:
--image-embed-url, default = "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/")
- `image-embed-url`, default = `"https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/"`
--text-embed-url, default = "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/")
- `text-embed-url`, default = `"https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/"`
--image-embed-dim, default=768 - 768 corresponds to the ViT iL/14 embedding size,change it to what your chosen ViT generates
- `image-embed-dim`, default = `768` - 768 corresponds to the ViT iL/14 embedding size,change it to what your chosen ViT generates
--learning-rate, default=1.1e-4
- `learning-rate`, default = `1.1e-4`
--weight-decay, default=6.02e-2
- `weight-decay`, default = `6.02e-2`
--max-grad-norm, default=0.5
- `max-grad-norm`, default = `0.5`
--batch-size, default=10 ** 4
- `batch-size`, default = `10 ** 4`
--num-epochs, default=5
- `num-epochs`, default = `5`
--clip, default=None # Signals the prior to use pre-computed embeddings
- `clip`, default = `None` # Signals the prior to use pre-computed embeddings
### Sample wandb run log
#### Loading and Saving the DiffusionPrior model
Please find a sample wandb run log at : https://wandb.ai/laion/diffusion-prior/runs/aul0rhv5?workspace=
Two methods are provided, load_diffusion_model and save_diffusion_model, the names being self-explanatory.
```python
from dalle2_pytorch.train import load_diffusion_model, save_diffusion_model
```
##### Loading
load_diffusion_model(dprior_path, device)
dprior_path : path to saved model(.pth)
device : the cuda device you're running on
##### Saving
save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim)
save_path : path to save at
model : object of Diffusion_Prior
optimizer : optimizer object - see train_diffusion_prior.py for how to create one.
e.g: optimizer = get_optimizer(diffusion_prior.net.parameters(), wd=weight_decay, lr=learning_rate)
scaler : a GradScaler object.
e.g: scaler = GradScaler(enabled=amp)
config : config object created in train_diffusion_prior.py - see file for example.
image_embed_dim - the dimension of the image_embedding
e.g: 768
## CLI (wip)
@@ -966,23 +1043,29 @@ Once built, images will be saved to the same directory the command is invoked
- [x] add convnext backbone for vqgan-vae (in addition to vit [vit-vqgan] + resnet)
- [x] make sure DDPMs can be run with traditional resnet blocks (but leave convnext as an option for experimentation)
- [x] make sure for the latter unets in the cascade, one can train on crops for learning super resolution (constrain the unet to be only convolutions in that case, or allow conv-like attention with rel pos bias)
- [x] offer setting in diffusion prior to split time and image embeddings into multiple tokens, configurable, for more surface area during attention
- [x] make sure resnet hyperparameters can be configurable across unet depth (groups and expansion factor)
- [x] pull logic for training diffusion prior into a class DiffusionPriorTrainer, for eventual script based + CLI based training
- [x] make sure the cascading ddpm in the repository can be trained unconditionally, offer a one-line CLI tool for training on a folder of images
- [x] bring in cross-scale embedding from iclr paper https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/crossformer.py#L14
- [x] cross embed layers for downsampling, as an option
- [x] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a>
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
- [ ] make sure the cascading ddpm in the repository can be trained unconditionally, offer a one-line CLI tool for training on a folder of images
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
- [ ] pull logic for training diffusion prior into a class DiffusionPriorTrainer, for eventual script based + CLI based training
- [ ] train on a toy task, offer in colab
- [ ] think about how best to design a declarative training config that handles preencoding for prior and training of multiple networks in decoder
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
- [ ] bring in cross-scale embedding from iclr paper https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/crossformer.py#L14
- [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824
- [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove
- [ ] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a>
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
- [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783
- [ ] make sure resnet hyperparameters can be configurable across unet depth (groups and expansion factor)
- [ ] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
- [ ] offer setting in diffusion prior to split time and image embeddings into multiple tokens, configurable, for more surface area during attention
- [ ] bring in skip-layer excitatons (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training
- [ ] decoder needs one day worth of refactor for tech debt
- [ ] allow for unet to be able to condition non-cross attention style as well
- [ ] for all model classes with hyperparameters that changes the network architecture, make it requirement that they must expose a config property, and write a simple function that asserts that it restores the object correctly
- [ ] for both diffusion prior and decoder, all exponential moving averaged models needs to be saved and restored as well (as well as the step number)
- [ ] read the paper, figure it out, and build it https://github.com/lucidrains/DALLE2-pytorch/issues/89
## Citations
@@ -1060,4 +1143,24 @@ Once built, images will be saved to the same directory the command is invoked
}
```
```bibtex
@misc{wang2021crossformer,
title = {CrossFormer: A Versatile Vision Transformer Hinging on Cross-scale Attention},
author = {Wenxiao Wang and Lu Yao and Long Chen and Binbin Lin and Deng Cai and Xiaofei He and Wei Liu},
year = {2021},
eprint = {2108.00154},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```
```bibtex
@article{ho2021cascaded,
title = {Cascaded Diffusion Models for High Fidelity Image Generation},
author = {Ho, Jonathan and Saharia, Chitwan and Chan, William and Fleet, David J and Norouzi, Mohammad and Salimans, Tim},
journal = {arXiv preprint arXiv:2106.15282},
year = {2021}
}
```
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>

View File

@@ -1,6 +1,6 @@
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter
from dalle2_pytorch.train import DecoderTrainer, DiffusionPriorTrainer
from dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer
from dalle2_pytorch.vqgan_vae import VQGanVAE
from x_clip import CLIP

View File

@@ -1,9 +1,10 @@
import math
from tqdm import tqdm
from inspect import isfunction
from functools import partial
from functools import partial, wraps
from contextlib import contextmanager
from collections import namedtuple
from pathlib import Path
import torch
import torch.nn.functional as F
@@ -32,6 +33,10 @@ from rotary_embedding_torch import RotaryEmbedding
from x_clip import CLIP
from coca_pytorch import CoCa
# constants
NAT = 1. / math.log(2.)
# helper functions
def exists(val):
@@ -40,6 +45,14 @@ def exists(val):
def identity(t, *args, **kwargs):
return t
def maybe(fn):
@wraps(fn)
def inner(x):
if not exists(x):
return x
return fn(x)
return inner
def default(val, d):
if exists(val):
return val
@@ -90,6 +103,9 @@ def freeze_model_and_make_eval_(model):
# tensor helpers
def log(t, eps = 1e-12):
return torch.log(t.clamp(min = eps))
def l2norm(t):
return F.normalize(t, dim = -1)
@@ -106,10 +122,10 @@ def resize_image_to(image, target_image_size):
# ddpms expect images to be in the range of -1 to 1
# but CLIP may otherwise
def normalize_img(img):
def normalize_neg_one_to_one(img):
return img * 2 - 1
def unnormalize_img(normed_img):
def unnormalize_zero_to_one(normed_img):
return (normed_img + 1) * 0.5
# clip related adapters
@@ -270,7 +286,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
def embed_image(self, image):
assert not self.cleared
image = resize_image_to(image, self.image_size)
image = self.clip_normalize(unnormalize_img(image))
image = self.clip_normalize(image)
image_embed = self.clip.encode_image(image)
return EmbeddedImage(l2norm(image_embed.float()), None)
@@ -296,13 +312,43 @@ def noise_like(shape, device, repeat=False):
noise = lambda: torch.randn(shape, device=device)
return repeat_noise() if repeat else noise()
def meanflat(x):
return x.mean(dim = tuple(range(1, len(x.shape))))
def normal_kl(mean1, logvar1, mean2, logvar2):
return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2))
def approx_standard_normal_cdf(x):
return 0.5 * (1.0 + torch.tanh(((2.0 / math.pi) ** 0.5) * (x + 0.044715 * (x ** 3))))
def discretized_gaussian_log_likelihood(x, *, means, log_scales, thres = 0.999):
assert x.shape == means.shape == log_scales.shape
centered_x = x - means
inv_stdv = torch.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 1. / 255.)
cdf_plus = approx_standard_normal_cdf(plus_in)
min_in = inv_stdv * (centered_x - 1. / 255.)
cdf_min = approx_standard_normal_cdf(min_in)
log_cdf_plus = log(cdf_plus)
log_one_minus_cdf_min = log(1. - cdf_min)
cdf_delta = cdf_plus - cdf_min
log_probs = torch.where(x < -thres,
log_cdf_plus,
torch.where(x > thres,
log_one_minus_cdf_min,
log(cdf_delta)))
return log_probs
def cosine_beta_schedule(timesteps, s = 0.008):
"""
cosine schedule
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
"""
steps = timesteps + 1
x = torch.linspace(0, timesteps, steps)
x = torch.linspace(0, timesteps, steps, dtype = torch.float64)
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
@@ -313,21 +359,21 @@ def linear_beta_schedule(timesteps):
scale = 1000 / timesteps
beta_start = scale * 0.0001
beta_end = scale * 0.02
return torch.linspace(beta_start, beta_end, timesteps)
return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64)
def quadratic_beta_schedule(timesteps):
scale = 1000 / timesteps
beta_start = scale * 0.0001
beta_end = scale * 0.02
return torch.linspace(beta_start**2, beta_end**2, timesteps) ** 2
return torch.linspace(beta_start**2, beta_end**2, timesteps, dtype = torch.float64) ** 2
def sigmoid_beta_schedule(timesteps):
scale = 1000 / timesteps
beta_start = scale * 0.0001
beta_end = scale * 0.02
betas = torch.linspace(-6, 6, timesteps)
betas = torch.linspace(-6, 6, timesteps, dtype = torch.float64)
return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
@@ -367,17 +413,21 @@ class BaseGaussianDiffusion(nn.Module):
self.loss_type = loss_type
self.loss_fn = loss_fn
self.register_buffer('betas', betas)
self.register_buffer('alphas_cumprod', alphas_cumprod)
self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
# register buffer helper function to cast double back to float
register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))
register_buffer('betas', betas)
register_buffer('alphas_cumprod', alphas_cumprod)
register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
# calculations for posterior q(x_{t-1} | x_t, x_0)
@@ -385,19 +435,13 @@ class BaseGaussianDiffusion(nn.Module):
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.register_buffer('posterior_variance', posterior_variance)
register_buffer('posterior_variance', posterior_variance)
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
def q_mean_variance(self, x_start, t):
mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
variance = extract(1. - self.alphas_cumprod, t, x_start.shape)
log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
return mean, variance, log_variance
register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
def q_posterior(self, x_start, x_t, t):
posterior_mean = (
@@ -570,7 +614,6 @@ class Attention(nn.Module):
heads = 8,
dropout = 0.,
causal = False,
post_norm = False,
rotary_emb = None
):
super().__init__()
@@ -580,7 +623,6 @@ class Attention(nn.Module):
self.causal = causal
self.norm = LayerNorm(dim)
self.post_norm = LayerNorm(dim) # sandwich norm from Coqview paper + Normformer
self.dropout = nn.Dropout(dropout)
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
@@ -591,7 +633,7 @@ class Attention(nn.Module):
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim, bias = False),
LayerNorm(dim) if post_norm else nn.Identity()
LayerNorm(dim)
)
def forward(self, x, mask = None, attn_bias = None):
@@ -639,7 +681,7 @@ class Attention(nn.Module):
# attention
sim = sim - sim.amax(dim = -1, keepdim = True)
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
attn = sim.softmax(dim = -1)
attn = self.dropout(attn)
@@ -648,8 +690,7 @@ class Attention(nn.Module):
out = einsum('b h i j, b j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out)
return self.post_norm(out)
return self.to_out(out)
class CausalTransformer(nn.Module):
def __init__(
@@ -675,7 +716,7 @@ class CausalTransformer(nn.Module):
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, post_norm = normformer, rotary_emb = rotary_emb),
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, rotary_emb = rotary_emb),
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
]))
@@ -703,10 +744,31 @@ class DiffusionPriorNetwork(nn.Module):
self,
dim,
num_timesteps = None,
num_time_embeds = 1,
num_image_embeds = 1,
num_text_embeds = 1,
**kwargs
):
super().__init__()
self.time_embeddings = nn.Embedding(num_timesteps, dim) if exists(num_timesteps) else nn.Sequential(SinusoidalPosEmb(dim), MLP(dim, dim)) # also offer a continuous version of timestep embeddings, with a 2 layer MLP
self.num_time_embeds = num_time_embeds
self.num_image_embeds = num_image_embeds
self.num_text_embeds = num_text_embeds
self.to_text_embeds = nn.Sequential(
nn.Linear(dim, dim * num_text_embeds) if num_text_embeds > 1 else nn.Identity(),
Rearrange('b (n d) -> b n d', n = num_text_embeds)
)
self.to_time_embeds = nn.Sequential(
nn.Embedding(num_timesteps, dim * num_time_embeds) if exists(num_timesteps) else nn.Sequential(SinusoidalPosEmb(dim), MLP(dim, dim * num_time_embeds)), # also offer a continuous version of timestep embeddings, with a 2 layer MLP
Rearrange('b (n d) -> b n d', n = num_time_embeds)
)
self.to_image_embeds = nn.Sequential(
nn.Linear(dim, dim * num_image_embeds) if num_image_embeds > 1 else nn.Identity(),
Rearrange('b (n d) -> b n d', n = num_image_embeds)
)
self.learned_query = nn.Parameter(torch.randn(dim))
self.causal_transformer = CausalTransformer(dim = dim, **kwargs)
@@ -732,14 +794,17 @@ class DiffusionPriorNetwork(nn.Module):
text_embed,
text_encodings = None,
mask = None,
cond_drop_prob = 0.2
cond_drop_prob = 0.
):
batch, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype
num_time_embeds, num_image_embeds, num_text_embeds = self.num_time_embeds, self.num_image_embeds, self.num_text_embeds
# in section 2.2, last paragraph
# "... consisting of encoded text, CLIP text embedding, diffusion timestep embedding, noised CLIP image embedding, final embedding for prediction"
text_embed, image_embed = rearrange_many((text_embed, image_embed), 'b d -> b 1 d')
text_embed = self.to_text_embeds(text_embed)
image_embed = self.to_image_embeds(image_embed)
# make text encodings optional
# although the paper seems to suggest it is present <--
@@ -759,16 +824,17 @@ class DiffusionPriorNetwork(nn.Module):
# whether text embedding is masked or not depends on the classifier free guidance conditional masking
keep_mask = repeat(keep_mask, 'b 1 -> b n', n = num_text_embeds)
mask = torch.cat((mask, keep_mask), dim = 1)
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
# but let's just do it right
if exists(mask):
mask = F.pad(mask, (0, 3), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
attend_padding = 1 + num_time_embeds + num_image_embeds # 1 for learned queries + number of image embeds + time embeds
mask = F.pad(mask, (0, attend_padding), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
time_embed = self.time_embeddings(diffusion_timesteps)
time_embed = rearrange(time_embed, 'b d -> b 1 d')
time_embed = self.to_time_embeds(diffusion_timesteps)
learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch)
@@ -801,7 +867,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
image_channels = 3,
timesteps = 1000,
cond_drop_prob = 0.,
loss_type = "l1",
loss_type = "l2",
predict_x_start = True,
beta_schedule = "cosine",
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
@@ -835,6 +901,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
self.channels = default(image_channels, lambda: clip.image_channels)
self.cond_drop_prob = cond_drop_prob
self.can_classifier_guidance = cond_drop_prob > 0.
self.condition_on_text_encodings = condition_on_text_encodings
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
@@ -848,8 +915,10 @@ class DiffusionPrior(BaseGaussianDiffusion):
self.training_clamp_l2norm = training_clamp_l2norm
self.init_image_embed_l2norm = init_image_embed_l2norm
def p_mean_variance(self, x, t, text_cond, clip_denoised: bool):
pred = self.net(x, t, **text_cond)
def p_mean_variance(self, x, t, text_cond, clip_denoised = False, cond_scale = 1.):
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the model was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
pred = self.net.forward_with_cond_scale(x, t, cond_scale = cond_scale, **text_cond)
if self.predict_x_start:
x_recon = pred
@@ -867,17 +936,17 @@ class DiffusionPrior(BaseGaussianDiffusion):
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
return model_mean, posterior_variance, posterior_log_variance
@torch.inference_mode()
def p_sample(self, x, t, text_cond = None, clip_denoised = True, repeat_noise = False):
@torch.no_grad()
def p_sample(self, x, t, text_cond = None, clip_denoised = True, repeat_noise = False, cond_scale = 1.):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, text_cond = text_cond, clip_denoised = clip_denoised)
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, text_cond = text_cond, clip_denoised = clip_denoised, cond_scale = cond_scale)
noise = noise_like(x.shape, device, repeat_noise)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.inference_mode()
def p_sample_loop(self, shape, text_cond):
@torch.no_grad()
def p_sample_loop(self, shape, text_cond, cond_scale = 1.):
device = self.betas.device
b = shape[0]
@@ -888,7 +957,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
times = torch.full((b,), i, device = device, dtype = torch.long)
image_embed = self.p_sample(image_embed, times, text_cond = text_cond)
image_embed = self.p_sample(image_embed, times, text_cond = text_cond, cond_scale = cond_scale)
return image_embed
@@ -912,21 +981,21 @@ class DiffusionPrior(BaseGaussianDiffusion):
loss = self.loss_fn(pred, target)
return loss
@torch.inference_mode()
@torch.no_grad()
@eval_decorator
def sample_batch_size(self, batch_size, text_cond):
def sample_batch_size(self, batch_size, text_cond, cond_scale = 1.):
device = self.betas.device
shape = (batch_size, self.image_embed_dim)
img = torch.randn(shape, device = device)
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
img = self.p_sample(img, torch.full((batch_size,), i, device = device, dtype = torch.long), text_cond = text_cond)
img = self.p_sample(img, torch.full((batch_size,), i, device = device, dtype = torch.long), text_cond = text_cond, cond_scale = cond_scale)
return img
@torch.inference_mode()
@torch.no_grad()
@eval_decorator
def sample(self, text, num_samples_per_batch = 2):
def sample(self, text, num_samples_per_batch = 2, cond_scale = 1.):
# in the paper, what they did was
# sample 2 image embeddings, choose the top 1 similarity, as judged by CLIP
text = repeat(text, 'b ... -> (b r) ...', r = num_samples_per_batch)
@@ -941,7 +1010,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
if self.condition_on_text_encodings:
text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask}
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond)
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond, cond_scale = cond_scale)
# retrieve original unscaled image embed
@@ -1097,6 +1166,7 @@ class CrossAttention(nn.Module):
dim_head = 64,
heads = 8,
dropout = 0.,
norm_context = False
):
super().__init__()
self.scale = dim_head ** -0.5
@@ -1106,13 +1176,17 @@ class CrossAttention(nn.Module):
context_dim = default(context_dim, dim)
self.norm = LayerNorm(dim)
self.norm_context = LayerNorm(context_dim)
self.norm_context = LayerNorm(context_dim) if norm_context else nn.Identity()
self.dropout = nn.Dropout(dropout)
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim, bias = False),
LayerNorm(dim)
)
def forward(self, x, context, mask = None):
b, n, device = *x.shape[:2], x.device
@@ -1141,7 +1215,7 @@ class CrossAttention(nn.Module):
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, max_neg_value)
sim = sim - sim.amax(dim = -1, keepdim = True)
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
attn = sim.softmax(dim = -1)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
@@ -1202,12 +1276,39 @@ class LinearAttention(nn.Module):
out = self.nonlin(out)
return self.to_out(out)
class CrossEmbedLayer(nn.Module):
def __init__(
self,
dim_in,
kernel_sizes,
dim_out = None,
stride = 2
):
super().__init__()
assert all([*map(lambda t: (t % 2) == (stride % 2), kernel_sizes)])
dim_out = default(dim_out, dim_in)
kernel_sizes = sorted(kernel_sizes)
num_scales = len(kernel_sizes)
# calculate the dimension at each scale
dim_scales = [int(dim_out / (2 ** i)) for i in range(1, num_scales)]
dim_scales = [*dim_scales, dim_out - sum(dim_scales)]
self.convs = nn.ModuleList([])
for kernel, dim_scale in zip(kernel_sizes, dim_scales):
self.convs.append(nn.Conv2d(dim_in, dim_scale, kernel, stride = stride, padding = (kernel - stride) // 2))
def forward(self, x):
fmaps = tuple(map(lambda conv: conv(x), self.convs))
return torch.cat(fmaps, dim = 1)
class Unet(nn.Module):
def __init__(
self,
dim,
*,
image_embed_dim,
image_embed_dim = None,
text_embed_dim = None,
cond_dim = None,
num_image_tokens = 4,
@@ -1215,6 +1316,7 @@ class Unet(nn.Module):
out_dim = None,
dim_mults=(1, 2, 4, 8),
channels = 3,
channels_out = None,
attn_dim_head = 32,
attn_heads = 16,
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
@@ -1225,8 +1327,10 @@ class Unet(nn.Module):
cond_on_image_embeds = False,
init_dim = None,
init_conv_kernel_size = 7,
block_type = 'resnet',
block_resnet_groups = 8,
resnet_groups = 8,
init_cross_embed_kernel_sizes = (3, 7, 15),
cross_embed_downsample = False,
cross_embed_downsample_kernel_sizes = (2, 4),
**kwargs
):
super().__init__()
@@ -1243,12 +1347,12 @@ class Unet(nn.Module):
# determine dimensions
self.channels = channels
self.channels_out = default(channels_out, channels)
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
init_dim = default(init_dim, dim // 2)
init_dim = default(init_dim, dim // 3 * 2)
assert (init_conv_kernel_size % 2) == 1
self.init_conv = nn.Conv2d(init_channels, init_dim, init_conv_kernel_size, padding = init_conv_kernel_size // 2)
self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1)
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
@@ -1276,14 +1380,18 @@ class Unet(nn.Module):
self.image_to_cond = nn.Sequential(
nn.Linear(image_embed_dim, cond_dim * num_image_tokens),
Rearrange('b (n d) -> b n d', n = num_image_tokens)
) if image_embed_dim != cond_dim else nn.Identity()
) if cond_on_image_embeds and image_embed_dim != cond_dim else nn.Identity()
self.norm_cond = nn.LayerNorm(cond_dim)
self.norm_mid_cond = nn.LayerNorm(cond_dim)
# text encoding conditioning (optional)
self.text_to_cond = None
if cond_on_text_encodings:
self.text_to_cond = nn.LazyLinear(cond_dim) if not exists(text_embed_dim) else nn.Linear(text_embed_dim, cond_dim)
assert exists(text_embed_dim), 'text_embed_dim must be given to the unet if cond_on_text_encodings is True'
self.text_to_cond = nn.Linear(text_embed_dim, cond_dim)
# finer control over whether to condition on image embeddings and text encodings
# so one can have the latter unets in the cascading DDPMs only focus on super-resoluting
@@ -1304,7 +1412,15 @@ class Unet(nn.Module):
# resnet block klass
block_klass = partial(ResnetBlock, groups = block_resnet_groups)
resnet_groups = cast_tuple(resnet_groups, len(in_out))
assert len(resnet_groups) == len(in_out)
# downsample klass
downsample_klass = Downsample
if cross_embed_downsample:
downsample_klass = partial(CrossEmbedLayer, kernel_sizes = cross_embed_downsample_kernel_sizes)
# layers
@@ -1312,39 +1428,38 @@ class Unet(nn.Module):
self.ups = nn.ModuleList([])
num_resolutions = len(in_out)
for ind, (dim_in, dim_out) in enumerate(in_out):
for ind, ((dim_in, dim_out), groups) in enumerate(zip(in_out, resnet_groups)):
is_first = ind == 0
is_last = ind >= (num_resolutions - 1)
layer_cond_dim = cond_dim if not is_first else None
self.downs.append(nn.ModuleList([
block_klass(dim_in, dim_out, time_cond_dim = time_cond_dim),
ResnetBlock(dim_in, dim_out, time_cond_dim = time_cond_dim, groups = groups),
Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(),
block_klass(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
Downsample(dim_out) if not is_last else nn.Identity()
ResnetBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
downsample_klass(dim_out) if not is_last else nn.Identity()
]))
mid_dim = dims[-1]
self.mid_block1 = block_klass(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim)
self.mid_block1 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
self.mid_block2 = block_klass(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim)
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
for ind, ((dim_in, dim_out), groups) in enumerate(zip(reversed(in_out[1:]), reversed(resnet_groups))):
is_last = ind >= (num_resolutions - 2)
layer_cond_dim = cond_dim if not is_last else None
self.ups.append(nn.ModuleList([
block_klass(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
ResnetBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
Residual(LinearAttention(dim_in, **attn_kwargs)) if sparse_attn else nn.Identity(),
block_klass(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
ResnetBlock(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
Upsample(dim_in)
]))
out_dim = default(out_dim, channels)
self.final_conv = nn.Sequential(
block_klass(dim, dim),
nn.Conv2d(dim, out_dim, 1)
ResnetBlock(dim, dim, groups = resnet_groups[0]),
nn.Conv2d(dim, self.channels_out, 1)
)
# if the current settings for the unet are not correct
@@ -1354,12 +1469,25 @@ class Unet(nn.Module):
*,
lowres_cond,
channels,
cond_on_image_embeds
channels_out,
cond_on_image_embeds,
cond_on_text_encodings
):
if lowres_cond == self.lowres_cond and channels == self.channels and cond_on_image_embeds == self.cond_on_image_embeds:
if lowres_cond == self.lowres_cond and \
channels == self.channels and \
cond_on_image_embeds == self.cond_on_image_embeds and \
cond_on_text_encodings == self.cond_on_text_encodings and \
channels_out == self.channels_out:
return self
updated_kwargs = {'lowres_cond': lowres_cond, 'channels': channels, 'cond_on_image_embeds': cond_on_image_embeds}
updated_kwargs = dict(
lowres_cond = lowres_cond,
channels = channels,
channels_out = channels_out,
cond_on_image_embeds = cond_on_image_embeds,
cond_on_text_encodings = cond_on_text_encodings
)
return self.__class__(**{**self._locals, **updated_kwargs})
def forward_with_cond_scale(
@@ -1424,11 +1552,12 @@ class Unet(nn.Module):
if self.cond_on_image_embeds:
image_tokens = self.image_to_cond(image_embed)
null_image_embed = self.null_image_embed.to(image_tokens.dtype) # for some reason pytorch AMP not working
image_tokens = torch.where(
image_keep_mask,
image_tokens,
self.null_image_embed
null_image_embed
)
# take care of text encodings (optional)
@@ -1452,10 +1581,12 @@ class Unet(nn.Module):
text_mask = rearrange(text_mask, 'b n -> b n 1')
text_keep_mask = text_mask & text_keep_mask
null_text_embed = self.null_text_embed.to(text_tokens.dtype) # for some reason pytorch AMP not working
text_tokens = torch.where(
text_keep_mask,
text_tokens,
self.null_text_embed
null_text_embed
)
# main conditioning tokens (c)
@@ -1470,6 +1601,11 @@ class Unet(nn.Module):
mid_c = c if not exists(text_tokens) else torch.cat((c, text_tokens), dim = -2)
# normalize conditioning tokens
c = self.norm_cond(c)
mid_c = self.norm_mid_cond(mid_c)
# go through the layers of the unet, down and up
hiddens = []
@@ -1543,7 +1679,7 @@ class Decoder(BaseGaussianDiffusion):
timesteps = 1000,
image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5,
loss_type = 'l1',
loss_type = 'l2',
beta_schedule = 'cosine',
predict_x_start = False,
predict_x_start_for_latent_diffusion = False,
@@ -1555,7 +1691,10 @@ class Decoder(BaseGaussianDiffusion):
condition_on_text_encodings = False, # the paper suggested that this didn't do much in the decoder, but i'm allowing the option for experimentation
clip_denoised = True,
clip_x_start = True,
clip_adapter_overrides = dict()
clip_adapter_overrides = dict(),
learned_variance = True,
vb_loss_weight = 0.001,
unconditional = False
):
super().__init__(
beta_schedule = beta_schedule,
@@ -1563,7 +1702,10 @@ class Decoder(BaseGaussianDiffusion):
loss_type = loss_type
)
assert exists(clip) ^ exists(image_size), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)'
self.unconditional = unconditional
assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present'
assert self.unconditional or (exists(clip) ^ exists(image_size)), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)'
self.clip = None
if exists(clip):
@@ -1590,10 +1732,18 @@ class Decoder(BaseGaussianDiffusion):
unets = cast_tuple(unet)
vaes = pad_tuple_to_length(cast_tuple(vae), len(unets), fillvalue = NullVQGanVAE(channels = self.channels))
# whether to use learned variance, defaults to True for the first unet in the cascade, as in paper
learned_variance = pad_tuple_to_length(cast_tuple(learned_variance), len(unets), fillvalue = False)
self.learned_variance = learned_variance
self.vb_loss_weight = vb_loss_weight
# construct unets and vaes
self.unets = nn.ModuleList([])
self.vaes = nn.ModuleList([])
for ind, (one_unet, one_vae) in enumerate(zip(unets, vaes)):
for ind, (one_unet, one_vae, one_unet_learned_var) in enumerate(zip(unets, vaes, learned_variance)):
assert isinstance(one_unet, Unet)
assert isinstance(one_vae, (VQGanVAE, NullVQGanVAE))
@@ -1601,11 +1751,14 @@ class Decoder(BaseGaussianDiffusion):
latent_dim = one_vae.encoded_dim if exists(one_vae) else None
unet_channels = default(latent_dim, self.channels)
unet_channels_out = unet_channels * (1 if not one_unet_learned_var else 2)
one_unet = one_unet.cast_model_parameters(
lowres_cond = not is_first,
cond_on_image_embeds = is_first,
channels = unet_channels
cond_on_image_embeds = is_first and not unconditional,
cond_on_text_encodings = one_unet.cond_on_text_encodings and not unconditional,
channels = unet_channels,
channels_out = unet_channels_out
)
self.unets.append(one_unet)
@@ -1643,6 +1796,7 @@ class Decoder(BaseGaussianDiffusion):
self.image_cond_drop_prob = image_cond_drop_prob
self.text_cond_drop_prob = text_cond_drop_prob
self.can_classifier_guidance = image_cond_drop_prob > 0. or text_cond_drop_prob > 0.
# whether to clip when sampling
@@ -1662,14 +1816,23 @@ class Decoder(BaseGaussianDiffusion):
unet = self.get_unet(unet_number)
self.cuda()
devices = [next(unet.parameters()).device for unet in self.unets]
self.unets.cpu()
unet.cuda()
yield
unet.cpu()
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, cond_scale = 1.):
pred = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
yield
for unet, device in zip(self.unets, devices):
unet.to(device)
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None):
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img))
if learned_variance:
pred, var_interp_frac_unnormalized = pred.chunk(2, dim = 1)
if predict_x_start:
x_recon = pred
@@ -1680,24 +1843,38 @@ class Decoder(BaseGaussianDiffusion):
x_recon.clamp_(-1., 1.)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
if learned_variance:
# if learned variance, posterio variance and posterior log variance are predicted by the network
# by an interpolation of the max and min log beta values
# eq 15 - https://arxiv.org/abs/2102.09672
min_log = extract(self.posterior_log_variance_clipped, t, x.shape)
max_log = extract(torch.log(self.betas), t, x.shape)
var_interp_frac = unnormalize_zero_to_one(var_interp_frac_unnormalized)
posterior_log_variance = var_interp_frac * max_log + (1 - var_interp_frac) * min_log
posterior_variance = posterior_log_variance.exp()
return model_mean, posterior_variance, posterior_log_variance
@torch.inference_mode()
def p_sample(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, clip_denoised = True, repeat_noise = False):
@torch.no_grad()
def p_sample(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True, repeat_noise = False):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start)
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, learned_variance = learned_variance)
noise = noise_like(x.shape, device, repeat_noise)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.inference_mode()
def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1):
@torch.no_grad()
def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1):
device = self.betas.device
b = shape[0]
img = torch.randn(shape, device = device)
lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img)
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
img = self.p_sample(
unet,
@@ -1709,17 +1886,26 @@ class Decoder(BaseGaussianDiffusion):
cond_scale = cond_scale,
lowres_cond_img = lowres_cond_img,
predict_x_start = predict_x_start,
learned_variance = learned_variance,
clip_denoised = clip_denoised
)
return img
unnormalize_img = unnormalize_zero_to_one(img)
return unnormalize_img
def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None):
def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False):
noise = default(noise, lambda: torch.randn_like(x_start))
# normalize to [-1, 1]
x_start = normalize_neg_one_to_one(x_start)
lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img)
# get x_t
x_noisy = self.q_sample(x_start = x_start, t = times, noise = noise)
pred = unet(
model_output = unet(
x_noisy,
times,
image_embed = image_embed,
@@ -1730,21 +1916,63 @@ class Decoder(BaseGaussianDiffusion):
text_cond_drop_prob = self.text_cond_drop_prob,
)
if learned_variance:
pred, _ = model_output.chunk(2, dim = 1)
else:
pred = model_output
target = noise if not predict_x_start else x_start
loss = self.loss_fn(pred, target)
return loss
@torch.inference_mode()
if not learned_variance:
# return simple loss if not using learned variance
return loss
# most of the code below is transcribed from
# https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/diffusion_utils_2.py
# the Improved DDPM paper then further modified it so that the mean is detached (shown a couple lines before), and weighted to be smaller than the l1 or l2 "simple" loss
# it is questionable whether this is really needed, looking at some of the figures in the paper, but may as well stay faithful to their implementation
# if learning the variance, also include the extra weight kl loss
true_mean, _, true_log_variance_clipped = self.q_posterior(x_start = x_start, x_t = x_noisy, t = times)
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x_noisy, t = times, image_embed = image_embed, clip_denoised = clip_denoised, learned_variance = True, model_output = model_output)
# kl loss with detached model predicted mean, for stability reasons as in paper
detached_model_mean = model_mean.detach()
kl = normal_kl(true_mean, true_log_variance_clipped, detached_model_mean, model_log_variance)
kl = meanflat(kl) * NAT
decoder_nll = -discretized_gaussian_log_likelihood(x_start, means = detached_model_mean, log_scales = 0.5 * model_log_variance)
decoder_nll = meanflat(decoder_nll) * NAT
# at the first timestep return the decoder NLL, otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
vb_losses = torch.where(times == 0, decoder_nll, kl)
# weight the vb loss smaller, for stability, as in the paper (recommended 0.001)
vb_loss = vb_losses.mean() * self.vb_loss_weight
return loss + vb_loss
@torch.no_grad()
@eval_decorator
def sample(
self,
image_embed,
image_embed = None,
text = None,
batch_size = 1,
cond_scale = 1.,
stop_at_unet_number = None
):
batch_size = image_embed.shape[0]
assert self.unconditional or exists(image_embed), 'image embed must be present on sampling from decoder unless if trained unconditionally'
if not self.unconditional:
batch_size = image_embed.shape[0]
text_encodings = text_mask = None
if exists(text):
@@ -1754,10 +1982,11 @@ class Decoder(BaseGaussianDiffusion):
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
img = None
is_cuda = next(self.parameters()).is_cuda
for unet_number, unet, vae, channel, image_size, predict_x_start in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)):
for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance)):
context = self.one_unet_in_gpu(unet = unet) if image_embed.is_cuda else null_context()
context = self.one_unet_in_gpu(unet = unet) if is_cuda else null_context()
with context:
lowres_cond_img = None
@@ -1770,8 +1999,7 @@ class Decoder(BaseGaussianDiffusion):
image_size = vae.get_encoded_fmap_size(image_size)
shape = (batch_size, vae.encoded_dim, image_size, image_size)
if exists(lowres_cond_img):
lowres_cond_img = vae.encode(lowres_cond_img)
lowres_cond_img = maybe(vae.encode)(lowres_cond_img)
img = self.p_sample_loop(
unet,
@@ -1781,6 +2009,7 @@ class Decoder(BaseGaussianDiffusion):
text_mask = text_mask,
cond_scale = cond_scale,
predict_x_start = predict_x_start,
learned_variance = learned_variance,
clip_denoised = not is_latent_diffusion,
lowres_cond_img = lowres_cond_img
)
@@ -1810,6 +2039,7 @@ class Decoder(BaseGaussianDiffusion):
target_image_size = self.image_sizes[unet_index]
predict_x_start = self.predict_x_start[unet_index]
random_crop_size = self.random_crop_sizes[unet_index]
learned_variance = self.learned_variance[unet_index]
b, c, h, w, device, = *image.shape, image.device
check_shape(image, 'b c h w', c = self.channels)
@@ -1817,12 +2047,12 @@ class Decoder(BaseGaussianDiffusion):
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
if not exists(image_embed):
if not exists(image_embed) and not self.unconditional:
assert exists(self.clip), 'if you want to derive CLIP image embeddings automatically, you must supply `clip` to the decoder on init'
image_embed, _ = self.clip.embed_image(image)
text_encodings = text_mask = None
if exists(text) and not exists(text_encodings):
if exists(text) and not exists(text_encodings) and not self.unconditional:
assert exists(self.clip), 'if you are passing in raw text, you need to supply `clip` to the decoder'
_, text_encodings, text_mask = self.clip.embed_text(text)
@@ -1843,11 +2073,9 @@ class Decoder(BaseGaussianDiffusion):
vae.eval()
with torch.no_grad():
image = vae.encode(image)
lowres_cond_img = maybe(vae.encode)(lowres_cond_img)
if exists(lowres_cond_img):
lowres_cond_img = vae.encode(lowres_cond_img)
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start)
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance)
# main class
@@ -1870,12 +2098,13 @@ class DALLE2(nn.Module):
self.to_pil = T.ToPILImage()
@torch.inference_mode()
@torch.no_grad()
@eval_decorator
def forward(
self,
text,
cond_scale = 1.,
prior_cond_scale = 1.,
return_pil_images = False
):
device = next(self.parameters()).device
@@ -1885,7 +2114,7 @@ class DALLE2(nn.Module):
text = [text] if not isinstance(text, (list, tuple)) else text
text = tokenizer.tokenize(text).to(device)
image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples)
image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples, cond_scale = prior_cond_scale)
text_cond = text if self.decoder_need_text_cond else None
images = self.decoder.sample(image_embed, text = text_cond, cond_scale = cond_scale)

View File

@@ -1 +1,2 @@
from dalle2_pytorch.dataloaders.decoder_loader import ImageEmbeddingDataset, create_image_embedding_dataloader
from dalle2_pytorch.dataloaders.decoder_loader import ImageEmbeddingDataset, create_image_embedding_dataloader
from dalle2_pytorch.dataloaders.embedding_wrapper import make_splits

View File

@@ -0,0 +1,180 @@
from torch.utils.data import IterableDataset
from torch import from_numpy
from clip import tokenize
from embedding_reader import EmbeddingReader
class PriorEmbeddingLoader(IterableDataset):
def __init__(
self,
text_conditioned: bool,
batch_size: int,
start: int,
stop: int,
image_reader,
text_reader: EmbeddingReader = None,
device: str = "cpu",
) -> None:
super(PriorEmbeddingLoader).__init__()
self.text_conditioned = text_conditioned
if not self.text_conditioned:
self.text_reader = text_reader
self.image_reader = image_reader
self.batch_size = batch_size
self.start = start
self.stop = stop
self.device = device
def __iter__(self):
self.n = 0
loader_args = dict(
batch_size=self.batch_size,
start=self.start,
end=self.stop,
show_progress=False,
)
if self.text_conditioned:
self.loader = self.image_reader(**loader_args)
else:
self.loader = zip(
self.image_reader(**loader_args), self.text_reader(**loader_args)
)
return self
def __next__(self):
try:
return self.get_sample()
except StopIteration:
raise StopIteration
def get_sample(self):
"""
pre-proocess data from either reader into a common format
"""
self.n += 1
if self.text_conditioned:
image_embedding, caption = next(self.loader)
image_embedding = from_numpy(image_embedding).to(self.device)
tokenized_caption = tokenize(
caption["caption"].to_list(), truncate=True
).to(self.device)
return image_embedding, tokenized_caption
else:
(image_embedding, _), (text_embedding, _) = next(self.loader)
image_embedding = from_numpy(image_embedding).to(self.device)
text_embedding = from_numpy(text_embedding).to(self.device)
return image_embedding, text_embedding
def make_splits(
text_conditioned: bool,
batch_size: int,
num_data_points: int,
train_split: float,
eval_split: float,
device: str,
img_url: str,
meta_url: str = None,
txt_url: str = None,
):
assert img_url is not None, "Must supply some image embeddings"
if text_conditioned:
assert meta_url is not None, "Must supply metadata url if text-conditioning"
image_reader = EmbeddingReader(
embeddings_folder=img_url,
file_format="parquet_npy",
meta_columns=["caption"],
metadata_folder=meta_url,
)
# compute split points
if num_data_points > image_reader.count:
print("Specified point count is larger than the number of points available...defaulting to max length of reader.")
num_data_points = image_reader.count
train_set_size = int(train_split * num_data_points)
eval_set_size = int(eval_split * num_data_points)
eval_stop = int(train_set_size + eval_set_size)
train_loader = PriorEmbeddingLoader(
text_conditioned=text_conditioned,
image_reader=image_reader,
batch_size=batch_size,
start=0,
stop=train_set_size,
device=device,
)
eval_loader = PriorEmbeddingLoader(
text_conditioned=text_conditioned,
image_reader=image_reader,
batch_size=batch_size,
start=train_set_size,
stop=eval_stop,
device=device,
)
test_loader = PriorEmbeddingLoader(
text_conditioned=text_conditioned,
image_reader=image_reader,
batch_size=batch_size,
start=eval_stop,
stop=int(num_data_points),
device=device,
)
else:
assert (
txt_url is not None
), "Must supply text embedding url if not text-conditioning"
image_reader = EmbeddingReader(img_url, file_format="npy")
text_reader = EmbeddingReader(txt_url, file_format="npy")
# compute split points
if num_data_points > image_reader.count:
print("Specified point count is larger than the number of points available...defaulting to max length of reader.")
num_data_points = image_reader.count
train_set_size = int(train_split * num_data_points)
eval_set_size = int(eval_split * num_data_points)
eval_stop = int(train_set_size + eval_set_size)
train_loader = PriorEmbeddingLoader(
text_conditioned=text_conditioned,
image_reader=image_reader,
text_reader=text_reader,
batch_size=batch_size,
start=0,
stop=train_set_size,
device=device,
)
eval_loader = PriorEmbeddingLoader(
text_conditioned=text_conditioned,
image_reader=image_reader,
text_reader=text_reader,
batch_size=batch_size,
start=train_set_size,
stop=eval_stop,
device=device,
)
test_loader = PriorEmbeddingLoader(
text_conditioned=text_conditioned,
image_reader=image_reader,
text_reader=text_reader,
batch_size=batch_size,
start=eval_stop,
stop=int(num_data_points),
device=device,
)
return train_loader, eval_loader, test_loader

View File

@@ -7,16 +7,17 @@ def separate_weight_decayable_params(params):
def get_optimizer(
params,
lr = 3e-4,
lr = 2e-5,
wd = 1e-2,
betas = (0.9, 0.999),
eps = 1e-8,
filter_by_requires_grad = False
):
if filter_by_requires_grad:
params = list(filter(lambda t: t.requires_grad, params))
if wd == 0:
return Adam(params, lr = lr, betas = betas)
return Adam(params, lr = lr, betas = betas, eps = eps)
params = set(params)
wd_params, no_wd_params = separate_weight_decayable_params(params)
@@ -26,4 +27,4 @@ def get_optimizer(
{'params': list(no_wd_params), 'weight_decay': 0},
]
return AdamW(param_groups, lr = lr, weight_decay = wd, betas = betas)
return AdamW(param_groups, lr = lr, weight_decay = wd, betas = betas, eps = eps)

View File

@@ -0,0 +1,49 @@
import os
import torch
from torch import nn
# helper functions
def exists(val):
return val is not None
# base class
class BaseTracker(nn.Module):
def __init__(self):
super().__init__()
def init(self, config, **kwargs):
raise NotImplementedError
def log(self, log, **kwargs):
raise NotImplementedError
# basic stdout class
class ConsoleTracker(BaseTracker):
def init(self, **config):
print(config)
def log(self, log, **kwargs):
print(log)
# basic wandb class
class WandbTracker(BaseTracker):
def __init__(self):
super().__init__()
try:
import wandb
except ImportError as e:
print('`pip install wandb` to use the wandb experiment tracker')
raise e
os.environ["WANDB_SILENT"] = "true"
self.wandb = wandb
def init(self, **config):
self.wandb.init(**config)
def log(self, log, **kwargs):
self.wandb.log(log, **kwargs)

View File

@@ -1,5 +1,8 @@
import time
import copy
from functools import partial
from math import ceil
from functools import partial, wraps
from collections.abc import Iterable
import torch
from torch import nn
@@ -8,11 +11,16 @@ from torch.cuda.amp import autocast, GradScaler
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
from dalle2_pytorch.optimizer import get_optimizer
import numpy as np
# helper functions
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def cast_tuple(val, length = 1):
return val if isinstance(val, tuple) else ((val,) * length)
@@ -39,13 +47,130 @@ def groupby_prefix_and_trim(prefix, d):
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
return kwargs_without_prefix, kwargs
# decorators
def cast_torch_tensor(fn):
@wraps(fn)
def inner(model, *args, **kwargs):
device = kwargs.pop('_device', next(model.parameters()).device)
cast_device = kwargs.pop('_cast_device', True)
kwargs_keys = kwargs.keys()
all_args = (*args, *kwargs.values())
split_kwargs_index = len(all_args) - len(kwargs_keys)
all_args = tuple(map(lambda t: torch.from_numpy(t) if exists(t) and isinstance(t, np.ndarray) else t, all_args))
if cast_device:
all_args = tuple(map(lambda t: t.to(device) if exists(t) and isinstance(t, torch.Tensor) else t, all_args))
args, kwargs_values = all_args[:split_kwargs_index], all_args[split_kwargs_index:]
kwargs = dict(tuple(zip(kwargs_keys, kwargs_values)))
out = fn(model, *args, **kwargs)
return out
return inner
# gradient accumulation functions
def split_iterable(it, split_size):
accum = []
for ind in range(ceil(len(it) / split_size)):
start_index = ind * split_size
accum.append(it[start_index: (start_index + split_size)])
return accum
def split(t, split_size = None):
if not exists(split_size):
return t
if isinstance(t, torch.Tensor):
return t.split(split_size, dim = 0)
if isinstance(t, Iterable):
return split_iterable(t, split_size)
return TypeError
def find_first(cond, arr):
for el in arr:
if cond(el):
return el
return None
def split_args_and_kwargs(*args, split_size = None, **kwargs):
all_args = (*args, *kwargs.values())
len_all_args = len(all_args)
first_tensor = find_first(lambda t: isinstance(t, torch.Tensor), all_args)
assert exists(first_tensor)
batch_size = len(first_tensor)
split_size = default(split_size, batch_size)
num_chunks = ceil(batch_size / split_size)
dict_len = len(kwargs)
dict_keys = kwargs.keys()
split_kwargs_index = len_all_args - dict_len
split_all_args = [split(arg, split_size = split_size) if exists(arg) and isinstance(arg, (torch.Tensor, Iterable)) else ((arg,) * num_chunks) for arg in all_args]
chunk_sizes = tuple(map(len, split_all_args[0]))
for (chunk_size, *chunked_all_args) in tuple(zip(chunk_sizes, *split_all_args)):
chunked_args, chunked_kwargs_values = chunked_all_args[:split_kwargs_index], chunked_all_args[split_kwargs_index:]
chunked_kwargs = dict(tuple(zip(dict_keys, chunked_kwargs_values)))
chunk_size_frac = chunk_size / batch_size
yield chunk_size_frac, (chunked_args, chunked_kwargs)
# print helpers
def print_ribbon(s, symbol = '=', repeat = 40):
flank = symbol * repeat
return f'{flank} {s} {flank}'
# saving and loading functions
# for diffusion prior
def load_diffusion_model(dprior_path, device):
dprior_path = Path(dprior_path)
assert dprior_path.exists(), 'Dprior model file does not exist'
loaded_obj = torch.load(str(dprior_path), map_location='cpu')
# Get hyperparameters of loaded model
dpn_config = loaded_obj['hparams']['diffusion_prior_network']
dp_config = loaded_obj['hparams']['diffusion_prior']
image_embed_dim = loaded_obj['image_embed_dim']['image_embed_dim']
# Create DiffusionPriorNetwork and DiffusionPrior with loaded hyperparameters
# DiffusionPriorNetwork
prior_network = DiffusionPriorNetwork( dim = image_embed_dim, **dpn_config).to(device)
# DiffusionPrior with text embeddings and image embeddings pre-computed
diffusion_prior = DiffusionPrior(net = prior_network, **dp_config, image_embed_dim = image_embed_dim).to(device)
# Load state dict from saved model
diffusion_prior.load_state_dict(loaded_obj['model'])
return diffusion_prior, loaded_obj
def save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim):
# Saving State Dict
print_ribbon('Saving checkpoint')
state_dict = dict(model=model.state_dict(),
optimizer=optimizer.state_dict(),
scaler=scaler.state_dict(),
hparams = config,
image_embed_dim = {"image_embed_dim":image_embed_dim})
torch.save(state_dict, save_path+'/'+str(time.time())+'_saved_model.pth')
# exponential moving average wrapper
class EMA(nn.Module):
def __init__(
self,
model,
beta = 0.99,
beta = 0.9999,
update_after_step = 1000,
update_every = 10,
):
@@ -60,6 +185,10 @@ class EMA(nn.Module):
self.register_buffer('initted', torch.Tensor([False]))
self.register_buffer('step', torch.tensor([0.]))
def restore_ema_model_device(self):
device = self.initted.device
self.ema_model.to(device)
def update(self):
self.step += 1
@@ -98,6 +227,7 @@ class DiffusionPriorTrainer(nn.Module):
use_ema = True,
lr = 3e-4,
wd = 1e-2,
eps = 1e-6,
max_grad_norm = None,
amp = False,
**kwargs
@@ -124,6 +254,7 @@ class DiffusionPriorTrainer(nn.Module):
diffusion_prior.parameters(),
lr = lr,
wd = wd,
eps = eps,
**kwargs
)
@@ -131,6 +262,8 @@ class DiffusionPriorTrainer(nn.Module):
self.max_grad_norm = max_grad_norm
self.register_buffer('step', torch.tensor([0.]))
def update(self):
if exists(self.max_grad_norm):
self.scaler.unscale_(self.optimizer)
@@ -143,27 +276,42 @@ class DiffusionPriorTrainer(nn.Module):
if self.use_ema:
self.ema_diffusion_prior.update()
@torch.inference_mode()
self.step += 1
@torch.no_grad()
@cast_torch_tensor
def p_sample_loop(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs)
@torch.inference_mode()
@torch.no_grad()
@cast_torch_tensor
def sample(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs)
@torch.inference_mode()
@torch.no_grad()
def sample_batch_size(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.sample_batch_size(*args, **kwargs)
@cast_torch_tensor
def forward(
self,
*args,
divisor = 1,
max_batch_size = None,
**kwargs
):
with autocast(enabled = self.amp):
loss = self.diffusion_prior(*args, **kwargs)
return self.scaler.scale(loss / divisor)
total_loss = 0.
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
with autocast(enabled = self.amp):
loss = self.diffusion_prior(*chunked_args, **chunked_kwargs)
loss = loss * chunk_size_frac
total_loss += loss.item()
if self.training:
self.scaler.scale(loss).backward()
return total_loss
# decoder trainer
@@ -172,8 +320,9 @@ class DecoderTrainer(nn.Module):
self,
decoder,
use_ema = True,
lr = 3e-4,
lr = 2e-5,
wd = 1e-2,
eps = 1e-8,
max_grad_norm = None,
amp = False,
**kwargs
@@ -186,11 +335,6 @@ class DecoderTrainer(nn.Module):
self.num_unets = len(self.decoder.unets)
self.use_ema = use_ema
if use_ema:
has_lazy_linear = any([type(module) == nn.LazyLinear for module in decoder.modules()])
assert not has_lazy_linear, 'you must set the text_embed_dim on your u-nets if you plan on doing automatic exponential moving average'
self.ema_unets = nn.ModuleList([])
self.amp = amp
@@ -198,13 +342,14 @@ class DecoderTrainer(nn.Module):
# be able to finely customize learning rate, weight decay
# per unet
lr, wd = map(partial(cast_tuple, length = self.num_unets), (lr, wd))
lr, wd, eps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps))
for ind, (unet, unet_lr, unet_wd) in enumerate(zip(self.decoder.unets, lr, wd)):
for ind, (unet, unet_lr, unet_wd, unet_eps) in enumerate(zip(self.decoder.unets, lr, wd, eps)):
optimizer = get_optimizer(
unet.parameters(),
lr = unet_lr,
wd = unet_wd,
eps = unet_eps,
**kwargs
)
@@ -220,6 +365,8 @@ class DecoderTrainer(nn.Module):
self.max_grad_norm = max_grad_norm
self.register_buffer('step', torch.tensor([0.]))
@property
def unets(self):
return nn.ModuleList([ema.ema_model for ema in self.ema_unets])
@@ -230,8 +377,11 @@ class DecoderTrainer(nn.Module):
scaler = getattr(self, f'scaler{index}')
return scaler.scale(loss)
def update(self, unet_number):
assert 1 <= unet_number <= self.num_unets
def update(self, unet_number = None):
if self.num_unets == 1:
unet_number = default(unet_number, 1)
assert exists(unet_number) and 1 <= unet_number <= self.num_unets
index = unet_number - 1
unet = self.decoder.unets[index]
@@ -250,7 +400,10 @@ class DecoderTrainer(nn.Module):
ema_unet = self.ema_unets[index]
ema_unet.update()
self.step += 1
@torch.no_grad()
@cast_torch_tensor
def sample(self, *args, **kwargs):
if self.use_ema:
trainable_unets = self.decoder.unets
@@ -260,16 +413,34 @@ class DecoderTrainer(nn.Module):
if self.use_ema:
self.decoder.unets = trainable_unets # restore original training unets
# cast the ema_model unets back to original device
for ema in self.ema_unets:
ema.restore_ema_model_device()
return output
@cast_torch_tensor
def forward(
self,
x,
*,
unet_number,
divisor = 1,
*args,
unet_number = None,
max_batch_size = None,
**kwargs
):
with autocast(enabled = self.amp):
loss = self.decoder(x, unet_number = unet_number, **kwargs)
return self.scale(loss / divisor, unet_number = unet_number)
if self.num_unets == 1:
unet_number = default(unet_number, 1)
total_loss = 0.
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
with autocast(enabled = self.amp):
loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs)
loss = loss * chunk_size_frac
total_loss += loss.item()
if self.training:
self.scale(loss, unet_number = unet_number).backward()
return total_loss

View File

@@ -10,11 +10,12 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.1.10',
version = '0.2.40',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',
author_email = 'lucidrains@gmail.com',
long_description_content_type = 'text/markdown',
url = 'https://github.com/lucidrains/dalle2-pytorch',
keywords = [
'artificial intelligence',
@@ -29,6 +30,7 @@ setup(
'einops-exts>=0.0.3',
'embedding-reader',
'kornia>=0.5.4',
'numpy',
'pillow',
'resize-right>=0.0.2',
'rotary-embedding-torch',

View File

@@ -1,324 +1,375 @@
import os
from pathlib import Path
import click
import math
import argparse
import time
import numpy as np
import torch
import clip
from torch import nn
from embedding_reader import EmbeddingReader
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork
from dalle2_pytorch.optimizer import get_optimizer
from torch.cuda.amp import autocast,GradScaler
import time
from dalle2_pytorch.dataloaders import make_splits
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter
from dalle2_pytorch.trainer import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model, print_ribbon
from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker
from embedding_reader import EmbeddingReader
from tqdm import tqdm
import wandb
os.environ["WANDB_SILENT"] = "true"
NUM_TEST_EMBEDDINGS = 100 # for cosine similarity reporting during training
REPORT_METRICS_EVERY = 100 # for cosine similarity and other metric reporting during training
# constants
REPORT_METRICS_EVERY = 250 # for cosine similarity and other metric reporting during training
def eval_model(model,device,image_reader,text_reader,start,end,batch_size,loss_type,phase="Validation"):
tracker = WandbTracker()
# helpers functions
def exists(val):
val is not None
class Timer:
def __init__(self):
self.reset()
def reset(self):
self.last_time = time.time()
def elapsed(self):
return time.time() - self.last_time
# functions
def eval_model(model, dataloader, text_conditioned, loss_type, phase="Validation"):
model.eval()
with torch.no_grad():
total_loss = 0.
total_samples = 0.
for emb_images, emb_text in zip(image_reader(batch_size=batch_size, start=start, end=end),
text_reader(batch_size=batch_size, start=start, end=end)):
for image_embeddings, text_data in tqdm(dataloader):
emb_images_tensor = torch.tensor(emb_images[0]).to(device)
emb_text_tensor = torch.tensor(emb_text[0]).to(device)
batches = image_embeddings.shape[0]
batches = emb_images_tensor.shape[0]
input_args = dict(image_embed=image_embeddings)
if text_conditioned:
input_args = dict(**input_args, text = text_data)
else:
input_args = dict(**input_args, text_embed=text_data)
loss = model(text_embed = emb_text_tensor, image_embed = emb_images_tensor)
loss = model(**input_args)
total_loss += loss.item() * batches
total_loss += loss * batches
total_samples += batches
avg_loss = (total_loss / total_samples)
wandb.log({f'{phase} {loss_type}': avg_loss})
def save_model(save_path, state_dict):
# Saving State Dict
print("====================================== Saving checkpoint ======================================")
torch.save(state_dict, save_path+'/'+str(time.time())+'_saved_model.pth')
tracker.log({f'{phase} {loss_type}': avg_loss})
def report_cosine_sims(diffusion_prior, dataloader, text_conditioned):
diffusion_prior.eval()
def report_cosine_sims(diffusion_prior, image_reader, text_reader, train_set_size, val_set_size, NUM_TEST_EMBEDDINGS, device):
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
tstart = train_set_size+val_set_size
tend = train_set_size+val_set_size+NUM_TEST_EMBEDDINGS
for test_image_embeddings, text_data in tqdm(dataloader):
# we are text conditioned, we produce an embedding from the tokenized text
if text_conditioned:
text_embedding, text_encodings, text_mask = diffusion_prior.clip.embed_text(
text_data)
text_cond = dict(text_embed=text_embedding,
text_encodings=text_encodings, mask=text_mask)
else:
text_embedding = text_data
text_cond = dict(text_embed=text_embedding)
for embt, embi in zip(text_reader(batch_size=NUM_TEST_EMBEDDINGS, start=tstart, end=tend), image_reader(batch_size=NUM_TEST_EMBEDDINGS, start=tstart, end=tend)):
# make a copy of the text embeddings for shuffling
text_embed = torch.tensor(embt[0]).to(device)
text_embed_shuffled = text_embed.clone()
text_embed_shuffled = text_embedding.clone()
# roll the text embeddings to simulate "unrelated" captions
rolled_idx = torch.roll(torch.arange(NUM_TEST_EMBEDDINGS), 1)
# roll the text to simulate "unrelated" captions
rolled_idx = torch.roll(torch.arange(text_embedding.shape[0]), 1)
text_embed_shuffled = text_embed_shuffled[rolled_idx]
text_embed_shuffled = text_embed_shuffled / \
text_embed_shuffled.norm(dim=1, keepdim=True)
test_text_shuffled_cond = dict(text_embed=text_embed_shuffled)
if text_conditioned:
text_encodings_shuffled = text_encodings[rolled_idx]
text_mask_shuffled = text_mask[rolled_idx]
else:
text_encodings_shuffled = None
text_mask_shuffled = None
text_cond_shuffled = dict(text_embed=text_embed_shuffled,
text_encodings=text_encodings_shuffled, mask=text_mask_shuffled)
# prepare the text embedding
text_embed = text_embed / text_embed.norm(dim=1, keepdim=True)
test_text_cond = dict(text_embed=text_embed)
text_embed = text_embedding / text_embedding.norm(dim=1, keepdim=True)
# prepare image embeddings
test_image_embeddings = torch.tensor(embi[0]).to(device)
test_image_embeddings = test_image_embeddings / \
test_image_embeddings.norm(dim=1, keepdim=True)
# predict on the unshuffled text embeddings
predicted_image_embeddings = diffusion_prior.p_sample_loop(
(NUM_TEST_EMBEDDINGS, 768), text_cond=test_text_cond)
test_image_embeddings.shape, text_cond)
predicted_image_embeddings = predicted_image_embeddings / \
predicted_image_embeddings.norm(dim=1, keepdim=True)
# predict on the shuffled embeddings
predicted_unrelated_embeddings = diffusion_prior.p_sample_loop(
(NUM_TEST_EMBEDDINGS, 768), text_cond=test_text_shuffled_cond)
test_image_embeddings.shape, text_cond_shuffled)
predicted_unrelated_embeddings = predicted_unrelated_embeddings / \
predicted_unrelated_embeddings.norm(dim=1, keepdim=True)
# calculate similarities
original_similarity = cos(
text_embed, test_image_embeddings).cpu().numpy()
text_embed, test_image_embeddings).cpu().numpy()
predicted_similarity = cos(
text_embed, predicted_image_embeddings).cpu().numpy()
text_embed, predicted_image_embeddings).cpu().numpy()
unrelated_similarity = cos(
text_embed, predicted_unrelated_embeddings).cpu().numpy()
text_embed, predicted_unrelated_embeddings).cpu().numpy()
predicted_img_similarity = cos(
test_image_embeddings, predicted_image_embeddings).cpu().numpy()
wandb.log(
{"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity)})
wandb.log({"CosineSimilarity(text_embed,predicted_image_embed)": np.mean(
predicted_similarity)})
wandb.log({"CosineSimilarity(text_embed,predicted_unrelated_embed)": np.mean(
unrelated_similarity)})
wandb.log({"CosineSimilarity(image_embed,predicted_image_embed)": np.mean(
predicted_img_similarity)})
return np.mean(predicted_similarity - original_similarity)
test_image_embeddings, predicted_image_embeddings).cpu().numpy()
tracker.log({"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity),
"CosineSimilarity(text_embed,predicted_image_embed)":np.mean(predicted_similarity),
"CosineSimilarity(orig_image_embed,predicted_image_embed)":np.mean(predicted_img_similarity),
"CosineSimilarity(text_embed,predicted_unrelated_embed)": np.mean(unrelated_similarity),
"Cosine similarity difference":np.mean(predicted_similarity - original_similarity)})
@click.command()
@click.option("--wandb-entity", default="laion")
@click.option("--wandb-project", default="diffusion-prior")
@click.option("--wandb-dataset", default="LAION-5B")
@click.option("--wandb-arch", default="DiffusionPrior")
@click.option("--image-embed-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/")
@click.option("--text-embed-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/")
@click.option("--meta-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/laion2B-en-metadata/")
@click.option("--learning-rate", default=1.1e-4)
@click.option("--weight-decay", default=6.02e-2)
@click.option("--dropout", default=5e-2)
@click.option("--max-grad-norm", default=0.5)
@click.option("--num-data-points", default=250e6)
@click.option("--batch-size", default=320)
@click.option("--num-epochs", default=5)
@click.option("--image-embed-dim", default=768)
@click.option("--train-percent", default=0.9)
@click.option("--val-percent", default=1e-7)
@click.option("--test-percent", default=0.0999999)
@click.option("--dpn-depth", default=12)
@click.option("--dpn-dim-head", default=64)
@click.option("--dpn-heads", default=12)
@click.option("--dp-condition-on-text-encodings", default=True)
@click.option("--dp-timesteps", default=1000)
@click.option("--dp-normformer", default=True)
@click.option("--dp-cond-drop-prob", default=0.1)
@click.option("--dp-loss-type", default="l2")
@click.option("--clip", default="ViT-L/14")
@click.option("--amp", default=False)
@click.option("--save-interval", default=120)
@click.option("--save-path", default="./diffusion_prior_checkpoints")
@click.option("--pretrained-model-path", default=None)
@click.option("--gpu-device", default=0)
def train(
wandb_entity,
wandb_project,
wandb_dataset,
wandb_arch,
image_embed_url,
text_embed_url,
meta_url,
learning_rate,
weight_decay,
dropout,
max_grad_norm,
num_data_points,
batch_size,
num_epochs,
image_embed_dim,
train_percent,
val_percent,
test_percent,
dpn_depth,
dpn_dim_head,
dpn_heads,
dp_condition_on_text_encodings,
dp_timesteps,
dp_normformer,
dp_cond_drop_prob,
dp_loss_type,
clip,
amp,
save_interval,
save_path,
pretrained_model_path,
gpu_device
):
config = {
"learning_rate": learning_rate,
"architecture": wandb_arch,
"dataset": wandb_dataset,
"weight_decay": weight_decay,
"max_gradient_clipping_norm": max_grad_norm,
"batch_size": batch_size,
"epochs": num_epochs,
"diffusion_prior_network": {
"depth": dpn_depth,
"dim_head": dpn_dim_head,
"heads": dpn_heads,
"normformer": dp_normformer
},
"diffusion_prior": {
"condition_on_text_encodings": dp_condition_on_text_encodings,
"timesteps": dp_timesteps,
"cond_drop_prob": dp_cond_drop_prob,
"loss_type": dp_loss_type,
"clip": clip
}
}
def train(image_embed_dim,
image_embed_url,
text_embed_url,
batch_size,
train_percent,
val_percent,
test_percent,
num_epochs,
dp_loss_type,
clip,
dp_condition_on_text_encodings,
dp_timesteps,
dp_normformer,
dp_cond_drop_prob,
dpn_depth,
dpn_dim_head,
dpn_heads,
save_interval,
save_path,
device,
learning_rate=0.001,
max_grad_norm=0.5,
weight_decay=0.01,
amp=False):
# Check if DPRIOR_PATH exists(saved model path)
DPRIOR_PATH = pretrained_model_path
RESUME = exists(DPRIOR_PATH)
if not RESUME:
tracker.init(
entity = wandb_entity,
project = wandb_project,
config = config
)
# Obtain the utilized device.
has_cuda = torch.cuda.is_available()
if has_cuda:
device = torch.device(f"cuda:{gpu_device}")
torch.cuda.set_device(device)
# Training loop
# diffusion prior network
# DiffusionPriorNetwork
prior_network = DiffusionPriorNetwork(
dim = image_embed_dim,
depth = dpn_depth,
dim_head = dpn_dim_head,
heads = dpn_heads,
normformer = dp_normformer).to(device)
dim = image_embed_dim,
depth = dpn_depth,
dim_head = dpn_dim_head,
heads = dpn_heads,
attn_dropout = dropout,
ff_dropout = dropout,
normformer = dp_normformer
)
# DiffusionPrior with text embeddings and image embeddings pre-computed
diffusion_prior = DiffusionPrior(
net = prior_network,
clip = clip,
image_embed_dim = image_embed_dim,
timesteps = dp_timesteps,
cond_drop_prob = dp_cond_drop_prob,
loss_type = dp_loss_type,
condition_on_text_encodings = dp_condition_on_text_encodings).to(device)
# Load clip model if text-conditioning
if dp_condition_on_text_encodings:
clip_adapter = OpenAIClipAdapter(clip)
else:
clip_adapter = None
# diffusion prior with text embeddings and image embeddings pre-computed
# Get image and text embeddings from the servers
print("==============Downloading embeddings - image and text====================")
image_reader = EmbeddingReader(embeddings_folder=image_embed_url, file_format="npy")
text_reader = EmbeddingReader(embeddings_folder=text_embed_url, file_format="npy")
num_data_points = text_reader.count
diffusion_prior = DiffusionPrior(
net = prior_network,
clip = clip_adapter,
image_embed_dim = image_embed_dim,
timesteps = dp_timesteps,
cond_drop_prob = dp_cond_drop_prob,
loss_type = dp_loss_type,
condition_on_text_encodings = dp_condition_on_text_encodings
)
# Load pre-trained model from DPRIOR_PATH
if RESUME:
diffusion_prior, loaded_obj = load_diffusion_model(DPRIOR_PATH, device)
tracker.init(entity = wandb_entity, project = wandb_project, config = config)
# diffusion prior trainer
trainer = DiffusionPriorTrainer(
diffusion_prior = diffusion_prior,
lr = learning_rate,
wd = weight_decay,
max_grad_norm = max_grad_norm,
amp = amp,
).to(device)
# load optimizer and scaler
if RESUME:
trainer.optimizer.load_state_dict(loaded_obj['optimizer'])
trainer.scaler.load_state_dict(loaded_obj['scaler'])
# Create save_path if it doesn't exist
if not os.path.exists(save_path):
os.makedirs(save_path)
Path(save_path).mkdir(exist_ok = True, parents = True)
# Utilize wrapper to abstract away loader logic
print_ribbon("Downloading Embeddings")
loader_args = dict(text_conditioned=dp_condition_on_text_encodings, batch_size=batch_size, num_data_points=num_data_points,
train_split=train_percent, eval_split=val_percent, device=device, img_url=image_embed_url)
if dp_condition_on_text_encodings:
loader_args = dict(**loader_args, meta_url=meta_url)
else:
loader_args = dict(**loader_args, txt_url=text_embed_url)
train_loader, eval_loader, test_loader = make_splits(**loader_args)
### Training code ###
scaler = GradScaler(enabled=amp)
optimizer = get_optimizer(diffusion_prior.net.parameters(), wd=weight_decay, lr=learning_rate)
step = 1
timer = Timer()
epochs = num_epochs
step = 0
t = time.time()
train_set_size = int(train_percent*num_data_points)
val_set_size = int(val_percent*num_data_points)
for _ in range(epochs):
diffusion_prior.train()
for emb_images,emb_text in zip(image_reader(batch_size=batch_size, start=0, end=train_set_size),
text_reader(batch_size=batch_size, start=0, end=train_set_size)):
emb_images_tensor = torch.tensor(emb_images[0]).to(device)
emb_text_tensor = torch.tensor(emb_text[0]).to(device)
for image, text in tqdm(train_loader):
diffusion_prior.train()
input_args = dict(image_embed=image)
if dp_condition_on_text_encodings:
input_args = dict(**input_args, text = text)
else:
input_args = dict(**input_args, text_embed=text)
with autocast(enabled=amp):
loss = diffusion_prior(text_embed = emb_text_tensor,image_embed = emb_images_tensor)
scaler.scale(loss).backward()
loss = trainer(**input_args)
# Samples per second
step+=1
samples_per_sec = batch_size*step/(time.time()-t)
# Save checkpoint every save_interval minutes
if(int(time.time()-t) >= 60*save_interval):
t = time.time()
save_model(
samples_per_sec = batch_size * step / timer.elapsed()
# Save checkpoint every save_interval minutes
if(int(timer.elapsed()) >= 60 * save_interval):
timer.reset()
save_diffusion_model(
save_path,
dict(model=diffusion_prior.state_dict(), optimizer=optimizer.state_dict(), scaler=scaler.state_dict()))
diffusion_prior,
trainer.optimizer,
trainer.scaler,
config,
image_embed_dim)
# Log to wandb
wandb.log({"Training loss": loss.item(),
tracker.log({"Training loss": loss,
"Steps": step,
"Samples per second": samples_per_sec})
# Log cosineSim(text_embed,predicted_image_embed) - cosineSim(text_embed,image_embed)
# Use NUM_TEST_EMBEDDINGS samples from the test set each time
# Get embeddings from the most recently saved model
if(step % REPORT_METRICS_EVERY) == 0:
diff_cosine_sim = report_cosine_sims(diffusion_prior,
image_reader,
text_reader,
train_set_size,
val_set_size,
NUM_TEST_EMBEDDINGS,
device)
wandb.log({"Cosine similarity difference": diff_cosine_sim})
report_cosine_sims(diffusion_prior, eval_loader, dp_condition_on_text_encodings)
### Evaluate model(validation run) ###
eval_model(diffusion_prior, eval_loader, dp_condition_on_text_encodings, dp_loss_type, phase="Validation")
scaler.unscale_(optimizer)
nn.utils.clip_grad_norm_(diffusion_prior.parameters(), max_grad_norm)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
### Evaluate model(validation run) ###
start = train_set_size
end=start+val_set_size
eval_model(diffusion_prior,device,image_reader,text_reader,start,end,batch_size,dp_loss_type,phase="Validation")
step += 1
trainer.update()
### Test run ###
test_set_size = int(test_percent*train_set_size)
start=train_set_size+val_set_size
end=num_data_points
eval_model(diffusion_prior,device,image_reader,text_reader,start,end,batch_size,dp_loss_type,phase="Test")
eval_model(diffusion_prior, test_loader, dp_condition_on_text_encodings, dp_loss_type, phase="Test")
def main():
parser = argparse.ArgumentParser()
# Logging
parser.add_argument("--wandb-entity", type=str, default="laion")
parser.add_argument("--wandb-project", type=str, default="diffusion-prior")
parser.add_argument("--wandb-name", type=str, default="laion-dprior")
parser.add_argument("--wandb-dataset", type=str, default="LAION-5B")
parser.add_argument("--wandb-arch", type=str, default="DiffusionPrior")
# URLs for embeddings
parser.add_argument("--image-embed-url", type=str, default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/")
parser.add_argument("--text-embed-url", type=str, default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/")
# Hyperparameters
parser.add_argument("--learning-rate", type=float, default=1.1e-4)
parser.add_argument("--weight-decay", type=float, default=6.02e-2)
parser.add_argument("--max-grad-norm", type=float, default=0.5)
parser.add_argument("--batch-size", type=int, default=10**4)
parser.add_argument("--num-epochs", type=int, default=5)
# Image embed dimension
parser.add_argument("--image-embed-dim", type=int, default=768)
# Train-test split
parser.add_argument("--train-percent", type=float, default=0.7)
parser.add_argument("--val-percent", type=float, default=0.2)
parser.add_argument("--test-percent", type=float, default=0.1)
# LAION training(pre-computed embeddings)
# DiffusionPriorNetwork(dpn) parameters
parser.add_argument("--dpn-depth", type=int, default=6)
parser.add_argument("--dpn-dim-head", type=int, default=64)
parser.add_argument("--dpn-heads", type=int, default=8)
# DiffusionPrior(dp) parameters
parser.add_argument("--dp-condition-on-text-encodings", type=bool, default=False)
parser.add_argument("--dp-timesteps", type=int, default=100)
parser.add_argument("--dp-l2norm-output", type=bool, default=False)
parser.add_argument("--dp-normformer", type=bool, default=False)
parser.add_argument("--dp-cond-drop-prob", type=float, default=0.1)
parser.add_argument("--dp-loss-type", type=str, default="l2")
parser.add_argument("--clip", type=str, default=None)
parser.add_argument("--amp", type=bool, default=False)
# Model checkpointing interval(minutes)
parser.add_argument("--save-interval", type=int, default=30)
parser.add_argument("--save-path", type=str, default="./diffusion_prior_checkpoints")
args = parser.parse_args()
print("Setting up wandb logging... Please wait...")
wandb.init(
entity=args.wandb_entity,
project=args.wandb_project,
config={
"learning_rate": args.learning_rate,
"architecture": args.wandb_arch,
"dataset": args.wandb_dataset,
"epochs": args.num_epochs,
})
print("wandb logging setup done!")
# Obtain the utilized device.
has_cuda = torch.cuda.is_available()
if has_cuda:
device = torch.device("cuda:0")
torch.cuda.set_device(device)
# Training loop
train(args.image_embed_dim,
args.image_embed_url,
args.text_embed_url,
args.batch_size,
args.train_percent,
args.val_percent,
args.test_percent,
args.num_epochs,
args.dp_loss_type,
args.clip,
args.dp_condition_on_text_encodings,
args.dp_timesteps,
args.dp_normformer,
args.dp_cond_drop_prob,
args.dpn_depth,
args.dpn_dim_head,
args.dpn_heads,
args.save_interval,
args.save_path,
device,
args.learning_rate,
args.max_grad_norm,
args.weight_decay,
args.amp)
if __name__ == "__main__":
main()
train()