Compare commits

...

43 Commits

Author SHA1 Message Date
Phil Wang
c30544b73a no CLIP altogether for training DiffusionPrior 2022-04-26 10:23:41 -07:00
Phil Wang
bdf5e9c009 todo 2022-04-26 09:56:54 -07:00
Phil Wang
9878be760b have researcher explicitly state upfront whether to condition with text encodings in cascading ddpm decoder, have DALLE-2 class take care of passing in text if feature turned on 2022-04-26 09:47:09 -07:00
Phil Wang
7ba6357c05 allow for training the Prior network with precomputed CLIP embeddings (or text encodings) 2022-04-26 09:29:51 -07:00
Phil Wang
76e063e8b7 refactor so that the causal transformer in the diffusion prior network can be conditioned without text encodings (for Laions parallel efforts, although it seems from the paper it is needed) 2022-04-26 09:00:11 -07:00
Phil Wang
4d25976f33 make sure non-latent diffusion still works 2022-04-26 08:36:00 -07:00
Phil Wang
0b28ee0d01 revert back to old upsampling, paper does not work 2022-04-26 07:39:04 -07:00
Phil Wang
45262a4bb7 bring in the exponential moving average wrapper, to get ready for training 2022-04-25 19:24:13 -07:00
Phil Wang
13a58a78c4 scratch off todo 2022-04-25 19:01:30 -07:00
Phil Wang
f75d49c781 start a file for all attention-related modules, use attention-based upsampling in the unets in dalle-2 2022-04-25 18:59:10 -07:00
Phil Wang
3b520dfa85 bring in attention-based upsampling to strengthen vqgan-vae, seems to work as advertised in initial experiments in GAN 2022-04-25 17:27:45 -07:00
Phil Wang
79198c6ae4 keep readme simple for reader 2022-04-25 17:21:45 -07:00
Phil Wang
77a246b1b9 todo 2022-04-25 08:48:28 -07:00
Phil Wang
f93a3f6ed8 reprioritize 2022-04-25 08:44:27 -07:00
Phil Wang
8f2a0c7e00 better naming 2022-04-25 07:44:33 -07:00
Phil Wang
863f4ef243 just take care of the logic for setting all latent diffusion to predict x0, if needed 2022-04-24 10:06:42 -07:00
Phil Wang
fb8a66a2de just in case latent diffusion performs better with prediction of x0 instead of epsilon, open up the research avenue 2022-04-24 10:04:22 -07:00
Phil Wang
579d4b42dd does not seem right to clip for the prior diffusion part 2022-04-24 09:51:18 -07:00
Phil Wang
473808850a some outlines to the eventual CLI endpoint 2022-04-24 09:27:15 -07:00
Phil Wang
d5318aef4f todo 2022-04-23 08:23:08 -07:00
Phil Wang
f82917e1fd prepare for turning off gradient penalty, as shown in GAN literature, GP needs to be only applied 1 out of 4 iterations 2022-04-23 07:52:10 -07:00
Phil Wang
05b74be69a use null container pattern to cleanup some conditionals, save more cleanup for next week 2022-04-22 15:23:18 -07:00
Phil Wang
a8b5d5d753 last tweak of readme 2022-04-22 14:16:43 -07:00
Phil Wang
976ef7f87c project management 2022-04-22 14:15:42 -07:00
Phil Wang
fd175bcc0e readme 2022-04-22 14:13:33 -07:00
Phil Wang
76b32f18b3 first pass at complete DALL-E2 + Latent Diffusion integration, latent diffusion on any layer(s) of the cascading ddpm in the decoder. 2022-04-22 13:53:13 -07:00
Phil Wang
f2d5b87677 todo 2022-04-22 11:39:58 -07:00
Phil Wang
461347c171 fix vqgan-vae for latent diffusion 2022-04-22 11:38:57 -07:00
Phil Wang
46cef31c86 optional projection out for prior network causal transformer 2022-04-22 11:16:30 -07:00
Phil Wang
59b1a77d4d be a bit more conservative and stick with layernorm (without bias) for now, given @borisdayma results https://twitter.com/borisdayma/status/1517227191477571585 2022-04-22 11:14:54 -07:00
Phil Wang
7f338319fd makes more sense for blur augmentation to happen before the upsampling 2022-04-22 11:10:47 -07:00
Phil Wang
2c6c91829d refactor blurring training augmentation to be taken care of by the decoder, with option to downsample to previous resolution before upsampling (cascading ddpm). this opens up the possibility of cascading latent ddpm 2022-04-22 11:09:17 -07:00
Phil Wang
ad17c69ab6 prepare for latent diffusion in the first DDPM of the cascade in the Decoder 2022-04-21 17:54:31 -07:00
Phil Wang
0b4ec34efb todo 2022-04-20 12:24:23 -07:00
Phil Wang
f027b82e38 remove wip as main networks (prior and decoder) are completed 2022-04-20 12:12:16 -07:00
Phil Wang
8cc9016cb0 Merge pull request #17 from kashif/patch-2
added diffusion-gan thoughts
2022-04-20 12:10:26 -07:00
Kashif Rasul
1d8f37befe added diffusion-gan thoughts
https://github.com/NVlabs/denoising-diffusion-gan
2022-04-20 21:01:11 +02:00
Phil Wang
faebf4c8b8 from my vision transformer experience, dimension of attention head of 32 is sufficient for image feature maps 2022-04-20 11:40:32 -07:00
Phil Wang
b8e8d3c164 thoughts 2022-04-20 11:34:51 -07:00
Phil Wang
8e2416b49b commit to generalizing latent diffusion to one model 2022-04-20 11:27:42 -07:00
Phil Wang
f37c26e856 cleanup and DRY a little 2022-04-20 10:56:32 -07:00
Phil Wang
27a33e1b20 complete contextmanager method for keeping only one unet in GPU during training or inference 2022-04-20 10:46:13 -07:00
Phil Wang
6f941a219a give time tokens a surface area of 2 tokens as default, make it so researcher can customize which unet actually is conditioned on image embeddings and/or text encodings 2022-04-20 10:04:47 -07:00
9 changed files with 1434 additions and 170 deletions

287
README.md
View File

@@ -1,6 +1,6 @@
<img src="./dalle2.png" width="450px"></img>
## DALL-E 2 - Pytorch (wip)
## DALL-E 2 - Pytorch
Implementation of <a href="https://openai.com/dall-e-2/">DALL-E 2</a>, OpenAI's updated text-to-image synthesis neural network, in Pytorch.
@@ -10,11 +10,9 @@ The main novelty seems to be an extra layer of indirection with the prior networ
This model is SOTA for text-to-image for now.
It may also explore an extension of using <a href="https://huggingface.co/spaces/multimodalart/latentdiffusion">latent diffusion</a> in the decoder from Rombach et al.
Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord" src="https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white"></a> if you are interested in helping out with the replication
There was enough interest for a Jax version. It will be completed after the Pytorch version shows signs of life on my toy tasks. <a href="https://github.com/lucidrains/dalle2-jax">Placeholder repository</a>. I will also eventually extend this to <a href="https://github.com/lucidrains/dalle2-video">text to video</a>, once the repository is in a good place.
There was enough interest for a <a href="https://github.com/lucidrains/dalle2-jax">Jax version</a>. I will also eventually extend this to <a href="https://github.com/lucidrains/dalle2-video">text to video</a>, once the repository is in a good place.
## Install
@@ -248,13 +246,6 @@ loss = decoder(images, unet_number = 2)
loss.backward()
# do the above for many steps for both unets
# then it will learn to generate images based on the CLIP image embeddings
# chaining the unets from lowest resolution to highest resolution (thus cascading)
mock_image_embed = torch.randn(1, 512).cuda()
images = decoder.sample(mock_image_embed) # (1, 3, 512, 512)
```
Finally, to generate the DALL-E2 images from text. Insert the trained `DiffusionPrior` as well as the `Decoder` (which wraps `CLIP`, the causal transformer, and unet(s))
@@ -357,7 +348,8 @@ decoder = Decoder(
image_sizes = (128, 256),
clip = clip,
timesteps = 100,
cond_drop_prob = 0.2
cond_drop_prob = 0.2,
condition_on_text_encodings = False # set this to True if you wish to condition on text during training and sampling
).cuda()
for unet_number in (1, 2):
@@ -385,7 +377,245 @@ You can also train the decoder on images of greater than the size (say 512x512)
For the layperson, no worries, training will all be automated into a CLI tool, at least for small scale training.
## CLI Usage (work in progress)
## Training on Preprocessed CLIP Embeddings
It is likely, when scaling up, that you would first preprocess your images and text into corresponding embeddings before training the prior network. You can do so easily by simply passing in `image_embed`, `text_embed`, and optionally `text_encodings` and `text_mask`
Working example below
```python
import torch
from dalle2_pytorch import DiffusionPriorNetwork, DiffusionPrior, CLIP
# get trained CLIP from step one
clip = CLIP(
dim_text = 512,
dim_image = 512,
dim_latent = 512,
num_text_tokens = 49408,
text_enc_depth = 6,
text_seq_len = 256,
text_heads = 8,
visual_enc_depth = 6,
visual_image_size = 256,
visual_patch_size = 32,
visual_heads = 8,
).cuda()
# setup prior network, which contains an autoregressive transformer
prior_network = DiffusionPriorNetwork(
dim = 512,
depth = 6,
dim_head = 64,
heads = 8
).cuda()
# diffusion prior network, which contains the CLIP and network (with transformer) above
diffusion_prior = DiffusionPrior(
net = prior_network,
clip = clip,
timesteps = 100,
cond_drop_prob = 0.2,
condition_on_text_encodings = False # this probably should be true, but just to get Laion started
).cuda()
# mock data
text = torch.randint(0, 49408, (4, 256)).cuda()
images = torch.randn(4, 3, 256, 256).cuda()
# precompute the text and image embeddings
# here using the diffusion prior class, but could be done with CLIP alone
clip_image_embeds = diffusion_prior.get_image_embed(images)
clip_text_embeds = diffusion_prior.get_text_cond(text).get('text_embed')
# feed text and images into diffusion prior network
loss = diffusion_prior(
text_embed = clip_text_embeds,
image_embed = clip_image_embeds
)
loss.backward()
# do the above for many many many steps
# now the diffusion prior can generate image embeddings from the text embeddings
```
You can also completely go `CLIP`-less, in which case you will need to pass in the `image_embed_dim` into the `DiffusionPrior` on initialization
```python
import torch
from dalle2_pytorch import DiffusionPriorNetwork, DiffusionPrior
# setup prior network, which contains an autoregressive transformer
prior_network = DiffusionPriorNetwork(
dim = 512,
depth = 6,
dim_head = 64,
heads = 8
).cuda()
# diffusion prior network, which contains the CLIP and network (with transformer) above
diffusion_prior = DiffusionPrior(
net = prior_network,
image_embed_dim = 512, # this needs to be set
timesteps = 100,
cond_drop_prob = 0.2,
condition_on_text_encodings = False # this probably should be true, but just to get Laion started
).cuda()
# mock data
text = torch.randint(0, 49408, (4, 256)).cuda()
images = torch.randn(4, 3, 256, 256).cuda()
# precompute the text and image embeddings
# here using the diffusion prior class, but could be done with CLIP alone
clip_image_embeds = torch.randn(4, 512).cuda()
clip_text_embeds = torch.randn(4, 512).cuda()
# feed text and images into diffusion prior network
loss = diffusion_prior(
text_embed = clip_text_embeds,
image_embed = clip_image_embeds
)
loss.backward()
# do the above for many many many steps
# now the diffusion prior can generate image embeddings from the text embeddings
```
## Experimental
### DALL-E2 with Latent Diffusion
This repository decides to take the next step and offer DALL-E2 combined with <a href="https://huggingface.co/spaces/multimodalart/latentdiffusion">latent diffusion</a>, from Rombach et al.
You can use it as follows. Latent diffusion can be limited to just the first U-Net in the cascade, or to any number you wish.
```python
import torch
from dalle2_pytorch import Unet, Decoder, CLIP, VQGanVAE
# trained clip from step 1
clip = CLIP(
dim_text = 512,
dim_image = 512,
dim_latent = 512,
num_text_tokens = 49408,
text_enc_depth = 1,
text_seq_len = 256,
text_heads = 8,
visual_enc_depth = 1,
visual_image_size = 256,
visual_patch_size = 32,
visual_heads = 8
)
# 3 unets for the decoder (a la cascading DDPM)
# first two unets are doing latent diffusion
# vqgan-vae must be trained before hand
vae1 = VQGanVAE(
dim = 32,
image_size = 256,
layers = 3,
layer_mults = (1, 2, 4)
)
vae2 = VQGanVAE(
dim = 32,
image_size = 512,
layers = 3,
layer_mults = (1, 2, 4)
)
unet1 = Unet(
dim = 32,
image_embed_dim = 512,
cond_dim = 128,
channels = 3,
sparse_attn = True,
sparse_attn_window = 2,
dim_mults = (1, 2, 4, 8)
)
unet2 = Unet(
dim = 32,
image_embed_dim = 512,
channels = 3,
dim_mults = (1, 2, 4, 8, 16),
cond_on_image_embeds = True,
cond_on_text_encodings = False
)
unet3 = Unet(
dim = 32,
image_embed_dim = 512,
channels = 3,
dim_mults = (1, 2, 4, 8, 16),
cond_on_image_embeds = True,
cond_on_text_encodings = False,
attend_at_middle = False
)
# decoder, which contains the unet(s) and clip
decoder = Decoder(
clip = clip,
vae = (vae1, vae2), # latent diffusion for unet1 (vae1) and unet2 (vae2), but not for the last unet3
unet = (unet1, unet2, unet3), # insert unets in order of low resolution to highest resolution (you can have as many stages as you want here)
image_sizes = (256, 512, 1024), # resolutions, 256 for first unet, 512 for second, 1024 for third
timesteps = 100,
cond_drop_prob = 0.2
).cuda()
# mock images (get a lot of this)
images = torch.randn(1, 3, 1024, 1024).cuda()
# feed images into decoder, specifying which unet you want to train
# each unet can be trained separately, which is one of the benefits of the cascading DDPM scheme
with decoder.one_unet_in_gpu(1):
loss = decoder(images, unet_number = 1)
loss.backward()
with decoder.one_unet_in_gpu(2):
loss = decoder(images, unet_number = 2)
loss.backward()
with decoder.one_unet_in_gpu(3):
loss = decoder(images, unet_number = 3)
loss.backward()
# do the above for many steps for both unets
# then it will learn to generate images based on the CLIP image embeddings
# chaining the unets from lowest resolution to highest resolution (thus cascading)
mock_image_embed = torch.randn(1, 512).cuda()
images = decoder.sample(mock_image_embed) # (1, 3, 1024, 1024)
```
## Training wrapper (wip)
Offer training wrappers
## CLI (wip)
```bash
$ dream 'sharing a sunset at the summit of mount everest with my dog'
@@ -393,9 +623,7 @@ $ dream 'sharing a sunset at the summit of mount everest with my dog'
Once built, images will be saved to the same directory the command is invoked
## Training wrapper (wip)
Offer training wrappers
<a href="https://github.com/lucidrains/big-sleep">template</a>
## Training CLI (wip)
@@ -410,11 +638,19 @@ Offer training wrappers
- [x] figure out all the current bag of tricks needed to make DDPMs great (starting with the blur trick mentioned in paper)
- [x] build the cascading ddpm by having Decoder class manage multiple unets at different resolutions
- [x] add efficient attention in unet
- [ ] be able to finely customize what to condition on (text, image embed) for specific unet in the cascade (super resolution ddpms near the end may not need too much conditioning)
- [ ] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately)
- [ ] build out latent diffusion architecture in separate file, as it is not faithful to dalle-2 (but offer it as as setting)
- [x] be able to finely customize what to condition on (text, image embed) for specific unet in the cascade (super resolution ddpms near the end may not need too much conditioning)
- [x] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately)
- [x] build out latent diffusion architecture, with the vq-reg variant (vqgan-vae), make it completely optional and compatible with cascading ddpms
- [x] for decoder, allow ability to customize objective (predict epsilon vs x0), in case latent diffusion does better with prediction of x0
- [x] use attention-based upsampling https://arxiv.org/abs/2112.11435
- [ ] abstract interface for CLIP adapter class, so other CLIPs can be brought in - use inheritance just this once for sharing logic between decoder and prior network ddpms
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
- [ ] train on a toy task, offer in colab
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
- [ ] bring in tools to train vqgan-vae
- [ ] bring in vit-vqgan https://arxiv.org/abs/2110.04627 for the latent diffusion
## Citations
@@ -446,23 +682,12 @@ Offer training wrappers
```bibtex
@inproceedings{Liu2022ACF,
title = {A ConvNet for the 2020s},
title = {A ConvNet for the 2020https://arxiv.org/abs/2112.11435s},
author = {Zhuang Liu and Hanzi Mao and Chaozheng Wu and Christoph Feichtenhofer and Trevor Darrell and Saining Xie},
year = {2022}
}
```
```bibtex
@misc{zhang2019root,
title = {Root Mean Square Layer Normalization},
author = {Biao Zhang and Rico Sennrich},
year = {2019},
eprint = {1910.07467},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
```
```bibtex
@inproceedings{Tu2022MaxViTMV,
title = {MaxViT: Multi-Axis Vision Transformer},

View File

@@ -1,2 +1,4 @@
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
from dalle2_pytorch.vqgan_vae import VQGanVAE
from x_clip import CLIP

125
dalle2_pytorch/attention.py Normal file
View File

@@ -0,0 +1,125 @@
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange, repeat
class LayerNormChan(nn.Module):
def __init__(
self,
dim,
eps = 1e-5
):
super().__init__()
self.eps = eps
self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1))
def forward(self, x):
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (var + self.eps).sqrt() * self.gamma
# attention-based upsampling
# from https://arxiv.org/abs/2112.11435
class QueryAndAttend(nn.Module):
def __init__(
self,
*,
dim,
num_queries = 1,
dim_head = 32,
heads = 8,
window_size = 3
):
super().__init__()
self.scale = dim_head ** -0.5
inner_dim = dim_head * heads
self.heads = heads
self.dim_head = dim_head
self.window_size = window_size
self.num_queries = num_queries
self.rel_pos_bias = nn.Parameter(torch.randn(heads, num_queries, window_size * window_size, 1, 1))
self.queries = nn.Parameter(torch.randn(heads, num_queries, dim_head))
self.to_kv = nn.Conv2d(dim, dim_head * 2, 1, bias = False)
self.to_out = nn.Conv2d(inner_dim, dim, 1, bias = False)
def forward(self, x):
"""
einstein notation
b - batch
h - heads
l - num queries
d - head dimension
x - height
y - width
j - source sequence for attending to (kernel size squared in this case)
"""
wsz, heads, dim_head, num_queries = self.window_size, self.heads, self.dim_head, self.num_queries
batch, _, height, width = x.shape
is_one_query = self.num_queries == 1
# queries, keys, values
q = self.queries * self.scale
k, v = self.to_kv(x).chunk(2, dim = 1)
# similarities
sim = einsum('h l d, b d x y -> b h l x y', q, k)
sim = rearrange(sim, 'b ... x y -> b (...) x y')
# unfold the similarity scores, with float(-inf) as padding value
mask_value = -torch.finfo(sim.dtype).max
sim = F.pad(sim, ((wsz // 2,) * 4), value = mask_value)
sim = F.unfold(sim, kernel_size = wsz)
sim = rearrange(sim, 'b (h l j) (x y) -> b h l j x y', h = heads, l = num_queries, x = height, y = width)
# rel pos bias
sim = sim + self.rel_pos_bias
# numerically stable attention
sim = sim - sim.amax(dim = -3, keepdim = True).detach()
attn = sim.softmax(dim = -3)
# unfold values
v = F.pad(v, ((wsz // 2,) * 4), value = 0.)
v = F.unfold(v, kernel_size = wsz)
v = rearrange(v, 'b (d j) (x y) -> b d j x y', d = dim_head, x = height, y = width)
# aggregate values
out = einsum('b h l j x y, b d j x y -> b l h d x y', attn, v)
# combine heads
out = rearrange(out, 'b l h d x y -> (b l) (h d) x y')
out = self.to_out(out)
out = rearrange(out, '(b l) d x y -> b l d x y', b = batch)
# return original input if one query
if is_one_query:
out = rearrange(out, 'b 1 ... -> b ...')
return out
class QueryAttnUpsample(nn.Module):
def __init__(self, dim, **kwargs):
super().__init__()
self.norm = LayerNormChan(dim)
self.qna = QueryAndAttend(dim = dim, num_queries = 4, **kwargs)
def forward(self, x):
x = self.norm(x)
out = self.qna(x)
out = rearrange(out, 'b (w1 w2) c h w -> b c (h w1) (w w2)', w1 = 2, w2 = 2)
return out

View File

@@ -1,9 +1,51 @@
import click
import torch
import torchvision.transforms as T
from pathlib import Path
from dalle2_pytorch import DALLE2, Decoder, DiffusionPrior
def safeget(dictionary, keys, default = None):
return reduce(lambda d, key: d.get(key, default) if isinstance(d, dict) else default, keys.split('.'), dictionary)
def simple_slugify(text, max_length = 255):
return text.replace("-", "_").replace(",", "").replace(" ", "_").replace("|", "--").strip('-_')[:max_length]
def get_pkg_version():
from pkg_resources import get_distribution
return get_distribution('dalle2_pytorch').version
def main():
pass
@click.command()
@click.option('--model', default = './dalle2.pt', help = 'path to trained DALL-E2 model')
@click.option('--cond_scale', default = 2, help = 'conditioning scale (classifier free guidance) in decoder')
@click.argument('text')
def dream(text):
return image
def dream(
model,
cond_scale,
text
):
model_path = Path(model)
full_model_path = str(model_path.resolve())
assert model_path.exists(), f'model not found at {full_model_path}'
loaded = torch.load(str(model_path))
version = safeget(loaded, 'version')
print(f'loading DALL-E2 from {full_model_path}, saved at version {version} - current package version is {get_pkg_version()}')
prior_init_params = safeget(loaded, 'init_params.prior')
decoder_init_params = safeget(loaded, 'init_params.decoder')
model_params = safeget(loaded, 'model_params')
prior = DiffusionPrior(**prior_init_params)
decoder = Decoder(**decoder_init_params)
dalle2 = DALLE2(prior, decoder)
dalle2.load_state_dict(model_params)
image = dalle2(text, cond_scale = cond_scale)
pil_image = T.ToPILImage()(image)
return pil_image.save(f'./{simple_slugify(text)}.png')

View File

@@ -2,6 +2,7 @@ import math
from tqdm import tqdm
from inspect import isfunction
from functools import partial
from contextlib import contextmanager
import torch
import torch.nn.functional as F
@@ -15,6 +16,8 @@ from einops_exts.torch import EinopsToAndFrom
from kornia.filters import gaussian_blur2d
from dalle2_pytorch.tokenizer import tokenizer
from dalle2_pytorch.vqgan_vae import NullVQGanVAE, VQGanVAE
from dalle2_pytorch.attention import QueryAttnUpsample
# use x-clip
@@ -47,6 +50,12 @@ def is_list_str(x):
return False
return all([type(el) == str for el in x])
def pad_tuple_to_length(t, length, fillvalue = None):
remain_length = length - len(t)
if remain_length <= 0:
return t
return (*t, *((fillvalue,) * remain_length))
# for controlling freezing of CLIP
def set_module_requires_grad_(module, requires_grad):
@@ -136,23 +145,27 @@ def sigmoid_beta_schedule(timesteps):
# diffusion prior
class RMSNorm(nn.Module):
class LayerNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.ones(dim))
self.register_buffer("beta", torch.zeros(dim))
def forward(self, x):
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
class ChanLayerNorm(nn.Module):
def __init__(self, dim, eps = 1e-5):
super().__init__()
self.eps = eps
self.scale = dim ** 0.5
self.gamma = nn.Parameter(torch.ones(dim))
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
def forward(self, x):
squared_sum = (x ** 2).sum(dim = -1, keepdim = True)
inv_norm = torch.rsqrt(squared_sum + self.eps)
return x * inv_norm * self.gamma * self.scale
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (var + self.eps).sqrt() * self.g
class ChanRMSNorm(RMSNorm):
def forward(self, x):
squared_sum = (x ** 2).sum(dim = 1, keepdim = True)
inv_norm = torch.rsqrt(squared_sum + self.eps)
return x * inv_norm * rearrange(self.gamma, 'c -> 1 c 1 1') * self.scale
class Residual(nn.Module):
def __init__(self, fn):
@@ -248,10 +261,10 @@ def FeedForward(dim, mult = 4, dropout = 0., post_activation_norm = False):
inner_dim = int(mult * dim)
return nn.Sequential(
RMSNorm(dim),
LayerNorm(dim),
nn.Linear(dim, inner_dim * 2, bias = False),
SwiGLU(),
RMSNorm(inner_dim) if post_activation_norm else nn.Identity(),
LayerNorm(inner_dim) if post_activation_norm else nn.Identity(),
nn.Dropout(dropout),
nn.Linear(inner_dim, dim, bias = False)
)
@@ -274,7 +287,8 @@ class Attention(nn.Module):
inner_dim = dim_head * heads
self.causal = causal
self.norm = RMSNorm(dim)
self.norm = LayerNorm(dim)
self.post_norm = LayerNorm(dim) # sandwich norm from Coqview paper + Normformer
self.dropout = nn.Dropout(dropout)
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
@@ -330,7 +344,8 @@ class Attention(nn.Module):
out = einsum('b h i j, b j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
out = self.to_out(out)
return self.post_norm(out)
class CausalTransformer(nn.Module):
def __init__(
@@ -343,7 +358,8 @@ class CausalTransformer(nn.Module):
ff_mult = 4,
norm_out = False,
attn_dropout = 0.,
ff_dropout = 0.
ff_dropout = 0.,
final_proj = True
):
super().__init__()
self.rel_pos_bias = RelPosBias(heads = heads)
@@ -355,7 +371,8 @@ class CausalTransformer(nn.Module):
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
]))
self.norm = RMSNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
self.norm = LayerNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
self.project_out = nn.Linear(dim, dim, bias = False) if final_proj else nn.Identity()
def forward(
self,
@@ -370,7 +387,8 @@ class CausalTransformer(nn.Module):
x = attn(x, mask = mask, attn_bias = attn_bias) + x
x = ff(x) + x
return self.norm(x)
out = self.norm(x)
return self.project_out(out)
class DiffusionPriorNetwork(nn.Module):
def __init__(
@@ -403,25 +421,41 @@ class DiffusionPriorNetwork(nn.Module):
image_embed,
diffusion_timesteps,
*,
text_encodings,
text_embed,
text_encodings = None,
mask = None,
cond_drop_prob = 0.2
):
batch, text_enc_len, device = image_embed.shape[0], text_encodings.shape[-2], image_embed.device
batch, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype
# in section 2.2, last paragraph
# "... consisting of encoded text, CLIP text embedding, diffusion timestep embedding, noised CLIP image embedding, final embedding for prediction"
text_embed, image_embed = rearrange_many((text_embed, image_embed), 'b d -> b 1 d')
# make text encodings optional
# although the paper seems to suggest it is present <--
if not exists(text_encodings):
text_encodings = torch.empty((batch, 0, dim), device = device, dtype = dtype)
if not exists(mask):
mask = torch.ones((batch, text_encodings.shape[-2]), device = device, dtype = torch.bool)
# classifier free guidance
cond_prob_mask = prob_mask_like((batch,), cond_drop_prob, device = device)
cond_prob_mask = rearrange(cond_prob_mask, 'b -> b 1')
mask &= cond_prob_mask
# whether text embedding is masked or not depends on the classifier free guidance conditional masking
mask = torch.cat((mask, cond_prob_mask), dim = 1)
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
# but let's just do it right
if exists(mask):
not_all_masked_out = mask.any(dim = -1)
mask = torch.cat((mask, rearrange(not_all_masked_out, 'b -> b 1')), dim = 1)
if exists(mask):
mask = F.pad(mask, (0, 2), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
@@ -437,16 +471,6 @@ class DiffusionPriorNetwork(nn.Module):
learned_queries
), dim = -2)
# mask if it doesn't exist
if not exists(mask):
mask = torch.ones((batch, text_enc_len), device = device, dtype = torch.bool)
# classifier free guidance
cond_prob_mask = prob_mask_like((batch,), cond_drop_prob, device = device)
mask &= rearrange(cond_prob_mask, 'b -> b 1')
# attend
tokens = self.causal_transformer(tokens, mask = mask)
@@ -462,25 +486,35 @@ class DiffusionPrior(nn.Module):
self,
net,
*,
clip,
timesteps=1000,
cond_drop_prob=0.2,
loss_type="l1",
predict_x0=True,
beta_schedule="cosine",
clip = None,
image_embed_dim = None,
image_size = None,
image_channels = 3,
timesteps = 1000,
cond_drop_prob = 0.2,
loss_type = "l1",
predict_x_start = True,
beta_schedule = "cosine",
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
):
super().__init__()
assert isinstance(clip, CLIP)
freeze_model_and_make_eval_(clip)
self.clip = clip
if exists(clip):
assert isinstance(clip, CLIP)
freeze_model_and_make_eval_(clip)
self.clip = clip
else:
assert exists(image_embed_dim), 'latent dimension must be given, if training prior network without CLIP given'
self.clip = None
self.net = net
self.image_embed_dim = clip.dim_latent
self.channels = clip.image_channels
self.image_size = clip.image_size
self.cond_drop_prob = cond_drop_prob
self.image_embed_dim = default(image_embed_dim, lambda: clip.dim_latent)
self.channels = default(image_channels, lambda: clip.image_channels)
self.predict_x0 = predict_x0
self.cond_drop_prob = cond_drop_prob
self.condition_on_text_encodings = condition_on_text_encodings
self.predict_x_start = predict_x_start
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
if beta_schedule == "cosine":
@@ -497,7 +531,7 @@ class DiffusionPrior(nn.Module):
raise NotImplementedError()
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod = torch.cumprod(alphas, axis = 0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
timesteps, = betas.shape
@@ -530,17 +564,27 @@ class DiffusionPrior(nn.Module):
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
@torch.no_grad()
def get_image_embed(self, image):
assert exists(self.clip)
image_encoding = self.clip.visual_transformer(image)
image_cls = image_encoding[:, 0]
image_embed = self.clip.to_visual_latent(image_cls)
return l2norm(image_embed)
@torch.no_grad()
def get_text_cond(self, text):
assert exists(self.clip)
text_encodings = self.clip.text_transformer(text)
text_cls, text_encodings = text_encodings[:, 0], text_encodings[:, 1:]
text_embed = self.clip.to_text_latent(text_cls)
text_embed = l2norm(text_embed)
if not self.condition_on_text_encodings:
return dict(text_embed = text_embed)
return dict(text_encodings = text_encodings, text_embed = text_embed, mask = text != 0)
def q_mean_variance(self, x_start, t):
@@ -565,14 +609,16 @@ class DiffusionPrior(nn.Module):
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, x, t, text_cond, clip_denoised: bool):
if self.predict_x0:
x_recon = self.net(x, t, **text_cond)
pred = self.net(x, t, **text_cond)
if self.predict_x_start:
x_recon = pred
# not 100% sure of this above line - for any spectators, let me know in the github issues (or through a pull request) if you know how to correctly do this
# i'll be rereading https://arxiv.org/abs/2111.14822, where i think a similar approach is taken
else:
x_recon = self.predict_start_from_noise(x, t = t, noise = self.net(x, t, **text_cond))
x_recon = self.predict_start_from_noise(x, t = t, noise = pred)
if clip_denoised:
if clip_denoised and not self.predict_x_start:
x_recon.clamp_(-1., 1.)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
@@ -618,7 +664,7 @@ class DiffusionPrior(nn.Module):
**text_cond
)
to_predict = noise if not self.predict_x0 else image_embed
to_predict = noise if not self.predict_x_start else image_embed
if self.loss_type == 'l1':
loss = F.l1_loss(to_predict, x_recon)
@@ -657,13 +703,41 @@ class DiffusionPrior(nn.Module):
top_image_embeds = image_embeds.gather(1, top_sim_indices)
return rearrange(top_image_embeds, 'b 1 d -> b d')
def forward(self, text, image, *args, **kwargs):
b, device, img_size, = image.shape[0], image.device, self.image_size
check_shape(image, 'b c h w', h = img_size, w = img_size, c = self.channels)
def forward(
self,
text = None,
image = None,
text_embed = None, # allow for training on preprocessed CLIP text and image embeddings
image_embed = None,
text_encodings = None, # as well as CLIP text encodings
text_mask = None, # text mask <- may eventually opt for the learned padding tokens technique from DALL-E1 to reduce complexity
*args,
**kwargs
):
assert exists(text) ^ exists(text_embed), 'either text or text embedding must be supplied'
assert exists(image) ^ exists(image_embed), 'either text or text embedding must be supplied'
assert not (self.condition_on_text_encodings and (not exists(text_encodings) and not exists(text))), 'text encodings must be present if you specified you wish to condition on it on initialization'
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
image_embed = self.get_image_embed(image)
text_cond = self.get_text_cond(text)
if exists(image):
image_embed = self.get_image_embed(image)
# calculate text conditionings, based on what is passed in
if exists(text):
text_cond = self.get_text_cond(text)
else:
text_cond = dict(
text_embed = text_embed,
text_encodings = text_encodings,
mask = text_mask
)
# timestep conditioning from ddpm
batch, device = image_embed.shape[0], image_embed.device
times = torch.randint(0, self.num_timesteps, (batch,), device = device, dtype = torch.long)
# calculate forward loss
loss = self.p_losses(image_embed, times, text_cond = text_cond, *args, **kwargs)
return loss
@@ -719,7 +793,7 @@ class ConvNextBlock(nn.Module):
inner_dim = int(dim_out * mult)
self.net = nn.Sequential(
ChanRMSNorm(dim) if norm else nn.Identity(),
ChanLayerNorm(dim) if norm else nn.Identity(),
nn.Conv2d(dim, inner_dim, 3, padding = 1),
nn.GELU(),
nn.Conv2d(inner_dim, dim_out, 3, padding = 1)
@@ -755,8 +829,8 @@ class CrossAttention(nn.Module):
context_dim = default(context_dim, dim)
self.norm = RMSNorm(dim)
self.norm_context = RMSNorm(context_dim)
self.norm = LayerNorm(dim)
self.norm_context = LayerNorm(context_dim)
self.dropout = nn.Dropout(dropout)
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
@@ -820,16 +894,19 @@ class Unet(nn.Module):
image_embed_dim,
cond_dim = None,
num_image_tokens = 4,
num_time_tokens = 2,
out_dim = None,
dim_mults=(1, 2, 4, 8),
channels = 3,
attn_dim_head = 32,
attn_heads = 8,
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
lowres_cond_upsample_mode = 'bilinear',
blur_sigma = 0.1,
blur_kernel_size = 3,
sparse_attn = False,
sparse_attn_window = 8, # window size for sparse attention
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
cond_on_text_encodings = False,
max_text_len = 256,
cond_on_image_embeds = False,
):
super().__init__()
# save locals to take care of some hyperparameters for cascading DDPM
@@ -841,9 +918,6 @@ class Unet(nn.Module):
# for eventual cascading diffusion
self.lowres_cond = lowres_cond
self.lowres_cond_upsample_mode = lowres_cond_upsample_mode
self.lowres_blur_kernel_size = blur_kernel_size
self.lowres_blur_sigma = blur_sigma
# determine dimensions
@@ -862,8 +936,8 @@ class Unet(nn.Module):
SinusoidalPosEmb(dim),
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Linear(dim * 4, cond_dim),
Rearrange('b d -> b 1 d')
nn.Linear(dim * 4, cond_dim * num_time_tokens),
Rearrange('b (r d) -> b r d', r = num_time_tokens)
)
self.image_to_cond = nn.Sequential(
@@ -873,10 +947,20 @@ class Unet(nn.Module):
self.text_to_cond = nn.LazyLinear(cond_dim)
# finer control over whether to condition on image embeddings and text encodings
# so one can have the latter unets in the cascading DDPMs only focus on super-resoluting
self.cond_on_text_encodings = cond_on_text_encodings
self.cond_on_image_embeds = cond_on_image_embeds
# for classifier free guidance
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
self.null_text_embed = nn.Parameter(torch.randn(1, 1, cond_dim))
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim))
# attention related params
attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head)
# layers
@@ -891,7 +975,7 @@ class Unet(nn.Module):
self.downs.append(nn.ModuleList([
ConvNextBlock(dim_in, dim_out, norm = ind != 0),
Residual(GridAttention(dim_out, window_size = sparse_attn_window)) if sparse_attn else nn.Identity(),
Residual(GridAttention(dim_out, window_size = sparse_attn_window, **attn_kwargs)) if sparse_attn else nn.Identity(),
ConvNextBlock(dim_out, dim_out, cond_dim = layer_cond_dim),
Downsample(dim_out) if not is_last else nn.Identity()
]))
@@ -899,7 +983,7 @@ class Unet(nn.Module):
mid_dim = dims[-1]
self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim))) if attend_at_middle else None
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
@@ -908,7 +992,7 @@ class Unet(nn.Module):
self.ups.append(nn.ModuleList([
ConvNextBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim),
Residual(GridAttention(dim_in, window_size = sparse_attn_window)) if sparse_attn else nn.Identity(),
Residual(GridAttention(dim_in, window_size = sparse_attn_window, **attn_kwargs)) if sparse_attn else nn.Identity(),
ConvNextBlock(dim_in, dim_in, cond_dim = layer_cond_dim),
Upsample(dim_in)
]))
@@ -921,11 +1005,16 @@ class Unet(nn.Module):
# if the current settings for the unet are not correct
# for cascading DDPM, then reinit the unet with the right settings
def force_lowres_cond(self, lowres_cond):
if lowres_cond == self.lowres_cond:
def cast_model_parameters(
self,
*,
lowres_cond,
channels
):
if lowres_cond == self.lowres_cond and channels == self.channels:
return self
updated_kwargs = {**self._locals, 'lowres_cond': lowres_cond}
updated_kwargs = {**self._locals, 'lowres_cond': lowres_cond, 'channels': channels}
return self.__class__(**updated_kwargs)
def forward_with_cond_scale(
@@ -961,13 +1050,6 @@ class Unet(nn.Module):
assert not (self.lowres_cond and not exists(lowres_cond_img)), 'low resolution conditioning image must be present'
if exists(lowres_cond_img):
if self.training:
# when training, blur the low resolution conditional image
blur_sigma = default(blur_sigma, self.lowres_blur_sigma)
blur_kernel_size = default(blur_kernel_size, self.lowres_blur_kernel_size)
lowres_cond_img = gaussian_blur2d(lowres_cond_img, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2))
lowres_cond_img = resize_image_to(lowres_cond_img, x.shape[-2:], mode = self.lowres_cond_upsample_mode)
x = torch.cat((x, lowres_cond_img), dim = 1)
# time conditioning
@@ -982,32 +1064,40 @@ class Unet(nn.Module):
# mask out image embedding depending on condition dropout
# for classifier free guidance
image_tokens = self.image_to_cond(image_embed)
image_tokens = None
image_tokens = torch.where(
cond_prob_mask,
image_tokens,
self.null_image_embed
)
if self.cond_on_image_embeds:
image_tokens = self.image_to_cond(image_embed)
image_tokens = torch.where(
cond_prob_mask,
image_tokens,
self.null_image_embed
)
# take care of text encodings (optional)
if exists(text_encodings):
text_tokens = None
if exists(text_encodings) and self.cond_on_text_encodings:
text_tokens = self.text_to_cond(text_encodings)
text_tokens = torch.where(
cond_prob_mask,
text_tokens,
self.null_text_embed
self.null_text_embed[:, :text_tokens.shape[1]]
)
# main conditioning tokens (c)
c = torch.cat((time_tokens, image_tokens), dim = -2)
c = time_tokens
if exists(image_tokens):
c = torch.cat((c, image_tokens), dim = -2)
# text and image conditioning tokens (mid_c)
# to save on compute, only do cross attention based conditioning on the inner most layers of the Unet
mid_c = c if not exists(text_encodings) else torch.cat((c, text_tokens), dim = -2)
mid_c = c if not exists(text_tokens) else torch.cat((c, text_tokens), dim = -2)
# go through the layers of the unet, down and up
@@ -1036,17 +1126,63 @@ class Unet(nn.Module):
return self.final_conv(x)
class LowresConditioner(nn.Module):
def __init__(
self,
cond_upsample_mode = 'bilinear',
downsample_first = True,
blur_sigma = 0.1,
blur_kernel_size = 3,
):
super().__init__()
self.cond_upsample_mode = cond_upsample_mode
self.downsample_first = downsample_first
self.blur_sigma = blur_sigma
self.blur_kernel_size = blur_kernel_size
def forward(
self,
cond_fmap,
*,
target_image_size,
downsample_image_size = None,
blur_sigma = None,
blur_kernel_size = None
):
target_image_size = cast_tuple(target_image_size, 2)
if self.training and self.downsample_first and exists(downsample_image_size):
cond_fmap = resize_image_to(cond_fmap, target_image_size, mode = self.cond_upsample_mode)
if self.training:
# when training, blur the low resolution conditional image
blur_sigma = default(blur_sigma, self.blur_sigma)
blur_kernel_size = default(blur_kernel_size, self.blur_kernel_size)
cond_fmap = gaussian_blur2d(cond_fmap, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2))
cond_fmap = resize_image_to(cond_fmap, target_image_size, mode = self.cond_upsample_mode)
return cond_fmap
class Decoder(nn.Module):
def __init__(
self,
unet,
*,
clip,
vae = tuple(),
timesteps = 1000,
cond_drop_prob = 0.2,
loss_type = 'l1',
beta_schedule = 'cosine',
image_sizes = None # for cascading ddpm, image size at each stage
predict_x_start = False,
predict_x_start_for_latent_diffusion = False,
image_sizes = None, # for cascading ddpm, image size at each stage
lowres_cond_upsample_mode = 'bilinear', # cascading ddpm - low resolution upsample mode
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
blur_sigma = 0.1, # cascading ddpm - blur sigma
blur_kernel_size = 3, # cascading ddpm - blur kernel size
condition_on_text_encodings = False, # the paper suggested that this didn't do much in the decoder, but i'm allowing the option for experimentation
):
super().__init__()
assert isinstance(clip, CLIP)
@@ -1055,14 +1191,33 @@ class Decoder(nn.Module):
self.clip_image_size = clip.image_size
self.channels = clip.image_channels
self.condition_on_text_encodings = condition_on_text_encodings
# automatically take care of ensuring that first unet is unconditional
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
unets = cast_tuple(unet)
vaes = pad_tuple_to_length(cast_tuple(vae), len(unets), fillvalue = NullVQGanVAE(channels = self.channels))
self.unets = nn.ModuleList([])
for ind, one_unet in enumerate(cast_tuple(unet)):
self.vaes = nn.ModuleList([])
for ind, (one_unet, one_vae) in enumerate(zip(unets, vaes)):
assert isinstance(one_unet, Unet)
assert isinstance(one_vae, (VQGanVAE, NullVQGanVAE))
is_first = ind == 0
one_unet = one_unet.force_lowres_cond(not is_first)
latent_dim = one_vae.encoded_dim if exists(one_vae) else None
unet_channels = default(latent_dim, self.channels)
one_unet = one_unet.cast_model_parameters(
lowres_cond = not is_first,
channels = unet_channels
)
self.unets.append(one_unet)
self.vaes.append(one_vae.copy_for_eval())
# unet image sizes
@@ -1071,12 +1226,30 @@ class Decoder(nn.Module):
assert len(self.unets) == len(image_sizes), f'you did not supply the correct number of u-nets ({len(self.unets)}) for resolutions {image_sizes}'
self.image_sizes = image_sizes
self.sample_channels = cast_tuple(self.channels, len(image_sizes))
# predict x0 config
self.predict_x_start = cast_tuple(predict_x_start, len(unets)) if not predict_x_start_for_latent_diffusion else tuple(map(lambda t: isinstance(t, VQGanVAE), self.vaes))
# cascading ddpm related stuff
lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets))
assert lowres_conditions == (False, *((True,) * (len(self.unets) - 1))), 'the first unet must be unconditioned (by low resolution image), and the rest of the unets must have `lowres_cond` set to True'
self.to_lowres_cond = LowresConditioner(
cond_upsample_mode = lowres_cond_upsample_mode,
downsample_first = lowres_downsample_first,
blur_sigma = blur_sigma,
blur_kernel_size = blur_kernel_size,
)
# classifier free guidance
self.cond_drop_prob = cond_drop_prob
# noise schedule
if beta_schedule == "cosine":
betas = cosine_beta_schedule(timesteps)
elif beta_schedule == "linear":
@@ -1091,7 +1264,7 @@ class Decoder(nn.Module):
raise NotImplementedError()
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod = torch.cumprod(alphas, axis = 0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
timesteps, = betas.shape
@@ -1124,10 +1297,31 @@ class Decoder(nn.Module):
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
def get_unet(self, unet_number):
assert 0 < unet_number <= len(self.unets)
index = unet_number - 1
return self.unets[index]
@contextmanager
def one_unet_in_gpu(self, unet_number = None, unet = None):
assert exists(unet_number) ^ exists(unet)
if exists(unet_number):
unet = self.get_unet(unet_number)
self.cuda()
self.unets.cpu()
unet.cuda()
yield
unet.cpu()
@torch.no_grad()
def get_text_encodings(self, text):
text_encodings = self.clip.text_transformer(text)
return text_encodings[:, 1:]
@torch.no_grad()
def get_image_embed(self, image):
image = resize_image_to(image, self.clip_image_size)
image_encoding = self.clip.visual_transformer(image)
@@ -1156,34 +1350,48 @@ class Decoder(nn.Module):
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, lowres_cond_img = None, clip_denoised = True, cond_scale = 1.):
pred_noise = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
x_recon = self.predict_start_from_noise(x, t = t, noise = pred_noise)
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, cond_scale = 1.):
pred = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
if clip_denoised:
if predict_x_start:
x_recon = pred
else:
x_recon = self.predict_start_from_noise(x, t = t, noise = pred)
if clip_denoised and not predict_x_start:
x_recon.clamp_(-1., 1.)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad()
def p_sample(self, unet, x, t, image_embed, text_encodings = None, cond_scale = 1., lowres_cond_img = None, clip_denoised = True, repeat_noise = False):
def p_sample(self, unet, x, t, image_embed, text_encodings = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, clip_denoised = True, repeat_noise = False):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised)
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start)
noise = noise_like(x.shape, device, repeat_noise)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad()
def p_sample_loop(self, unet, shape, image_embed, lowres_cond_img = None, text_encodings = None, cond_scale = 1):
def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, lowres_cond_img = None, text_encodings = None, cond_scale = 1):
device = self.betas.device
b = shape[0]
img = torch.randn(shape, device = device)
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
img = self.p_sample(unet, img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
img = self.p_sample(
unet,
img,
torch.full((b,), i, device = device, dtype = torch.long),
image_embed = image_embed,
text_encodings = text_encodings,
cond_scale = cond_scale,
lowres_cond_img = lowres_cond_img,
predict_x_start = predict_x_start
)
return img
def q_sample(self, x_start, t, noise=None):
@@ -1194,7 +1402,7 @@ class Decoder(nn.Module):
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)
def p_losses(self, unet, x_start, t, *, image_embed, lowres_cond_img = None, text_encodings = None, noise = None):
def p_losses(self, unet, x_start, t, *, image_embed, lowres_cond_img = None, text_encodings = None, predict_x_start = False, noise = None):
noise = default(noise, lambda: torch.randn_like(x_start))
x_noisy = self.q_sample(x_start = x_start, t = t, noise = noise)
@@ -1208,12 +1416,14 @@ class Decoder(nn.Module):
cond_drop_prob = self.cond_drop_prob
)
target = noise if not predict_x_start else x_start
if self.loss_type == 'l1':
loss = F.l1_loss(noise, x_recon)
loss = F.l1_loss(target, x_recon)
elif self.loss_type == 'l2':
loss = F.mse_loss(noise, x_recon)
loss = F.mse_loss(target, x_recon)
elif self.loss_type == "huber":
loss = F.smooth_l1_loss(noise, x_recon)
loss = F.smooth_l1_loss(target, x_recon)
else:
raise NotImplementedError()
@@ -1223,25 +1433,58 @@ class Decoder(nn.Module):
@eval_decorator
def sample(self, image_embed, text = None, cond_scale = 1.):
batch_size = image_embed.shape[0]
channels = self.channels
text_encodings = self.get_text_encodings(text) if exists(text) else None
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
img = None
for unet, image_size in tqdm(zip(self.unets, self.image_sizes)):
shape = (batch_size, channels, image_size, image_size)
img = self.p_sample_loop(unet, shape, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = img)
for unet, vae, channel, image_size, predict_x_start in tqdm(zip(self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)):
with self.one_unet_in_gpu(unet = unet):
lowres_cond_img = None
shape = (batch_size, channel, image_size, image_size)
if unet.lowres_cond:
lowres_cond_img = self.to_lowres_cond(img, target_image_size = image_size)
image_size = vae.get_encoded_fmap_size(image_size)
shape = (batch_size, vae.encoded_dim, image_size, image_size)
if exists(lowres_cond_img):
lowres_cond_img = vae.encode(lowres_cond_img)
img = self.p_sample_loop(
unet,
shape,
image_embed = image_embed,
text_encodings = text_encodings,
cond_scale = cond_scale,
predict_x_start = predict_x_start,
lowres_cond_img = lowres_cond_img
)
img = vae.decode(img)
return img
def forward(self, image, text = None, image_embed = None, text_encodings = None, unet_number = None):
def forward(
self,
image,
text = None,
image_embed = None,
text_encodings = None,
unet_number = None
):
assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
unet_number = default(unet_number, 1)
assert 1 <= unet_number <= len(self.unets)
unet_index = unet_number - 1
index = unet_number - 1
unet = self.unets[index]
target_image_size = self.image_sizes[index]
unet = self.get_unet(unet_number)
target_image_size = self.image_sizes[unet_index]
vae = self.vaes[unet_index]
predict_x_start = self.predict_x_start[unet_index]
b, c, h, w, device, = *image.shape, image.device
@@ -1255,9 +1498,19 @@ class Decoder(nn.Module):
text_encodings = self.get_text_encodings(text) if exists(text) and not exists(text_encodings) else None
lowres_cond_img = image if index > 0 else None
ddpm_image = resize_image_to(image, target_image_size)
return self.p_losses(unet, ddpm_image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img)
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None
image = resize_image_to(image, target_image_size)
vae.eval()
with torch.no_grad():
image = vae.encode(image)
if exists(lowres_cond_img):
lowres_cond_img = vae.encode(lowres_cond_img)
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start)
# main class
@@ -1274,7 +1527,9 @@ class DALLE2(nn.Module):
assert isinstance(decoder, Decoder)
self.prior = prior
self.decoder = decoder
self.prior_num_samples = prior_num_samples
self.decoder_need_text_cond = self.decoder.condition_on_text_encodings
@torch.no_grad()
@eval_decorator
@@ -1284,11 +1539,18 @@ class DALLE2(nn.Module):
cond_scale = 1.
):
device = next(self.parameters()).device
one_text = isinstance(text, str) or (not is_list_str(text) and text.shape[0] == 1)
if isinstance(text, str) or is_list_str(text):
text = [text] if not isinstance(text, (list, tuple)) else text
text = tokenizer.tokenize(text).to(device)
image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples)
images = self.decoder.sample(image_embed, cond_scale = cond_scale)
text_cond = text if self.decoder_need_text_cond else None
images = self.decoder.sample(image_embed, text = text_cond, cond_scale = cond_scale)
if one_text:
return images[0]
return images

View File

@@ -1,12 +0,0 @@
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange
class LatentDiffusion(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x

View File

@@ -0,0 +1,53 @@
import copy
import torch
from torch import nn
# exponential moving average wrapper
class EMA(nn.Module):
def __init__(
self,
model,
beta = 0.99,
ema_update_after_step = 1000,
ema_update_every = 10,
):
super().__init__()
self.beta = beta
self.online_model = model
self.ema_model = copy.deepcopy(model)
self.ema_update_after_step = ema_update_after_step # only start EMA after this step number, starting at 0
self.ema_update_every = ema_update_every
self.register_buffer('initted', torch.Tensor([False]))
self.register_buffer('step', torch.tensor([0.]))
def update(self):
self.step += 1
if self.step <= self.ema_update_after_step or (self.step % self.ema_update_every) != 0:
return
if not self.initted:
self.ema_model.state_dict(self.online_model.state_dict())
self.initted.data.copy_(torch.Tensor([True]))
self.update_moving_average(self.ema_model, self.online_model)
def update_moving_average(ma_model, current_model):
def calculate_ema(beta, old, new):
if not exists(old):
return new
return old * beta + (1 - beta) * new
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
old_weight, up_weight = ma_params.data, current_params.data
ma_params.data = calculate_ema(self.beta, old_weight, up_weight)
for current_buffer, ma_buffer in zip(current_model.buffers(), ma_model.buffers()):
new_buffer_value = calculate_ema(self.beta, ma_buffer, current_buffer)
ma_buffer.copy_(new_buffer_value)
def __call__(self, *args, **kwargs):
return self.ema_model(*args, **kwargs)

566
dalle2_pytorch/vqgan_vae.py Normal file
View File

@@ -0,0 +1,566 @@
import copy
import math
from math import sqrt
from functools import partial, wraps
from vector_quantize_pytorch import VectorQuantize as VQ
import torch
from torch import nn, einsum
import torch.nn.functional as F
from torch.autograd import grad as torch_grad
import torchvision
from einops import rearrange, reduce, repeat
from dalle2_pytorch.attention import QueryAttnUpsample
# constants
MList = nn.ModuleList
# helper functions
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
# decorators
def eval_decorator(fn):
def inner(model, *args, **kwargs):
was_training = model.training
model.eval()
out = fn(model, *args, **kwargs)
model.train(was_training)
return out
return inner
def remove_vgg(fn):
@wraps(fn)
def inner(self, *args, **kwargs):
has_vgg = hasattr(self, 'vgg')
if has_vgg:
vgg = self.vgg
delattr(self, 'vgg')
out = fn(self, *args, **kwargs)
if has_vgg:
self.vgg = vgg
return out
return inner
# keyword argument helpers
def pick_and_pop(keys, d):
values = list(map(lambda key: d.pop(key), keys))
return dict(zip(keys, values))
def group_dict_by_key(cond, d):
return_val = [dict(),dict()]
for key in d.keys():
match = bool(cond(key))
ind = int(not match)
return_val[ind][key] = d[key]
return (*return_val,)
def string_begins_with(prefix, str):
return str.startswith(prefix)
def group_by_key_prefix(prefix, d):
return group_dict_by_key(partial(string_begins_with, prefix), d)
def groupby_prefix_and_trim(prefix, d):
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
return kwargs_without_prefix, kwargs
# tensor helper functions
def log(t, eps = 1e-10):
return torch.log(t + eps)
def gradient_penalty(images, output, weight = 10):
batch_size = images.shape[0]
gradients = torch_grad(outputs = output, inputs = images,
grad_outputs = torch.ones(output.size(), device = images.device),
create_graph = True, retain_graph = True, only_inputs = True)[0]
gradients = rearrange(gradients, 'b ... -> b (...)')
return weight * ((gradients.norm(2, dim = 1) - 1) ** 2).mean()
def l2norm(t):
return F.normalize(t, dim = -1)
def leaky_relu(p = 0.1):
return nn.LeakyReLU(0.1)
def stable_softmax(t, dim = -1, alpha = 32 ** 2):
t = t / alpha
t = t - torch.amax(t, dim = dim, keepdim = True).detach()
return (t * alpha).softmax(dim = dim)
def safe_div(numer, denom, eps = 1e-8):
return numer / (denom + eps)
# gan losses
def hinge_discr_loss(fake, real):
return (F.relu(1 + fake) + F.relu(1 - real)).mean()
def hinge_gen_loss(fake):
return -fake.mean()
def bce_discr_loss(fake, real):
return (-log(1 - torch.sigmoid(fake)) - log(torch.sigmoid(real))).mean()
def bce_gen_loss(fake):
return -log(torch.sigmoid(fake)).mean()
def grad_layer_wrt_loss(loss, layer):
return torch_grad(
outputs = loss,
inputs = layer,
grad_outputs = torch.ones_like(loss),
retain_graph = True
)[0].detach()
# vqgan vae
class LayerNormChan(nn.Module):
def __init__(
self,
dim,
eps = 1e-5
):
super().__init__()
self.eps = eps
self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1))
def forward(self, x):
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (var + self.eps).sqrt() * self.gamma
class Discriminator(nn.Module):
def __init__(
self,
dims,
channels = 3,
groups = 16,
init_kernel_size = 5
):
super().__init__()
dim_pairs = zip(dims[:-1], dims[1:])
self.layers = MList([nn.Sequential(nn.Conv2d(channels, dims[0], init_kernel_size, padding = init_kernel_size // 2), leaky_relu())])
for dim_in, dim_out in dim_pairs:
self.layers.append(nn.Sequential(
nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1),
nn.GroupNorm(groups, dim_out),
leaky_relu()
))
dim = dims[-1]
self.to_logits = nn.Sequential( # return 5 x 5, for PatchGAN-esque training
nn.Conv2d(dim, dim, 1),
leaky_relu(),
nn.Conv2d(dim, 1, 4)
)
def forward(self, x):
for net in self.layers:
x = net(x)
return self.to_logits(x)
class ContinuousPositionBias(nn.Module):
""" from https://arxiv.org/abs/2111.09883 """
def __init__(self, *, dim, heads, layers = 2):
super().__init__()
self.net = MList([])
self.net.append(nn.Sequential(nn.Linear(2, dim), leaky_relu()))
for _ in range(layers - 1):
self.net.append(nn.Sequential(nn.Linear(dim, dim), leaky_relu()))
self.net.append(nn.Linear(dim, heads))
self.register_buffer('rel_pos', None, persistent = False)
def forward(self, x):
n, device = x.shape[-1], x.device
fmap_size = int(sqrt(n))
if not exists(self.rel_pos):
pos = torch.arange(fmap_size, device = device)
grid = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij'))
grid = rearrange(grid, 'c i j -> (i j) c')
rel_pos = rearrange(grid, 'i c -> i 1 c') - rearrange(grid, 'j c -> 1 j c')
rel_pos = torch.sign(rel_pos) * torch.log(rel_pos.abs() + 1)
self.register_buffer('rel_pos', rel_pos, persistent = False)
rel_pos = self.rel_pos.float()
for layer in self.net:
rel_pos = layer(rel_pos)
bias = rearrange(rel_pos, 'i j h -> h i j')
return x + bias
class GLUResBlock(nn.Module):
def __init__(self, chan, groups = 16):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(chan, chan * 2, 3, padding = 1),
nn.GLU(dim = 1),
nn.GroupNorm(groups, chan),
nn.Conv2d(chan, chan * 2, 3, padding = 1),
nn.GLU(dim = 1),
nn.GroupNorm(groups, chan),
nn.Conv2d(chan, chan, 1)
)
def forward(self, x):
return self.net(x) + x
class ResBlock(nn.Module):
def __init__(self, chan, groups = 16):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(chan, chan, 3, padding = 1),
nn.GroupNorm(groups, chan),
leaky_relu(),
nn.Conv2d(chan, chan, 3, padding = 1),
nn.GroupNorm(groups, chan),
leaky_relu(),
nn.Conv2d(chan, chan, 1)
)
def forward(self, x):
return self.net(x) + x
# vqgan attention layer
class VQGanAttention(nn.Module):
def __init__(
self,
*,
dim,
dim_head = 64,
heads = 8,
dropout = 0.
):
super().__init__()
self.heads = heads
self.scale = dim_head ** -0.5
inner_dim = heads * dim_head
self.dropout = nn.Dropout(dropout)
self.pre_norm = LayerNormChan(dim)
self.cpb = ContinuousPositionBias(dim = dim // 4, heads = heads)
self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)
self.to_out = nn.Conv2d(inner_dim, dim, 1, bias = False)
def forward(self, x):
h = self.heads
height, width, residual = *x.shape[-2:], x.clone()
x = self.pre_norm(x)
q, k, v = self.to_qkv(x).chunk(3, dim = 1)
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = h), (q, k, v))
sim = einsum('b h c i, b h c j -> b h i j', q, k) * self.scale
sim = self.cpb(sim)
attn = stable_softmax(sim, dim = -1)
attn = self.dropout(attn)
out = einsum('b h i j, b h c j -> b h c i', attn, v)
out = rearrange(out, 'b h c (x y) -> b (h c) x y', x = height, y = width)
out = self.to_out(out)
return out + residual
class NullVQGanVAE(nn.Module):
def __init__(
self,
*,
channels
):
super().__init__()
self.encoded_dim = channels
self.layers = 0
def get_encoded_fmap_size(self, size):
return size
def copy_for_eval(self):
return self
def encode(self, x):
return x
def decode(self, x):
return x
class VQGanVAE(nn.Module):
def __init__(
self,
*,
dim,
image_size,
channels = 3,
layers = 4,
layer_mults = None,
l2_recon_loss = False,
use_hinge_loss = True,
num_resnet_blocks = 1,
vgg = None,
vq_codebook_size = 512,
vq_decay = 0.8,
vq_commitment_weight = 1.,
vq_kmeans_init = True,
vq_use_cosine_sim = True,
use_attn = True,
attn_dim_head = 64,
attn_heads = 8,
resnet_groups = 16,
attn_dropout = 0.,
first_conv_kernel_size = 5,
use_vgg_and_gan = True,
**kwargs
):
super().__init__()
assert dim % resnet_groups == 0, f'dimension {dim} must be divisible by {resnet_groups} (groups for the groupnorm)'
vq_kwargs, kwargs = groupby_prefix_and_trim('vq_', kwargs)
self.image_size = image_size
self.channels = channels
self.layers = layers
self.fmap_size = image_size // (layers ** 2)
self.codebook_size = vq_codebook_size
self.encoders = MList([])
self.decoders = MList([])
layer_mults = default(layer_mults, list(map(lambda t: 2 ** t, range(layers))))
assert len(layer_mults) == layers, 'layer multipliers must be equal to designated number of layers'
layer_dims = [dim * mult for mult in layer_mults]
dims = (dim, *layer_dims)
codebook_dim = layer_dims[-1]
self.encoded_dim = dims[-1]
dim_pairs = zip(dims[:-1], dims[1:])
append = lambda arr, t: arr.append(t)
prepend = lambda arr, t: arr.insert(0, t)
if not isinstance(num_resnet_blocks, tuple):
num_resnet_blocks = (*((0,) * (layers - 1)), num_resnet_blocks)
if not isinstance(use_attn, tuple):
use_attn = (*((False,) * (layers - 1)), use_attn)
assert len(num_resnet_blocks) == layers, 'number of resnet blocks config must be equal to number of layers'
assert len(use_attn) == layers
for layer_index, (dim_in, dim_out), layer_num_resnet_blocks, layer_use_attn in zip(range(layers), dim_pairs, num_resnet_blocks, use_attn):
append(self.encoders, nn.Sequential(nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), leaky_relu()))
prepend(self.decoders, nn.Sequential(nn.ConvTranspose2d(dim_out, dim_in, 4, 2, 1), leaky_relu()))
if layer_use_attn:
prepend(self.decoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
for _ in range(layer_num_resnet_blocks):
append(self.encoders, ResBlock(dim_out, groups = resnet_groups))
prepend(self.decoders, GLUResBlock(dim_out, groups = resnet_groups))
if layer_use_attn:
append(self.encoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
prepend(self.encoders, nn.Conv2d(channels, dim, first_conv_kernel_size, padding = first_conv_kernel_size // 2))
append(self.decoders, nn.Conv2d(dim, channels, 1))
self.vq = VQ(
dim = codebook_dim,
codebook_size = vq_codebook_size,
decay = vq_decay,
commitment_weight = vq_commitment_weight,
accept_image_fmap = True,
kmeans_init = vq_kmeans_init,
use_cosine_sim = vq_use_cosine_sim,
**vq_kwargs
)
# reconstruction loss
self.recon_loss_fn = F.mse_loss if l2_recon_loss else F.l1_loss
# turn off GAN and perceptual loss if grayscale
self.vgg = None
self.discr = None
self.use_vgg_and_gan = use_vgg_and_gan
if not use_vgg_and_gan:
return
# preceptual loss
if exists(vgg):
self.vgg = vgg
else:
self.vgg = torchvision.models.vgg16(pretrained = True)
self.vgg.classifier = nn.Sequential(*self.vgg.classifier[:-2])
# gan related losses
self.discr = Discriminator(dims = dims, channels = channels)
self.discr_loss = hinge_discr_loss if use_hinge_loss else bce_discr_loss
self.gen_loss = hinge_gen_loss if use_hinge_loss else bce_gen_loss
def get_encoded_fmap_size(self, image_size):
return image_size // (2 ** self.layers)
def copy_for_eval(self):
device = next(self.parameters()).device
vae_copy = copy.deepcopy(self.cpu())
if vae_copy.use_vgg_and_gan:
del vae_copy.discr
del vae_copy.vgg
vae_copy.eval()
return vae_copy.to(device)
@remove_vgg
def state_dict(self, *args, **kwargs):
return super().state_dict(*args, **kwargs)
@remove_vgg
def load_state_dict(self, *args, **kwargs):
return super().load_state_dict(*args, **kwargs)
@property
def codebook(self):
return self.vq.codebook
def encode(self, fmap):
for enc in self.encoders:
fmap = enc(fmap)
return fmap
def decode(self, fmap, return_indices_and_loss = False):
fmap, indices, commit_loss = self.vq(fmap)
for dec in self.decoders:
fmap = dec(fmap)
if not return_indices_and_loss:
return fmap
return fmap, indices, commit_loss
def forward(
self,
img,
return_loss = False,
return_discr_loss = False,
return_recons = False,
add_gradient_penalty = True
):
batch, channels, height, width, device = *img.shape, img.device
assert height == self.image_size and width == self.image_size, 'height and width of input image must be equal to {self.image_size}'
assert channels == self.channels, 'number of channels on image or sketch is not equal to the channels set on this VQGanVAE'
fmap = self.encode(img)
fmap, indices, commit_loss = self.decode(fmap, return_indices_and_loss = True)
if not return_loss and not return_discr_loss:
return fmap
assert return_loss ^ return_discr_loss, 'you should either return autoencoder loss or discriminator loss, but not both'
# whether to return discriminator loss
if return_discr_loss:
assert exists(self.discr), 'discriminator must exist to train it'
fmap.detach_()
img.requires_grad_()
fmap_discr_logits, img_discr_logits = map(self.discr, (fmap, img))
discr_loss = self.discr_loss(fmap_discr_logits, img_discr_logits)
if add_gradient_penalty:
gp = gradient_penalty(img, img_discr_logits)
loss = discr_loss + gp
if return_recons:
return loss, fmap
return loss
# reconstruction loss
recon_loss = self.recon_loss_fn(fmap, img)
# early return if training on grayscale
if not self.use_vgg_and_gan:
if return_recons:
return recon_loss, fmap
return recon_loss
# perceptual loss
img_vgg_input = img
fmap_vgg_input = fmap
if img.shape[1] == 1:
# handle grayscale for vgg
img_vgg_input, fmap_vgg_input = map(lambda t: repeat(t, 'b 1 ... -> b c ...', c = 3), (img_vgg_input, fmap_vgg_input))
img_vgg_feats = self.vgg(img_vgg_input)
recon_vgg_feats = self.vgg(fmap_vgg_input)
perceptual_loss = F.mse_loss(img_vgg_feats, recon_vgg_feats)
# generator loss
gen_loss = self.gen_loss(self.discr(fmap))
# calculate adaptive weight
last_dec_layer = self.decoders[-1].weight
norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_dec_layer).norm(p = 2)
norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_dec_layer).norm(p = 2)
adaptive_weight = safe_div(norm_grad_wrt_perceptual_loss, norm_grad_wrt_gen_loss)
adaptive_weight.clamp_(max = 1e4)
# combine losses
loss = recon_loss + perceptual_loss + commit_loss + adaptive_weight * gen_loss
if return_recons:
return loss, fmap
return loss

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.26',
version = '0.0.50',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',
@@ -30,6 +30,7 @@ setup(
'torch>=1.10',
'torchvision',
'tqdm',
'vector-quantize-pytorch',
'x-clip>=0.4.4',
'youtokentome'
],