mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 11:34:29 +01:00
Compare commits
42 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
76e063e8b7 | ||
|
|
4d25976f33 | ||
|
|
0b28ee0d01 | ||
|
|
45262a4bb7 | ||
|
|
13a58a78c4 | ||
|
|
f75d49c781 | ||
|
|
3b520dfa85 | ||
|
|
79198c6ae4 | ||
|
|
77a246b1b9 | ||
|
|
f93a3f6ed8 | ||
|
|
8f2a0c7e00 | ||
|
|
863f4ef243 | ||
|
|
fb8a66a2de | ||
|
|
579d4b42dd | ||
|
|
473808850a | ||
|
|
d5318aef4f | ||
|
|
f82917e1fd | ||
|
|
05b74be69a | ||
|
|
a8b5d5d753 | ||
|
|
976ef7f87c | ||
|
|
fd175bcc0e | ||
|
|
76b32f18b3 | ||
|
|
f2d5b87677 | ||
|
|
461347c171 | ||
|
|
46cef31c86 | ||
|
|
59b1a77d4d | ||
|
|
7f338319fd | ||
|
|
2c6c91829d | ||
|
|
ad17c69ab6 | ||
|
|
0b4ec34efb | ||
|
|
f027b82e38 | ||
|
|
8cc9016cb0 | ||
|
|
1d8f37befe | ||
|
|
faebf4c8b8 | ||
|
|
b8e8d3c164 | ||
|
|
8e2416b49b | ||
|
|
f37c26e856 | ||
|
|
27a33e1b20 | ||
|
|
6f941a219a | ||
|
|
ddde8ca1bf | ||
|
|
c26b77ad20 | ||
|
|
c5b4aab8e5 |
165
README.md
165
README.md
@@ -1,6 +1,6 @@
|
||||
<img src="./dalle2.png" width="450px"></img>
|
||||
|
||||
## DALL-E 2 - Pytorch (wip)
|
||||
## DALL-E 2 - Pytorch
|
||||
|
||||
Implementation of <a href="https://openai.com/dall-e-2/">DALL-E 2</a>, OpenAI's updated text-to-image synthesis neural network, in Pytorch.
|
||||
|
||||
@@ -10,11 +10,9 @@ The main novelty seems to be an extra layer of indirection with the prior networ
|
||||
|
||||
This model is SOTA for text-to-image for now.
|
||||
|
||||
It may also explore an extension of using <a href="https://huggingface.co/spaces/multimodalart/latentdiffusion">latent diffusion</a> in the decoder from Rombach et al.
|
||||
|
||||
Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord" src="https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white"></a> if you are interested in helping out with the replication
|
||||
|
||||
There was enough interest for a Jax version. It will be completed after the Pytorch version shows signs of life on my toy tasks. <a href="https://github.com/lucidrains/dalle2-jax">Placeholder repository</a>
|
||||
There was enough interest for a <a href="https://github.com/lucidrains/dalle2-jax">Jax version</a>. I will also eventually extend this to <a href="https://github.com/lucidrains/dalle2-video">text to video</a>, once the repository is in a good place.
|
||||
|
||||
## Install
|
||||
|
||||
@@ -248,13 +246,6 @@ loss = decoder(images, unet_number = 2)
|
||||
loss.backward()
|
||||
|
||||
# do the above for many steps for both unets
|
||||
|
||||
# then it will learn to generate images based on the CLIP image embeddings
|
||||
|
||||
# chaining the unets from lowest resolution to highest resolution (thus cascading)
|
||||
|
||||
mock_image_embed = torch.randn(1, 512).cuda()
|
||||
images = decoder.sample(mock_image_embed) # (1, 3, 512, 512)
|
||||
```
|
||||
|
||||
Finally, to generate the DALL-E2 images from text. Insert the trained `DiffusionPrior` as well as the `Decoder` (which wraps `CLIP`, the causal transformer, and unet(s))
|
||||
@@ -385,7 +376,127 @@ You can also train the decoder on images of greater than the size (say 512x512)
|
||||
|
||||
For the layperson, no worries, training will all be automated into a CLI tool, at least for small scale training.
|
||||
|
||||
## CLI Usage (work in progress)
|
||||
## Experimental
|
||||
|
||||
### DALL-E2 with Latent Diffusion
|
||||
|
||||
This repository decides to take the next step and offer DALL-E2 combined with <a href="https://huggingface.co/spaces/multimodalart/latentdiffusion">latent diffusion</a>, from Rombach et al.
|
||||
|
||||
You can use it as follows. Latent diffusion can be limited to just the first U-Net in the cascade, or to any number you wish.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from dalle2_pytorch import Unet, Decoder, CLIP, VQGanVAE
|
||||
|
||||
# trained clip from step 1
|
||||
|
||||
clip = CLIP(
|
||||
dim_text = 512,
|
||||
dim_image = 512,
|
||||
dim_latent = 512,
|
||||
num_text_tokens = 49408,
|
||||
text_enc_depth = 1,
|
||||
text_seq_len = 256,
|
||||
text_heads = 8,
|
||||
visual_enc_depth = 1,
|
||||
visual_image_size = 256,
|
||||
visual_patch_size = 32,
|
||||
visual_heads = 8
|
||||
)
|
||||
|
||||
# 3 unets for the decoder (a la cascading DDPM)
|
||||
|
||||
# first two unets are doing latent diffusion
|
||||
# vqgan-vae must be trained before hand
|
||||
|
||||
vae1 = VQGanVAE(
|
||||
dim = 32,
|
||||
image_size = 256,
|
||||
layers = 3,
|
||||
layer_mults = (1, 2, 4)
|
||||
)
|
||||
|
||||
vae2 = VQGanVAE(
|
||||
dim = 32,
|
||||
image_size = 512,
|
||||
layers = 3,
|
||||
layer_mults = (1, 2, 4)
|
||||
)
|
||||
|
||||
unet1 = Unet(
|
||||
dim = 32,
|
||||
image_embed_dim = 512,
|
||||
cond_dim = 128,
|
||||
channels = 3,
|
||||
sparse_attn = True,
|
||||
sparse_attn_window = 2,
|
||||
dim_mults = (1, 2, 4, 8)
|
||||
)
|
||||
|
||||
unet2 = Unet(
|
||||
dim = 32,
|
||||
image_embed_dim = 512,
|
||||
channels = 3,
|
||||
dim_mults = (1, 2, 4, 8, 16),
|
||||
cond_on_image_embeds = True,
|
||||
cond_on_text_encodings = False
|
||||
)
|
||||
|
||||
unet3 = Unet(
|
||||
dim = 32,
|
||||
image_embed_dim = 512,
|
||||
channels = 3,
|
||||
dim_mults = (1, 2, 4, 8, 16),
|
||||
cond_on_image_embeds = True,
|
||||
cond_on_text_encodings = False,
|
||||
attend_at_middle = False
|
||||
)
|
||||
|
||||
# decoder, which contains the unet(s) and clip
|
||||
|
||||
decoder = Decoder(
|
||||
clip = clip,
|
||||
vae = (vae1, vae2), # latent diffusion for unet1 (vae1) and unet2 (vae2), but not for the last unet3
|
||||
unet = (unet1, unet2, unet3), # insert unets in order of low resolution to highest resolution (you can have as many stages as you want here)
|
||||
image_sizes = (256, 512, 1024), # resolutions, 256 for first unet, 512 for second, 1024 for third
|
||||
timesteps = 100,
|
||||
cond_drop_prob = 0.2
|
||||
).cuda()
|
||||
|
||||
# mock images (get a lot of this)
|
||||
|
||||
images = torch.randn(1, 3, 1024, 1024).cuda()
|
||||
|
||||
# feed images into decoder, specifying which unet you want to train
|
||||
# each unet can be trained separately, which is one of the benefits of the cascading DDPM scheme
|
||||
|
||||
with decoder.one_unet_in_gpu(1):
|
||||
loss = decoder(images, unet_number = 1)
|
||||
loss.backward()
|
||||
|
||||
with decoder.one_unet_in_gpu(2):
|
||||
loss = decoder(images, unet_number = 2)
|
||||
loss.backward()
|
||||
|
||||
with decoder.one_unet_in_gpu(3):
|
||||
loss = decoder(images, unet_number = 3)
|
||||
loss.backward()
|
||||
|
||||
# do the above for many steps for both unets
|
||||
|
||||
# then it will learn to generate images based on the CLIP image embeddings
|
||||
|
||||
# chaining the unets from lowest resolution to highest resolution (thus cascading)
|
||||
|
||||
mock_image_embed = torch.randn(1, 512).cuda()
|
||||
images = decoder.sample(mock_image_embed) # (1, 3, 1024, 1024)
|
||||
```
|
||||
|
||||
## Training wrapper (wip)
|
||||
|
||||
Offer training wrappers
|
||||
|
||||
## CLI (wip)
|
||||
|
||||
```bash
|
||||
$ dream 'sharing a sunset at the summit of mount everest with my dog'
|
||||
@@ -393,9 +504,7 @@ $ dream 'sharing a sunset at the summit of mount everest with my dog'
|
||||
|
||||
Once built, images will be saved to the same directory the command is invoked
|
||||
|
||||
## Training wrapper (wip)
|
||||
|
||||
Offer training wrappers
|
||||
<a href="https://github.com/lucidrains/big-sleep">template</a>
|
||||
|
||||
## Training CLI (wip)
|
||||
|
||||
@@ -410,10 +519,19 @@ Offer training wrappers
|
||||
- [x] figure out all the current bag of tricks needed to make DDPMs great (starting with the blur trick mentioned in paper)
|
||||
- [x] build the cascading ddpm by having Decoder class manage multiple unets at different resolutions
|
||||
- [x] add efficient attention in unet
|
||||
- [ ] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately)
|
||||
- [ ] build out latent diffusion architecture in separate file, as it is not faithful to dalle-2 (but offer it as as setting)
|
||||
- [x] be able to finely customize what to condition on (text, image embed) for specific unet in the cascade (super resolution ddpms near the end may not need too much conditioning)
|
||||
- [x] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately)
|
||||
- [x] build out latent diffusion architecture, with the vq-reg variant (vqgan-vae), make it completely optional and compatible with cascading ddpms
|
||||
- [x] for decoder, allow ability to customize objective (predict epsilon vs x0), in case latent diffusion does better with prediction of x0
|
||||
- [x] use attention-based upsampling https://arxiv.org/abs/2112.11435
|
||||
- [ ] spend one day cleaning up tech debt in decoder
|
||||
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
|
||||
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
|
||||
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
|
||||
- [ ] train on a toy task, offer in colab
|
||||
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
|
||||
- [ ] bring in tools to train vqgan-vae
|
||||
- [ ] bring in vit-vqgan https://arxiv.org/abs/2110.04627 for the latent diffusion
|
||||
|
||||
## Citations
|
||||
|
||||
@@ -445,23 +563,12 @@ Offer training wrappers
|
||||
|
||||
```bibtex
|
||||
@inproceedings{Liu2022ACF,
|
||||
title = {A ConvNet for the 2020s},
|
||||
title = {A ConvNet for the 2020https://arxiv.org/abs/2112.11435s},
|
||||
author = {Zhuang Liu and Hanzi Mao and Chaozheng Wu and Christoph Feichtenhofer and Trevor Darrell and Saining Xie},
|
||||
year = {2022}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{zhang2019root,
|
||||
title = {Root Mean Square Layer Normalization},
|
||||
author = {Biao Zhang and Rico Sennrich},
|
||||
year = {2019},
|
||||
eprint = {1910.07467},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.LG}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@inproceedings{Tu2022MaxViTMV,
|
||||
title = {MaxViT: Multi-Axis Vision Transformer},
|
||||
|
||||
@@ -1,2 +1,4 @@
|
||||
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
|
||||
|
||||
from dalle2_pytorch.vqgan_vae import VQGanVAE
|
||||
from x_clip import CLIP
|
||||
|
||||
125
dalle2_pytorch/attention.py
Normal file
125
dalle2_pytorch/attention.py
Normal file
@@ -0,0 +1,125 @@
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange, repeat
|
||||
|
||||
class LayerNormChan(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
eps = 1e-5
|
||||
):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1))
|
||||
|
||||
def forward(self, x):
|
||||
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
|
||||
mean = torch.mean(x, dim = 1, keepdim = True)
|
||||
return (x - mean) / (var + self.eps).sqrt() * self.gamma
|
||||
|
||||
# attention-based upsampling
|
||||
# from https://arxiv.org/abs/2112.11435
|
||||
|
||||
class QueryAndAttend(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dim,
|
||||
num_queries = 1,
|
||||
dim_head = 32,
|
||||
heads = 8,
|
||||
window_size = 3
|
||||
):
|
||||
super().__init__()
|
||||
self.scale = dim_head ** -0.5
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.dim_head = dim_head
|
||||
self.window_size = window_size
|
||||
self.num_queries = num_queries
|
||||
|
||||
self.rel_pos_bias = nn.Parameter(torch.randn(heads, num_queries, window_size * window_size, 1, 1))
|
||||
|
||||
self.queries = nn.Parameter(torch.randn(heads, num_queries, dim_head))
|
||||
self.to_kv = nn.Conv2d(dim, dim_head * 2, 1, bias = False)
|
||||
self.to_out = nn.Conv2d(inner_dim, dim, 1, bias = False)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
einstein notation
|
||||
b - batch
|
||||
h - heads
|
||||
l - num queries
|
||||
d - head dimension
|
||||
x - height
|
||||
y - width
|
||||
j - source sequence for attending to (kernel size squared in this case)
|
||||
"""
|
||||
|
||||
wsz, heads, dim_head, num_queries = self.window_size, self.heads, self.dim_head, self.num_queries
|
||||
batch, _, height, width = x.shape
|
||||
|
||||
is_one_query = self.num_queries == 1
|
||||
|
||||
# queries, keys, values
|
||||
|
||||
q = self.queries * self.scale
|
||||
k, v = self.to_kv(x).chunk(2, dim = 1)
|
||||
|
||||
# similarities
|
||||
|
||||
sim = einsum('h l d, b d x y -> b h l x y', q, k)
|
||||
sim = rearrange(sim, 'b ... x y -> b (...) x y')
|
||||
|
||||
# unfold the similarity scores, with float(-inf) as padding value
|
||||
|
||||
mask_value = -torch.finfo(sim.dtype).max
|
||||
sim = F.pad(sim, ((wsz // 2,) * 4), value = mask_value)
|
||||
sim = F.unfold(sim, kernel_size = wsz)
|
||||
sim = rearrange(sim, 'b (h l j) (x y) -> b h l j x y', h = heads, l = num_queries, x = height, y = width)
|
||||
|
||||
# rel pos bias
|
||||
|
||||
sim = sim + self.rel_pos_bias
|
||||
|
||||
# numerically stable attention
|
||||
|
||||
sim = sim - sim.amax(dim = -3, keepdim = True).detach()
|
||||
attn = sim.softmax(dim = -3)
|
||||
|
||||
# unfold values
|
||||
|
||||
v = F.pad(v, ((wsz // 2,) * 4), value = 0.)
|
||||
v = F.unfold(v, kernel_size = wsz)
|
||||
v = rearrange(v, 'b (d j) (x y) -> b d j x y', d = dim_head, x = height, y = width)
|
||||
|
||||
# aggregate values
|
||||
|
||||
out = einsum('b h l j x y, b d j x y -> b l h d x y', attn, v)
|
||||
|
||||
# combine heads
|
||||
|
||||
out = rearrange(out, 'b l h d x y -> (b l) (h d) x y')
|
||||
out = self.to_out(out)
|
||||
out = rearrange(out, '(b l) d x y -> b l d x y', b = batch)
|
||||
|
||||
# return original input if one query
|
||||
|
||||
if is_one_query:
|
||||
out = rearrange(out, 'b 1 ... -> b ...')
|
||||
|
||||
return out
|
||||
|
||||
class QueryAttnUpsample(nn.Module):
|
||||
def __init__(self, dim, **kwargs):
|
||||
super().__init__()
|
||||
self.norm = LayerNormChan(dim)
|
||||
self.qna = QueryAndAttend(dim = dim, num_queries = 4, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
out = self.qna(x)
|
||||
out = rearrange(out, 'b (w1 w2) c h w -> b c (h w1) (w w2)', w1 = 2, w2 = 2)
|
||||
return out
|
||||
@@ -1,9 +1,51 @@
|
||||
import click
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
from pathlib import Path
|
||||
|
||||
from dalle2_pytorch import DALLE2, Decoder, DiffusionPrior
|
||||
|
||||
def safeget(dictionary, keys, default = None):
|
||||
return reduce(lambda d, key: d.get(key, default) if isinstance(d, dict) else default, keys.split('.'), dictionary)
|
||||
|
||||
def simple_slugify(text, max_length = 255):
|
||||
return text.replace("-", "_").replace(",", "").replace(" ", "_").replace("|", "--").strip('-_')[:max_length]
|
||||
|
||||
def get_pkg_version():
|
||||
from pkg_resources import get_distribution
|
||||
return get_distribution('dalle2_pytorch').version
|
||||
|
||||
def main():
|
||||
pass
|
||||
|
||||
@click.command()
|
||||
@click.option('--model', default = './dalle2.pt', help = 'path to trained DALL-E2 model')
|
||||
@click.option('--cond_scale', default = 2, help = 'conditioning scale (classifier free guidance) in decoder')
|
||||
@click.argument('text')
|
||||
def dream(text):
|
||||
return image
|
||||
def dream(
|
||||
model,
|
||||
cond_scale,
|
||||
text
|
||||
):
|
||||
model_path = Path(model)
|
||||
full_model_path = str(model_path.resolve())
|
||||
assert model_path.exists(), f'model not found at {full_model_path}'
|
||||
loaded = torch.load(str(model_path))
|
||||
|
||||
version = safeget(loaded, 'version')
|
||||
print(f'loading DALL-E2 from {full_model_path}, saved at version {version} - current package version is {get_pkg_version()}')
|
||||
|
||||
prior_init_params = safeget(loaded, 'init_params.prior')
|
||||
decoder_init_params = safeget(loaded, 'init_params.decoder')
|
||||
model_params = safeget(loaded, 'model_params')
|
||||
|
||||
prior = DiffusionPrior(**prior_init_params)
|
||||
decoder = Decoder(**decoder_init_params)
|
||||
|
||||
dalle2 = DALLE2(prior, decoder)
|
||||
dalle2.load_state_dict(model_params)
|
||||
|
||||
image = dalle2(text, cond_scale = cond_scale)
|
||||
|
||||
pil_image = T.ToPILImage()(image)
|
||||
return pil_image.save(f'./{simple_slugify(text)}.png')
|
||||
|
||||
@@ -2,6 +2,7 @@ import math
|
||||
from tqdm import tqdm
|
||||
from inspect import isfunction
|
||||
from functools import partial
|
||||
from contextlib import contextmanager
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -15,6 +16,8 @@ from einops_exts.torch import EinopsToAndFrom
|
||||
from kornia.filters import gaussian_blur2d
|
||||
|
||||
from dalle2_pytorch.tokenizer import tokenizer
|
||||
from dalle2_pytorch.vqgan_vae import NullVQGanVAE, VQGanVAE
|
||||
from dalle2_pytorch.attention import QueryAttnUpsample
|
||||
|
||||
# use x-clip
|
||||
|
||||
@@ -47,6 +50,12 @@ def is_list_str(x):
|
||||
return False
|
||||
return all([type(el) == str for el in x])
|
||||
|
||||
def pad_tuple_to_length(t, length, fillvalue = None):
|
||||
remain_length = length - len(t)
|
||||
if remain_length <= 0:
|
||||
return t
|
||||
return (*t, *((fillvalue,) * remain_length))
|
||||
|
||||
# for controlling freezing of CLIP
|
||||
|
||||
def set_module_requires_grad_(module, requires_grad):
|
||||
@@ -105,8 +114,8 @@ def cosine_beta_schedule(timesteps, s = 0.008):
|
||||
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
|
||||
"""
|
||||
steps = timesteps + 1
|
||||
x = torch.linspace(0, steps, steps)
|
||||
alphas_cumprod = torch.cos(((x / steps) + s) / (1 + s) * torch.pi * 0.5) ** 2
|
||||
x = torch.linspace(0, timesteps, steps)
|
||||
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
|
||||
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
|
||||
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
||||
return torch.clip(betas, 0, 0.999)
|
||||
@@ -136,23 +145,27 @@ def sigmoid_beta_schedule(timesteps):
|
||||
|
||||
# diffusion prior
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.gamma = nn.Parameter(torch.ones(dim))
|
||||
self.register_buffer("beta", torch.zeros(dim))
|
||||
|
||||
def forward(self, x):
|
||||
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
|
||||
|
||||
|
||||
class ChanLayerNorm(nn.Module):
|
||||
def __init__(self, dim, eps = 1e-5):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.scale = dim ** 0.5
|
||||
self.gamma = nn.Parameter(torch.ones(dim))
|
||||
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
|
||||
|
||||
def forward(self, x):
|
||||
squared_sum = (x ** 2).sum(dim = -1, keepdim = True)
|
||||
inv_norm = torch.rsqrt(squared_sum + self.eps)
|
||||
return x * inv_norm * self.gamma * self.scale
|
||||
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
|
||||
mean = torch.mean(x, dim = 1, keepdim = True)
|
||||
return (x - mean) / (var + self.eps).sqrt() * self.g
|
||||
|
||||
class ChanRMSNorm(RMSNorm):
|
||||
def forward(self, x):
|
||||
squared_sum = (x ** 2).sum(dim = 1, keepdim = True)
|
||||
inv_norm = torch.rsqrt(squared_sum + self.eps)
|
||||
return x * inv_norm * rearrange(self.gamma, 'c -> 1 c 1 1') * self.scale
|
||||
|
||||
class Residual(nn.Module):
|
||||
def __init__(self, fn):
|
||||
@@ -248,10 +261,10 @@ def FeedForward(dim, mult = 4, dropout = 0., post_activation_norm = False):
|
||||
|
||||
inner_dim = int(mult * dim)
|
||||
return nn.Sequential(
|
||||
RMSNorm(dim),
|
||||
LayerNorm(dim),
|
||||
nn.Linear(dim, inner_dim * 2, bias = False),
|
||||
SwiGLU(),
|
||||
RMSNorm(inner_dim) if post_activation_norm else nn.Identity(),
|
||||
LayerNorm(inner_dim) if post_activation_norm else nn.Identity(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(inner_dim, dim, bias = False)
|
||||
)
|
||||
@@ -274,7 +287,8 @@ class Attention(nn.Module):
|
||||
inner_dim = dim_head * heads
|
||||
|
||||
self.causal = causal
|
||||
self.norm = RMSNorm(dim)
|
||||
self.norm = LayerNorm(dim)
|
||||
self.post_norm = LayerNorm(dim) # sandwich norm from Coqview paper + Normformer
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
|
||||
@@ -330,7 +344,8 @@ class Attention(nn.Module):
|
||||
out = einsum('b h i j, b j d -> b h i d', attn, v)
|
||||
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
out = self.to_out(out)
|
||||
return self.post_norm(out)
|
||||
|
||||
class CausalTransformer(nn.Module):
|
||||
def __init__(
|
||||
@@ -343,7 +358,8 @@ class CausalTransformer(nn.Module):
|
||||
ff_mult = 4,
|
||||
norm_out = False,
|
||||
attn_dropout = 0.,
|
||||
ff_dropout = 0.
|
||||
ff_dropout = 0.,
|
||||
final_proj = True
|
||||
):
|
||||
super().__init__()
|
||||
self.rel_pos_bias = RelPosBias(heads = heads)
|
||||
@@ -355,7 +371,8 @@ class CausalTransformer(nn.Module):
|
||||
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
|
||||
]))
|
||||
|
||||
self.norm = RMSNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
|
||||
self.norm = LayerNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
|
||||
self.project_out = nn.Linear(dim, dim, bias = False) if final_proj else nn.Identity()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -370,7 +387,8 @@ class CausalTransformer(nn.Module):
|
||||
x = attn(x, mask = mask, attn_bias = attn_bias) + x
|
||||
x = ff(x) + x
|
||||
|
||||
return self.norm(x)
|
||||
out = self.norm(x)
|
||||
return self.project_out(out)
|
||||
|
||||
class DiffusionPriorNetwork(nn.Module):
|
||||
def __init__(
|
||||
@@ -403,25 +421,41 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
image_embed,
|
||||
diffusion_timesteps,
|
||||
*,
|
||||
text_encodings,
|
||||
text_embed,
|
||||
text_encodings = None,
|
||||
mask = None,
|
||||
cond_drop_prob = 0.2
|
||||
):
|
||||
batch, text_enc_len, device = image_embed.shape[0], text_encodings.shape[-2], image_embed.device
|
||||
batch, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype
|
||||
|
||||
# in section 2.2, last paragraph
|
||||
# "... consisting of encoded text, CLIP text embedding, diffusion timestep embedding, noised CLIP image embedding, final embedding for prediction"
|
||||
|
||||
text_embed, image_embed = rearrange_many((text_embed, image_embed), 'b d -> b 1 d')
|
||||
|
||||
# make text encodings optional
|
||||
# although the paper seems to suggest it is present <--
|
||||
|
||||
if not exists(text_encodings):
|
||||
text_encodings = torch.empty((batch, 0, dim), device = device, dtype = dtype)
|
||||
|
||||
if not exists(mask):
|
||||
mask = torch.ones((batch, text_encodings.shape[-2]), device = device, dtype = torch.bool)
|
||||
|
||||
# classifier free guidance
|
||||
|
||||
cond_prob_mask = prob_mask_like((batch,), cond_drop_prob, device = device)
|
||||
cond_prob_mask = rearrange(cond_prob_mask, 'b -> b 1')
|
||||
|
||||
mask &= cond_prob_mask
|
||||
|
||||
# whether text embedding is masked or not depends on the classifier free guidance conditional masking
|
||||
|
||||
mask = torch.cat((mask, cond_prob_mask), dim = 1)
|
||||
|
||||
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
|
||||
# but let's just do it right
|
||||
|
||||
if exists(mask):
|
||||
not_all_masked_out = mask.any(dim = -1)
|
||||
mask = torch.cat((mask, rearrange(not_all_masked_out, 'b -> b 1')), dim = 1)
|
||||
|
||||
if exists(mask):
|
||||
mask = F.pad(mask, (0, 2), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
|
||||
|
||||
@@ -437,16 +471,6 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
learned_queries
|
||||
), dim = -2)
|
||||
|
||||
# mask if it doesn't exist
|
||||
|
||||
if not exists(mask):
|
||||
mask = torch.ones((batch, text_enc_len), device = device, dtype = torch.bool)
|
||||
|
||||
# classifier free guidance
|
||||
|
||||
cond_prob_mask = prob_mask_like((batch,), cond_drop_prob, device = device)
|
||||
mask &= rearrange(cond_prob_mask, 'b -> b 1')
|
||||
|
||||
# attend
|
||||
|
||||
tokens = self.causal_transformer(tokens, mask = mask)
|
||||
@@ -463,11 +487,12 @@ class DiffusionPrior(nn.Module):
|
||||
net,
|
||||
*,
|
||||
clip,
|
||||
timesteps=1000,
|
||||
cond_drop_prob=0.2,
|
||||
loss_type="l1",
|
||||
predict_x0=True,
|
||||
beta_schedule="cosine",
|
||||
timesteps = 1000,
|
||||
cond_drop_prob = 0.2,
|
||||
loss_type = "l1",
|
||||
predict_x_start = True,
|
||||
beta_schedule = "cosine",
|
||||
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(clip, CLIP)
|
||||
@@ -478,9 +503,11 @@ class DiffusionPrior(nn.Module):
|
||||
self.image_embed_dim = clip.dim_latent
|
||||
self.channels = clip.image_channels
|
||||
self.image_size = clip.image_size
|
||||
self.cond_drop_prob = cond_drop_prob
|
||||
|
||||
self.predict_x0 = predict_x0
|
||||
self.cond_drop_prob = cond_drop_prob
|
||||
self.condition_on_text_encodings = condition_on_text_encodings
|
||||
|
||||
self.predict_x_start = predict_x_start
|
||||
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
|
||||
|
||||
if beta_schedule == "cosine":
|
||||
@@ -497,7 +524,7 @@ class DiffusionPrior(nn.Module):
|
||||
raise NotImplementedError()
|
||||
|
||||
alphas = 1. - betas
|
||||
alphas_cumprod = torch.cumprod(alphas, axis=0)
|
||||
alphas_cumprod = torch.cumprod(alphas, axis = 0)
|
||||
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
|
||||
|
||||
timesteps, = betas.shape
|
||||
@@ -530,17 +557,23 @@ class DiffusionPrior(nn.Module):
|
||||
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
||||
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
|
||||
|
||||
@torch.no_grad()
|
||||
def get_image_embed(self, image):
|
||||
image_encoding = self.clip.visual_transformer(image)
|
||||
image_cls = image_encoding[:, 0]
|
||||
image_embed = self.clip.to_visual_latent(image_cls)
|
||||
return l2norm(image_embed)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_text_cond(self, text):
|
||||
text_encodings = self.clip.text_transformer(text)
|
||||
text_cls, text_encodings = text_encodings[:, 0], text_encodings[:, 1:]
|
||||
text_embed = self.clip.to_text_latent(text_cls)
|
||||
text_embed = l2norm(text_embed)
|
||||
|
||||
if not self.condition_on_text_encodings:
|
||||
return dict(text_embed = text_embed)
|
||||
|
||||
return dict(text_encodings = text_encodings, text_embed = text_embed, mask = text != 0)
|
||||
|
||||
def q_mean_variance(self, x_start, t):
|
||||
@@ -565,14 +598,16 @@ class DiffusionPrior(nn.Module):
|
||||
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
||||
|
||||
def p_mean_variance(self, x, t, text_cond, clip_denoised: bool):
|
||||
if self.predict_x0:
|
||||
x_recon = self.net(x, t, **text_cond)
|
||||
pred = self.net(x, t, **text_cond)
|
||||
|
||||
if self.predict_x_start:
|
||||
x_recon = pred
|
||||
# not 100% sure of this above line - for any spectators, let me know in the github issues (or through a pull request) if you know how to correctly do this
|
||||
# i'll be rereading https://arxiv.org/abs/2111.14822, where i think a similar approach is taken
|
||||
else:
|
||||
x_recon = self.predict_start_from_noise(x, t = t, noise = self.net(x, t, **text_cond))
|
||||
x_recon = self.predict_start_from_noise(x, t = t, noise = pred)
|
||||
|
||||
if clip_denoised:
|
||||
if clip_denoised and not self.predict_x_start:
|
||||
x_recon.clamp_(-1., 1.)
|
||||
|
||||
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
||||
@@ -618,7 +653,7 @@ class DiffusionPrior(nn.Module):
|
||||
**text_cond
|
||||
)
|
||||
|
||||
to_predict = noise if not self.predict_x0 else image_embed
|
||||
to_predict = noise if not self.predict_x_start else image_embed
|
||||
|
||||
if self.loss_type == 'l1':
|
||||
loss = F.l1_loss(to_predict, x_recon)
|
||||
@@ -719,7 +754,7 @@ class ConvNextBlock(nn.Module):
|
||||
|
||||
inner_dim = int(dim_out * mult)
|
||||
self.net = nn.Sequential(
|
||||
ChanRMSNorm(dim) if norm else nn.Identity(),
|
||||
ChanLayerNorm(dim) if norm else nn.Identity(),
|
||||
nn.Conv2d(dim, inner_dim, 3, padding = 1),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(inner_dim, dim_out, 3, padding = 1)
|
||||
@@ -755,8 +790,8 @@ class CrossAttention(nn.Module):
|
||||
|
||||
context_dim = default(context_dim, dim)
|
||||
|
||||
self.norm = RMSNorm(dim)
|
||||
self.norm_context = RMSNorm(context_dim)
|
||||
self.norm = LayerNorm(dim)
|
||||
self.norm_context = LayerNorm(context_dim)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
|
||||
@@ -820,16 +855,18 @@ class Unet(nn.Module):
|
||||
image_embed_dim,
|
||||
cond_dim = None,
|
||||
num_image_tokens = 4,
|
||||
num_time_tokens = 2,
|
||||
out_dim = None,
|
||||
dim_mults=(1, 2, 4, 8),
|
||||
channels = 3,
|
||||
attn_dim_head = 32,
|
||||
attn_heads = 8,
|
||||
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
|
||||
lowres_cond_upsample_mode = 'bilinear',
|
||||
blur_sigma = 0.1,
|
||||
blur_kernel_size = 3,
|
||||
sparse_attn = False,
|
||||
sparse_attn_window = 8, # window size for sparse attention
|
||||
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
|
||||
cond_on_text_encodings = False,
|
||||
cond_on_image_embeds = False,
|
||||
):
|
||||
super().__init__()
|
||||
# save locals to take care of some hyperparameters for cascading DDPM
|
||||
@@ -841,9 +878,6 @@ class Unet(nn.Module):
|
||||
# for eventual cascading diffusion
|
||||
|
||||
self.lowres_cond = lowres_cond
|
||||
self.lowres_cond_upsample_mode = lowres_cond_upsample_mode
|
||||
self.lowres_blur_kernel_size = blur_kernel_size
|
||||
self.lowres_blur_sigma = blur_sigma
|
||||
|
||||
# determine dimensions
|
||||
|
||||
@@ -862,8 +896,8 @@ class Unet(nn.Module):
|
||||
SinusoidalPosEmb(dim),
|
||||
nn.Linear(dim, dim * 4),
|
||||
nn.GELU(),
|
||||
nn.Linear(dim * 4, cond_dim),
|
||||
Rearrange('b d -> b 1 d')
|
||||
nn.Linear(dim * 4, cond_dim * num_time_tokens),
|
||||
Rearrange('b (r d) -> b r d', r = num_time_tokens)
|
||||
)
|
||||
|
||||
self.image_to_cond = nn.Sequential(
|
||||
@@ -873,11 +907,21 @@ class Unet(nn.Module):
|
||||
|
||||
self.text_to_cond = nn.LazyLinear(cond_dim)
|
||||
|
||||
# finer control over whether to condition on image embeddings and text encodings
|
||||
# so one can have the latter unets in the cascading DDPMs only focus on super-resoluting
|
||||
|
||||
self.cond_on_text_encodings = cond_on_text_encodings
|
||||
self.cond_on_image_embeds = cond_on_image_embeds
|
||||
|
||||
# for classifier free guidance
|
||||
|
||||
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
|
||||
self.null_text_embed = nn.Parameter(torch.randn(1, 1, cond_dim))
|
||||
|
||||
# attention related params
|
||||
|
||||
attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head)
|
||||
|
||||
# layers
|
||||
|
||||
self.downs = nn.ModuleList([])
|
||||
@@ -891,7 +935,7 @@ class Unet(nn.Module):
|
||||
|
||||
self.downs.append(nn.ModuleList([
|
||||
ConvNextBlock(dim_in, dim_out, norm = ind != 0),
|
||||
Residual(GridAttention(dim_out, window_size = sparse_attn_window)) if sparse_attn else nn.Identity(),
|
||||
Residual(GridAttention(dim_out, window_size = sparse_attn_window, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
||||
ConvNextBlock(dim_out, dim_out, cond_dim = layer_cond_dim),
|
||||
Downsample(dim_out) if not is_last else nn.Identity()
|
||||
]))
|
||||
@@ -899,7 +943,7 @@ class Unet(nn.Module):
|
||||
mid_dim = dims[-1]
|
||||
|
||||
self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
|
||||
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim))) if attend_at_middle else None
|
||||
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
|
||||
self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
|
||||
|
||||
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
||||
@@ -908,7 +952,7 @@ class Unet(nn.Module):
|
||||
|
||||
self.ups.append(nn.ModuleList([
|
||||
ConvNextBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim),
|
||||
Residual(GridAttention(dim_in, window_size = sparse_attn_window)) if sparse_attn else nn.Identity(),
|
||||
Residual(GridAttention(dim_in, window_size = sparse_attn_window, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
||||
ConvNextBlock(dim_in, dim_in, cond_dim = layer_cond_dim),
|
||||
Upsample(dim_in)
|
||||
]))
|
||||
@@ -921,11 +965,16 @@ class Unet(nn.Module):
|
||||
|
||||
# if the current settings for the unet are not correct
|
||||
# for cascading DDPM, then reinit the unet with the right settings
|
||||
def force_lowres_cond(self, lowres_cond):
|
||||
if lowres_cond == self.lowres_cond:
|
||||
def cast_model_parameters(
|
||||
self,
|
||||
*,
|
||||
lowres_cond,
|
||||
channels
|
||||
):
|
||||
if lowres_cond == self.lowres_cond and channels == self.channels:
|
||||
return self
|
||||
|
||||
updated_kwargs = {**self._locals, 'lowres_cond': lowres_cond}
|
||||
updated_kwargs = {**self._locals, 'lowres_cond': lowres_cond, 'channels': channels}
|
||||
return self.__class__(**updated_kwargs)
|
||||
|
||||
def forward_with_cond_scale(
|
||||
@@ -961,13 +1010,6 @@ class Unet(nn.Module):
|
||||
assert not (self.lowres_cond and not exists(lowres_cond_img)), 'low resolution conditioning image must be present'
|
||||
|
||||
if exists(lowres_cond_img):
|
||||
if self.training:
|
||||
# when training, blur the low resolution conditional image
|
||||
blur_sigma = default(blur_sigma, self.lowres_blur_sigma)
|
||||
blur_kernel_size = default(blur_kernel_size, self.lowres_blur_kernel_size)
|
||||
lowres_cond_img = gaussian_blur2d(lowres_cond_img, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2))
|
||||
|
||||
lowres_cond_img = resize_image_to(lowres_cond_img, x.shape[-2:], mode = self.lowres_cond_upsample_mode)
|
||||
x = torch.cat((x, lowres_cond_img), dim = 1)
|
||||
|
||||
# time conditioning
|
||||
@@ -982,17 +1024,22 @@ class Unet(nn.Module):
|
||||
# mask out image embedding depending on condition dropout
|
||||
# for classifier free guidance
|
||||
|
||||
image_tokens = self.image_to_cond(image_embed)
|
||||
image_tokens = None
|
||||
|
||||
image_tokens = torch.where(
|
||||
cond_prob_mask,
|
||||
image_tokens,
|
||||
self.null_image_embed
|
||||
)
|
||||
if self.cond_on_image_embeds:
|
||||
image_tokens = self.image_to_cond(image_embed)
|
||||
|
||||
image_tokens = torch.where(
|
||||
cond_prob_mask,
|
||||
image_tokens,
|
||||
self.null_image_embed
|
||||
)
|
||||
|
||||
# take care of text encodings (optional)
|
||||
|
||||
if exists(text_encodings):
|
||||
text_tokens = None
|
||||
|
||||
if exists(text_encodings) and self.cond_on_text_encodings:
|
||||
text_tokens = self.text_to_cond(text_encodings)
|
||||
text_tokens = torch.where(
|
||||
cond_prob_mask,
|
||||
@@ -1002,12 +1049,15 @@ class Unet(nn.Module):
|
||||
|
||||
# main conditioning tokens (c)
|
||||
|
||||
c = torch.cat((time_tokens, image_tokens), dim = -2)
|
||||
c = time_tokens
|
||||
|
||||
if exists(image_tokens):
|
||||
c = torch.cat((c, image_tokens), dim = -2)
|
||||
|
||||
# text and image conditioning tokens (mid_c)
|
||||
# to save on compute, only do cross attention based conditioning on the inner most layers of the Unet
|
||||
|
||||
mid_c = c if not exists(text_encodings) else torch.cat((c, text_tokens), dim = -2)
|
||||
mid_c = c if not exists(text_tokens) else torch.cat((c, text_tokens), dim = -2)
|
||||
|
||||
# go through the layers of the unet, down and up
|
||||
|
||||
@@ -1036,17 +1086,62 @@ class Unet(nn.Module):
|
||||
|
||||
return self.final_conv(x)
|
||||
|
||||
class LowresConditioner(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
cond_upsample_mode = 'bilinear',
|
||||
downsample_first = True,
|
||||
blur_sigma = 0.1,
|
||||
blur_kernel_size = 3,
|
||||
):
|
||||
super().__init__()
|
||||
self.cond_upsample_mode = cond_upsample_mode
|
||||
self.downsample_first = downsample_first
|
||||
self.blur_sigma = blur_sigma
|
||||
self.blur_kernel_size = blur_kernel_size
|
||||
|
||||
def forward(
|
||||
self,
|
||||
cond_fmap,
|
||||
*,
|
||||
target_image_size,
|
||||
downsample_image_size = None,
|
||||
blur_sigma = None,
|
||||
blur_kernel_size = None
|
||||
):
|
||||
target_image_size = cast_tuple(target_image_size, 2)
|
||||
|
||||
if self.training and self.downsample_first and exists(downsample_image_size):
|
||||
cond_fmap = resize_image_to(cond_fmap, target_image_size, mode = self.cond_upsample_mode)
|
||||
|
||||
if self.training:
|
||||
# when training, blur the low resolution conditional image
|
||||
blur_sigma = default(blur_sigma, self.blur_sigma)
|
||||
blur_kernel_size = default(blur_kernel_size, self.blur_kernel_size)
|
||||
cond_fmap = gaussian_blur2d(cond_fmap, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2))
|
||||
|
||||
cond_fmap = resize_image_to(cond_fmap, target_image_size, mode = self.cond_upsample_mode)
|
||||
|
||||
return cond_fmap
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
unet,
|
||||
*,
|
||||
clip,
|
||||
vae = tuple(),
|
||||
timesteps = 1000,
|
||||
cond_drop_prob = 0.2,
|
||||
loss_type = 'l1',
|
||||
beta_schedule = 'cosine',
|
||||
image_sizes = None # for cascading ddpm, image size at each stage
|
||||
predict_x_start = False,
|
||||
predict_x_start_for_latent_diffusion = False,
|
||||
image_sizes = None, # for cascading ddpm, image size at each stage
|
||||
lowres_cond_upsample_mode = 'bilinear', # cascading ddpm - low resolution upsample mode
|
||||
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
|
||||
blur_sigma = 0.1, # cascading ddpm - blur sigma
|
||||
blur_kernel_size = 3, # cascading ddpm - blur kernel size
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(clip, CLIP)
|
||||
@@ -1058,11 +1153,28 @@ class Decoder(nn.Module):
|
||||
# automatically take care of ensuring that first unet is unconditional
|
||||
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
|
||||
|
||||
unets = cast_tuple(unet)
|
||||
vaes = pad_tuple_to_length(cast_tuple(vae), len(unets), fillvalue = NullVQGanVAE(channels = self.channels))
|
||||
|
||||
self.unets = nn.ModuleList([])
|
||||
for ind, one_unet in enumerate(cast_tuple(unet)):
|
||||
self.vaes = nn.ModuleList([])
|
||||
|
||||
for ind, (one_unet, one_vae) in enumerate(zip(unets, vaes)):
|
||||
assert isinstance(one_unet, Unet)
|
||||
assert isinstance(one_vae, (VQGanVAE, NullVQGanVAE))
|
||||
|
||||
is_first = ind == 0
|
||||
one_unet = one_unet.force_lowres_cond(not is_first)
|
||||
latent_dim = one_vae.encoded_dim if exists(one_vae) else None
|
||||
|
||||
unet_channels = default(latent_dim, self.channels)
|
||||
|
||||
one_unet = one_unet.cast_model_parameters(
|
||||
lowres_cond = not is_first,
|
||||
channels = unet_channels
|
||||
)
|
||||
|
||||
self.unets.append(one_unet)
|
||||
self.vaes.append(one_vae.copy_for_eval())
|
||||
|
||||
# unet image sizes
|
||||
|
||||
@@ -1071,12 +1183,30 @@ class Decoder(nn.Module):
|
||||
|
||||
assert len(self.unets) == len(image_sizes), f'you did not supply the correct number of u-nets ({len(self.unets)}) for resolutions {image_sizes}'
|
||||
self.image_sizes = image_sizes
|
||||
self.sample_channels = cast_tuple(self.channels, len(image_sizes))
|
||||
|
||||
# predict x0 config
|
||||
|
||||
self.predict_x_start = cast_tuple(predict_x_start, len(unets)) if not predict_x_start_for_latent_diffusion else tuple(map(lambda t: isinstance(t, VQGanVAE), self.vaes))
|
||||
|
||||
# cascading ddpm related stuff
|
||||
|
||||
lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets))
|
||||
assert lowres_conditions == (False, *((True,) * (len(self.unets) - 1))), 'the first unet must be unconditioned (by low resolution image), and the rest of the unets must have `lowres_cond` set to True'
|
||||
|
||||
self.to_lowres_cond = LowresConditioner(
|
||||
cond_upsample_mode = lowres_cond_upsample_mode,
|
||||
downsample_first = lowres_downsample_first,
|
||||
blur_sigma = blur_sigma,
|
||||
blur_kernel_size = blur_kernel_size,
|
||||
)
|
||||
|
||||
# classifier free guidance
|
||||
|
||||
self.cond_drop_prob = cond_drop_prob
|
||||
|
||||
# noise schedule
|
||||
|
||||
if beta_schedule == "cosine":
|
||||
betas = cosine_beta_schedule(timesteps)
|
||||
elif beta_schedule == "linear":
|
||||
@@ -1091,7 +1221,7 @@ class Decoder(nn.Module):
|
||||
raise NotImplementedError()
|
||||
|
||||
alphas = 1. - betas
|
||||
alphas_cumprod = torch.cumprod(alphas, axis=0)
|
||||
alphas_cumprod = torch.cumprod(alphas, axis = 0)
|
||||
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
|
||||
|
||||
timesteps, = betas.shape
|
||||
@@ -1124,10 +1254,31 @@ class Decoder(nn.Module):
|
||||
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
||||
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
|
||||
|
||||
def get_unet(self, unet_number):
|
||||
assert 0 < unet_number <= len(self.unets)
|
||||
index = unet_number - 1
|
||||
return self.unets[index]
|
||||
|
||||
@contextmanager
|
||||
def one_unet_in_gpu(self, unet_number = None, unet = None):
|
||||
assert exists(unet_number) ^ exists(unet)
|
||||
|
||||
if exists(unet_number):
|
||||
unet = self.get_unet(unet_number)
|
||||
|
||||
self.cuda()
|
||||
self.unets.cpu()
|
||||
|
||||
unet.cuda()
|
||||
yield
|
||||
unet.cpu()
|
||||
|
||||
@torch.no_grad()
|
||||
def get_text_encodings(self, text):
|
||||
text_encodings = self.clip.text_transformer(text)
|
||||
return text_encodings[:, 1:]
|
||||
|
||||
@torch.no_grad()
|
||||
def get_image_embed(self, image):
|
||||
image = resize_image_to(image, self.clip_image_size)
|
||||
image_encoding = self.clip.visual_transformer(image)
|
||||
@@ -1156,34 +1307,48 @@ class Decoder(nn.Module):
|
||||
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
||||
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
||||
|
||||
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, lowres_cond_img = None, clip_denoised = True, cond_scale = 1.):
|
||||
pred_noise = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
|
||||
x_recon = self.predict_start_from_noise(x, t = t, noise = pred_noise)
|
||||
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, cond_scale = 1.):
|
||||
pred = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
|
||||
|
||||
if clip_denoised:
|
||||
if predict_x_start:
|
||||
x_recon = pred
|
||||
else:
|
||||
x_recon = self.predict_start_from_noise(x, t = t, noise = pred)
|
||||
|
||||
if clip_denoised and not predict_x_start:
|
||||
x_recon.clamp_(-1., 1.)
|
||||
|
||||
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
||||
return model_mean, posterior_variance, posterior_log_variance
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample(self, unet, x, t, image_embed, text_encodings = None, cond_scale = 1., lowres_cond_img = None, clip_denoised = True, repeat_noise = False):
|
||||
def p_sample(self, unet, x, t, image_embed, text_encodings = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, clip_denoised = True, repeat_noise = False):
|
||||
b, *_, device = *x.shape, x.device
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised)
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start)
|
||||
noise = noise_like(x.shape, device, repeat_noise)
|
||||
# no noise when t == 0
|
||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_loop(self, unet, shape, image_embed, lowres_cond_img = None, text_encodings = None, cond_scale = 1):
|
||||
def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, lowres_cond_img = None, text_encodings = None, cond_scale = 1):
|
||||
device = self.betas.device
|
||||
|
||||
b = shape[0]
|
||||
img = torch.randn(shape, device = device)
|
||||
|
||||
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
|
||||
img = self.p_sample(unet, img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
|
||||
img = self.p_sample(
|
||||
unet,
|
||||
img,
|
||||
torch.full((b,), i, device = device, dtype = torch.long),
|
||||
image_embed = image_embed,
|
||||
text_encodings = text_encodings,
|
||||
cond_scale = cond_scale,
|
||||
lowres_cond_img = lowres_cond_img,
|
||||
predict_x_start = predict_x_start
|
||||
)
|
||||
|
||||
return img
|
||||
|
||||
def q_sample(self, x_start, t, noise=None):
|
||||
@@ -1194,7 +1359,7 @@ class Decoder(nn.Module):
|
||||
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
||||
)
|
||||
|
||||
def p_losses(self, unet, x_start, t, *, image_embed, lowres_cond_img = None, text_encodings = None, noise = None):
|
||||
def p_losses(self, unet, x_start, t, *, image_embed, lowres_cond_img = None, text_encodings = None, predict_x_start = False, noise = None):
|
||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||
|
||||
x_noisy = self.q_sample(x_start = x_start, t = t, noise = noise)
|
||||
@@ -1208,12 +1373,14 @@ class Decoder(nn.Module):
|
||||
cond_drop_prob = self.cond_drop_prob
|
||||
)
|
||||
|
||||
target = noise if not predict_x_start else x_start
|
||||
|
||||
if self.loss_type == 'l1':
|
||||
loss = F.l1_loss(noise, x_recon)
|
||||
loss = F.l1_loss(target, x_recon)
|
||||
elif self.loss_type == 'l2':
|
||||
loss = F.mse_loss(noise, x_recon)
|
||||
loss = F.mse_loss(target, x_recon)
|
||||
elif self.loss_type == "huber":
|
||||
loss = F.smooth_l1_loss(noise, x_recon)
|
||||
loss = F.smooth_l1_loss(target, x_recon)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -1223,25 +1390,56 @@ class Decoder(nn.Module):
|
||||
@eval_decorator
|
||||
def sample(self, image_embed, text = None, cond_scale = 1.):
|
||||
batch_size = image_embed.shape[0]
|
||||
channels = self.channels
|
||||
|
||||
text_encodings = self.get_text_encodings(text) if exists(text) else None
|
||||
|
||||
img = None
|
||||
for unet, image_size in tqdm(zip(self.unets, self.image_sizes)):
|
||||
shape = (batch_size, channels, image_size, image_size)
|
||||
img = self.p_sample_loop(unet, shape, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = img)
|
||||
|
||||
for unet, vae, channel, image_size, predict_x_start in tqdm(zip(self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)):
|
||||
with self.one_unet_in_gpu(unet = unet):
|
||||
lowres_cond_img = None
|
||||
shape = (batch_size, channel, image_size, image_size)
|
||||
|
||||
if unet.lowres_cond:
|
||||
lowres_cond_img = self.to_lowres_cond(img, target_image_size = image_size)
|
||||
|
||||
image_size = vae.get_encoded_fmap_size(image_size)
|
||||
shape = (batch_size, vae.encoded_dim, image_size, image_size)
|
||||
|
||||
if exists(lowres_cond_img):
|
||||
lowres_cond_img = vae.encode(lowres_cond_img)
|
||||
|
||||
img = self.p_sample_loop(
|
||||
unet,
|
||||
shape,
|
||||
image_embed = image_embed,
|
||||
text_encodings = text_encodings,
|
||||
cond_scale = cond_scale,
|
||||
predict_x_start = predict_x_start,
|
||||
lowres_cond_img = lowres_cond_img
|
||||
)
|
||||
|
||||
img = vae.decode(img)
|
||||
|
||||
return img
|
||||
|
||||
def forward(self, image, text = None, image_embed = None, text_encodings = None, unet_number = None):
|
||||
def forward(
|
||||
self,
|
||||
image,
|
||||
text = None,
|
||||
image_embed = None,
|
||||
text_encodings = None,
|
||||
unet_number = None
|
||||
):
|
||||
assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
|
||||
unet_number = default(unet_number, 1)
|
||||
assert 1 <= unet_number <= len(self.unets)
|
||||
unet_index = unet_number - 1
|
||||
|
||||
index = unet_number - 1
|
||||
unet = self.unets[index]
|
||||
target_image_size = self.image_sizes[index]
|
||||
unet = self.get_unet(unet_number)
|
||||
|
||||
target_image_size = self.image_sizes[unet_index]
|
||||
vae = self.vaes[unet_index]
|
||||
predict_x_start = self.predict_x_start[unet_index]
|
||||
|
||||
b, c, h, w, device, = *image.shape, image.device
|
||||
|
||||
@@ -1255,9 +1453,17 @@ class Decoder(nn.Module):
|
||||
|
||||
text_encodings = self.get_text_encodings(text) if exists(text) and not exists(text_encodings) else None
|
||||
|
||||
lowres_cond_img = image if index > 0 else None
|
||||
ddpm_image = resize_image_to(image, target_image_size)
|
||||
return self.p_losses(unet, ddpm_image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img)
|
||||
lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None
|
||||
image = resize_image_to(image, target_image_size)
|
||||
|
||||
vae.eval()
|
||||
with torch.no_grad():
|
||||
image = vae.encode(image)
|
||||
|
||||
if exists(lowres_cond_img):
|
||||
lowres_cond_img = vae.encode(lowres_cond_img)
|
||||
|
||||
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start)
|
||||
|
||||
# main class
|
||||
|
||||
@@ -1284,6 +1490,7 @@ class DALLE2(nn.Module):
|
||||
cond_scale = 1.
|
||||
):
|
||||
device = next(self.parameters()).device
|
||||
one_text = isinstance(text, str) or (not is_list_str(text) and text.shape[0] == 1)
|
||||
|
||||
if isinstance(text, str) or is_list_str(text):
|
||||
text = [text] if not isinstance(text, (list, tuple)) else text
|
||||
@@ -1291,4 +1498,8 @@ class DALLE2(nn.Module):
|
||||
|
||||
image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples)
|
||||
images = self.decoder.sample(image_embed, cond_scale = cond_scale)
|
||||
|
||||
if one_text:
|
||||
return images[0]
|
||||
|
||||
return images
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn, einsum
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
class LatentDiffusion(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
@@ -0,0 +1,53 @@
|
||||
import copy
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
# exponential moving average wrapper
|
||||
|
||||
class EMA(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
beta = 0.99,
|
||||
ema_update_after_step = 1000,
|
||||
ema_update_every = 10,
|
||||
):
|
||||
super().__init__()
|
||||
self.beta = beta
|
||||
self.online_model = model
|
||||
self.ema_model = copy.deepcopy(model)
|
||||
|
||||
self.ema_update_after_step = ema_update_after_step # only start EMA after this step number, starting at 0
|
||||
self.ema_update_every = ema_update_every
|
||||
|
||||
self.register_buffer('initted', torch.Tensor([False]))
|
||||
self.register_buffer('step', torch.tensor([0.]))
|
||||
|
||||
def update(self):
|
||||
self.step += 1
|
||||
|
||||
if self.step <= self.ema_update_after_step or (self.step % self.ema_update_every) != 0:
|
||||
return
|
||||
|
||||
if not self.initted:
|
||||
self.ema_model.state_dict(self.online_model.state_dict())
|
||||
self.initted.data.copy_(torch.Tensor([True]))
|
||||
|
||||
self.update_moving_average(self.ema_model, self.online_model)
|
||||
|
||||
def update_moving_average(ma_model, current_model):
|
||||
def calculate_ema(beta, old, new):
|
||||
if not exists(old):
|
||||
return new
|
||||
return old * beta + (1 - beta) * new
|
||||
|
||||
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
|
||||
old_weight, up_weight = ma_params.data, current_params.data
|
||||
ma_params.data = calculate_ema(self.beta, old_weight, up_weight)
|
||||
|
||||
for current_buffer, ma_buffer in zip(current_model.buffers(), ma_model.buffers()):
|
||||
new_buffer_value = calculate_ema(self.beta, ma_buffer, current_buffer)
|
||||
ma_buffer.copy_(new_buffer_value)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.ema_model(*args, **kwargs)
|
||||
|
||||
566
dalle2_pytorch/vqgan_vae.py
Normal file
566
dalle2_pytorch/vqgan_vae.py
Normal file
@@ -0,0 +1,566 @@
|
||||
import copy
|
||||
import math
|
||||
from math import sqrt
|
||||
from functools import partial, wraps
|
||||
|
||||
from vector_quantize_pytorch import VectorQuantize as VQ
|
||||
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import grad as torch_grad
|
||||
import torchvision
|
||||
|
||||
from einops import rearrange, reduce, repeat
|
||||
|
||||
from dalle2_pytorch.attention import QueryAttnUpsample
|
||||
|
||||
# constants
|
||||
|
||||
MList = nn.ModuleList
|
||||
|
||||
# helper functions
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
# decorators
|
||||
|
||||
def eval_decorator(fn):
|
||||
def inner(model, *args, **kwargs):
|
||||
was_training = model.training
|
||||
model.eval()
|
||||
out = fn(model, *args, **kwargs)
|
||||
model.train(was_training)
|
||||
return out
|
||||
return inner
|
||||
|
||||
def remove_vgg(fn):
|
||||
@wraps(fn)
|
||||
def inner(self, *args, **kwargs):
|
||||
has_vgg = hasattr(self, 'vgg')
|
||||
if has_vgg:
|
||||
vgg = self.vgg
|
||||
delattr(self, 'vgg')
|
||||
|
||||
out = fn(self, *args, **kwargs)
|
||||
|
||||
if has_vgg:
|
||||
self.vgg = vgg
|
||||
|
||||
return out
|
||||
return inner
|
||||
|
||||
# keyword argument helpers
|
||||
|
||||
def pick_and_pop(keys, d):
|
||||
values = list(map(lambda key: d.pop(key), keys))
|
||||
return dict(zip(keys, values))
|
||||
|
||||
def group_dict_by_key(cond, d):
|
||||
return_val = [dict(),dict()]
|
||||
for key in d.keys():
|
||||
match = bool(cond(key))
|
||||
ind = int(not match)
|
||||
return_val[ind][key] = d[key]
|
||||
return (*return_val,)
|
||||
|
||||
def string_begins_with(prefix, str):
|
||||
return str.startswith(prefix)
|
||||
|
||||
def group_by_key_prefix(prefix, d):
|
||||
return group_dict_by_key(partial(string_begins_with, prefix), d)
|
||||
|
||||
def groupby_prefix_and_trim(prefix, d):
|
||||
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
|
||||
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
|
||||
return kwargs_without_prefix, kwargs
|
||||
|
||||
# tensor helper functions
|
||||
|
||||
def log(t, eps = 1e-10):
|
||||
return torch.log(t + eps)
|
||||
|
||||
def gradient_penalty(images, output, weight = 10):
|
||||
batch_size = images.shape[0]
|
||||
gradients = torch_grad(outputs = output, inputs = images,
|
||||
grad_outputs = torch.ones(output.size(), device = images.device),
|
||||
create_graph = True, retain_graph = True, only_inputs = True)[0]
|
||||
|
||||
gradients = rearrange(gradients, 'b ... -> b (...)')
|
||||
return weight * ((gradients.norm(2, dim = 1) - 1) ** 2).mean()
|
||||
|
||||
def l2norm(t):
|
||||
return F.normalize(t, dim = -1)
|
||||
|
||||
def leaky_relu(p = 0.1):
|
||||
return nn.LeakyReLU(0.1)
|
||||
|
||||
def stable_softmax(t, dim = -1, alpha = 32 ** 2):
|
||||
t = t / alpha
|
||||
t = t - torch.amax(t, dim = dim, keepdim = True).detach()
|
||||
return (t * alpha).softmax(dim = dim)
|
||||
|
||||
def safe_div(numer, denom, eps = 1e-8):
|
||||
return numer / (denom + eps)
|
||||
|
||||
# gan losses
|
||||
|
||||
def hinge_discr_loss(fake, real):
|
||||
return (F.relu(1 + fake) + F.relu(1 - real)).mean()
|
||||
|
||||
def hinge_gen_loss(fake):
|
||||
return -fake.mean()
|
||||
|
||||
def bce_discr_loss(fake, real):
|
||||
return (-log(1 - torch.sigmoid(fake)) - log(torch.sigmoid(real))).mean()
|
||||
|
||||
def bce_gen_loss(fake):
|
||||
return -log(torch.sigmoid(fake)).mean()
|
||||
|
||||
def grad_layer_wrt_loss(loss, layer):
|
||||
return torch_grad(
|
||||
outputs = loss,
|
||||
inputs = layer,
|
||||
grad_outputs = torch.ones_like(loss),
|
||||
retain_graph = True
|
||||
)[0].detach()
|
||||
|
||||
# vqgan vae
|
||||
|
||||
class LayerNormChan(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
eps = 1e-5
|
||||
):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1))
|
||||
|
||||
def forward(self, x):
|
||||
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
|
||||
mean = torch.mean(x, dim = 1, keepdim = True)
|
||||
return (x - mean) / (var + self.eps).sqrt() * self.gamma
|
||||
|
||||
class Discriminator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dims,
|
||||
channels = 3,
|
||||
groups = 16,
|
||||
init_kernel_size = 5
|
||||
):
|
||||
super().__init__()
|
||||
dim_pairs = zip(dims[:-1], dims[1:])
|
||||
|
||||
self.layers = MList([nn.Sequential(nn.Conv2d(channels, dims[0], init_kernel_size, padding = init_kernel_size // 2), leaky_relu())])
|
||||
|
||||
for dim_in, dim_out in dim_pairs:
|
||||
self.layers.append(nn.Sequential(
|
||||
nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1),
|
||||
nn.GroupNorm(groups, dim_out),
|
||||
leaky_relu()
|
||||
))
|
||||
|
||||
dim = dims[-1]
|
||||
self.to_logits = nn.Sequential( # return 5 x 5, for PatchGAN-esque training
|
||||
nn.Conv2d(dim, dim, 1),
|
||||
leaky_relu(),
|
||||
nn.Conv2d(dim, 1, 4)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
for net in self.layers:
|
||||
x = net(x)
|
||||
|
||||
return self.to_logits(x)
|
||||
|
||||
class ContinuousPositionBias(nn.Module):
|
||||
""" from https://arxiv.org/abs/2111.09883 """
|
||||
|
||||
def __init__(self, *, dim, heads, layers = 2):
|
||||
super().__init__()
|
||||
self.net = MList([])
|
||||
self.net.append(nn.Sequential(nn.Linear(2, dim), leaky_relu()))
|
||||
|
||||
for _ in range(layers - 1):
|
||||
self.net.append(nn.Sequential(nn.Linear(dim, dim), leaky_relu()))
|
||||
|
||||
self.net.append(nn.Linear(dim, heads))
|
||||
self.register_buffer('rel_pos', None, persistent = False)
|
||||
|
||||
def forward(self, x):
|
||||
n, device = x.shape[-1], x.device
|
||||
fmap_size = int(sqrt(n))
|
||||
|
||||
if not exists(self.rel_pos):
|
||||
pos = torch.arange(fmap_size, device = device)
|
||||
grid = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij'))
|
||||
grid = rearrange(grid, 'c i j -> (i j) c')
|
||||
rel_pos = rearrange(grid, 'i c -> i 1 c') - rearrange(grid, 'j c -> 1 j c')
|
||||
rel_pos = torch.sign(rel_pos) * torch.log(rel_pos.abs() + 1)
|
||||
self.register_buffer('rel_pos', rel_pos, persistent = False)
|
||||
|
||||
rel_pos = self.rel_pos.float()
|
||||
|
||||
for layer in self.net:
|
||||
rel_pos = layer(rel_pos)
|
||||
|
||||
bias = rearrange(rel_pos, 'i j h -> h i j')
|
||||
return x + bias
|
||||
|
||||
class GLUResBlock(nn.Module):
|
||||
def __init__(self, chan, groups = 16):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(chan, chan * 2, 3, padding = 1),
|
||||
nn.GLU(dim = 1),
|
||||
nn.GroupNorm(groups, chan),
|
||||
nn.Conv2d(chan, chan * 2, 3, padding = 1),
|
||||
nn.GLU(dim = 1),
|
||||
nn.GroupNorm(groups, chan),
|
||||
nn.Conv2d(chan, chan, 1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x) + x
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, chan, groups = 16):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(chan, chan, 3, padding = 1),
|
||||
nn.GroupNorm(groups, chan),
|
||||
leaky_relu(),
|
||||
nn.Conv2d(chan, chan, 3, padding = 1),
|
||||
nn.GroupNorm(groups, chan),
|
||||
leaky_relu(),
|
||||
nn.Conv2d(chan, chan, 1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x) + x
|
||||
|
||||
# vqgan attention layer
|
||||
class VQGanAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dim,
|
||||
dim_head = 64,
|
||||
heads = 8,
|
||||
dropout = 0.
|
||||
):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
inner_dim = heads * dim_head
|
||||
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.pre_norm = LayerNormChan(dim)
|
||||
|
||||
self.cpb = ContinuousPositionBias(dim = dim // 4, heads = heads)
|
||||
self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)
|
||||
self.to_out = nn.Conv2d(inner_dim, dim, 1, bias = False)
|
||||
|
||||
def forward(self, x):
|
||||
h = self.heads
|
||||
height, width, residual = *x.shape[-2:], x.clone()
|
||||
|
||||
x = self.pre_norm(x)
|
||||
|
||||
q, k, v = self.to_qkv(x).chunk(3, dim = 1)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = h), (q, k, v))
|
||||
|
||||
sim = einsum('b h c i, b h c j -> b h i j', q, k) * self.scale
|
||||
|
||||
sim = self.cpb(sim)
|
||||
|
||||
attn = stable_softmax(sim, dim = -1)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = einsum('b h i j, b h c j -> b h c i', attn, v)
|
||||
out = rearrange(out, 'b h c (x y) -> b (h c) x y', x = height, y = width)
|
||||
out = self.to_out(out)
|
||||
|
||||
return out + residual
|
||||
|
||||
class NullVQGanVAE(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
channels
|
||||
):
|
||||
super().__init__()
|
||||
self.encoded_dim = channels
|
||||
self.layers = 0
|
||||
|
||||
def get_encoded_fmap_size(self, size):
|
||||
return size
|
||||
|
||||
def copy_for_eval(self):
|
||||
return self
|
||||
|
||||
def encode(self, x):
|
||||
return x
|
||||
|
||||
def decode(self, x):
|
||||
return x
|
||||
|
||||
class VQGanVAE(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dim,
|
||||
image_size,
|
||||
channels = 3,
|
||||
layers = 4,
|
||||
layer_mults = None,
|
||||
l2_recon_loss = False,
|
||||
use_hinge_loss = True,
|
||||
num_resnet_blocks = 1,
|
||||
vgg = None,
|
||||
vq_codebook_size = 512,
|
||||
vq_decay = 0.8,
|
||||
vq_commitment_weight = 1.,
|
||||
vq_kmeans_init = True,
|
||||
vq_use_cosine_sim = True,
|
||||
use_attn = True,
|
||||
attn_dim_head = 64,
|
||||
attn_heads = 8,
|
||||
resnet_groups = 16,
|
||||
attn_dropout = 0.,
|
||||
first_conv_kernel_size = 5,
|
||||
use_vgg_and_gan = True,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
assert dim % resnet_groups == 0, f'dimension {dim} must be divisible by {resnet_groups} (groups for the groupnorm)'
|
||||
|
||||
vq_kwargs, kwargs = groupby_prefix_and_trim('vq_', kwargs)
|
||||
|
||||
self.image_size = image_size
|
||||
self.channels = channels
|
||||
self.layers = layers
|
||||
self.fmap_size = image_size // (layers ** 2)
|
||||
self.codebook_size = vq_codebook_size
|
||||
|
||||
self.encoders = MList([])
|
||||
self.decoders = MList([])
|
||||
|
||||
layer_mults = default(layer_mults, list(map(lambda t: 2 ** t, range(layers))))
|
||||
assert len(layer_mults) == layers, 'layer multipliers must be equal to designated number of layers'
|
||||
|
||||
layer_dims = [dim * mult for mult in layer_mults]
|
||||
dims = (dim, *layer_dims)
|
||||
codebook_dim = layer_dims[-1]
|
||||
|
||||
self.encoded_dim = dims[-1]
|
||||
|
||||
dim_pairs = zip(dims[:-1], dims[1:])
|
||||
|
||||
append = lambda arr, t: arr.append(t)
|
||||
prepend = lambda arr, t: arr.insert(0, t)
|
||||
|
||||
if not isinstance(num_resnet_blocks, tuple):
|
||||
num_resnet_blocks = (*((0,) * (layers - 1)), num_resnet_blocks)
|
||||
|
||||
if not isinstance(use_attn, tuple):
|
||||
use_attn = (*((False,) * (layers - 1)), use_attn)
|
||||
|
||||
assert len(num_resnet_blocks) == layers, 'number of resnet blocks config must be equal to number of layers'
|
||||
assert len(use_attn) == layers
|
||||
|
||||
for layer_index, (dim_in, dim_out), layer_num_resnet_blocks, layer_use_attn in zip(range(layers), dim_pairs, num_resnet_blocks, use_attn):
|
||||
append(self.encoders, nn.Sequential(nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), leaky_relu()))
|
||||
prepend(self.decoders, nn.Sequential(nn.ConvTranspose2d(dim_out, dim_in, 4, 2, 1), leaky_relu()))
|
||||
|
||||
if layer_use_attn:
|
||||
prepend(self.decoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
|
||||
|
||||
for _ in range(layer_num_resnet_blocks):
|
||||
append(self.encoders, ResBlock(dim_out, groups = resnet_groups))
|
||||
prepend(self.decoders, GLUResBlock(dim_out, groups = resnet_groups))
|
||||
|
||||
if layer_use_attn:
|
||||
append(self.encoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
|
||||
|
||||
prepend(self.encoders, nn.Conv2d(channels, dim, first_conv_kernel_size, padding = first_conv_kernel_size // 2))
|
||||
append(self.decoders, nn.Conv2d(dim, channels, 1))
|
||||
|
||||
self.vq = VQ(
|
||||
dim = codebook_dim,
|
||||
codebook_size = vq_codebook_size,
|
||||
decay = vq_decay,
|
||||
commitment_weight = vq_commitment_weight,
|
||||
accept_image_fmap = True,
|
||||
kmeans_init = vq_kmeans_init,
|
||||
use_cosine_sim = vq_use_cosine_sim,
|
||||
**vq_kwargs
|
||||
)
|
||||
|
||||
# reconstruction loss
|
||||
|
||||
self.recon_loss_fn = F.mse_loss if l2_recon_loss else F.l1_loss
|
||||
|
||||
# turn off GAN and perceptual loss if grayscale
|
||||
|
||||
self.vgg = None
|
||||
self.discr = None
|
||||
self.use_vgg_and_gan = use_vgg_and_gan
|
||||
|
||||
if not use_vgg_and_gan:
|
||||
return
|
||||
|
||||
# preceptual loss
|
||||
|
||||
if exists(vgg):
|
||||
self.vgg = vgg
|
||||
else:
|
||||
self.vgg = torchvision.models.vgg16(pretrained = True)
|
||||
self.vgg.classifier = nn.Sequential(*self.vgg.classifier[:-2])
|
||||
|
||||
# gan related losses
|
||||
|
||||
self.discr = Discriminator(dims = dims, channels = channels)
|
||||
|
||||
self.discr_loss = hinge_discr_loss if use_hinge_loss else bce_discr_loss
|
||||
self.gen_loss = hinge_gen_loss if use_hinge_loss else bce_gen_loss
|
||||
|
||||
def get_encoded_fmap_size(self, image_size):
|
||||
return image_size // (2 ** self.layers)
|
||||
|
||||
def copy_for_eval(self):
|
||||
device = next(self.parameters()).device
|
||||
vae_copy = copy.deepcopy(self.cpu())
|
||||
|
||||
if vae_copy.use_vgg_and_gan:
|
||||
del vae_copy.discr
|
||||
del vae_copy.vgg
|
||||
|
||||
vae_copy.eval()
|
||||
return vae_copy.to(device)
|
||||
|
||||
@remove_vgg
|
||||
def state_dict(self, *args, **kwargs):
|
||||
return super().state_dict(*args, **kwargs)
|
||||
|
||||
@remove_vgg
|
||||
def load_state_dict(self, *args, **kwargs):
|
||||
return super().load_state_dict(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def codebook(self):
|
||||
return self.vq.codebook
|
||||
|
||||
def encode(self, fmap):
|
||||
for enc in self.encoders:
|
||||
fmap = enc(fmap)
|
||||
|
||||
return fmap
|
||||
|
||||
def decode(self, fmap, return_indices_and_loss = False):
|
||||
fmap, indices, commit_loss = self.vq(fmap)
|
||||
|
||||
for dec in self.decoders:
|
||||
fmap = dec(fmap)
|
||||
|
||||
if not return_indices_and_loss:
|
||||
return fmap
|
||||
|
||||
return fmap, indices, commit_loss
|
||||
|
||||
def forward(
|
||||
self,
|
||||
img,
|
||||
return_loss = False,
|
||||
return_discr_loss = False,
|
||||
return_recons = False,
|
||||
add_gradient_penalty = True
|
||||
):
|
||||
batch, channels, height, width, device = *img.shape, img.device
|
||||
assert height == self.image_size and width == self.image_size, 'height and width of input image must be equal to {self.image_size}'
|
||||
assert channels == self.channels, 'number of channels on image or sketch is not equal to the channels set on this VQGanVAE'
|
||||
|
||||
fmap = self.encode(img)
|
||||
|
||||
fmap, indices, commit_loss = self.decode(fmap, return_indices_and_loss = True)
|
||||
|
||||
if not return_loss and not return_discr_loss:
|
||||
return fmap
|
||||
|
||||
assert return_loss ^ return_discr_loss, 'you should either return autoencoder loss or discriminator loss, but not both'
|
||||
|
||||
# whether to return discriminator loss
|
||||
|
||||
if return_discr_loss:
|
||||
assert exists(self.discr), 'discriminator must exist to train it'
|
||||
|
||||
fmap.detach_()
|
||||
img.requires_grad_()
|
||||
|
||||
fmap_discr_logits, img_discr_logits = map(self.discr, (fmap, img))
|
||||
|
||||
discr_loss = self.discr_loss(fmap_discr_logits, img_discr_logits)
|
||||
|
||||
if add_gradient_penalty:
|
||||
gp = gradient_penalty(img, img_discr_logits)
|
||||
loss = discr_loss + gp
|
||||
|
||||
if return_recons:
|
||||
return loss, fmap
|
||||
|
||||
return loss
|
||||
|
||||
# reconstruction loss
|
||||
|
||||
recon_loss = self.recon_loss_fn(fmap, img)
|
||||
|
||||
# early return if training on grayscale
|
||||
|
||||
if not self.use_vgg_and_gan:
|
||||
if return_recons:
|
||||
return recon_loss, fmap
|
||||
|
||||
return recon_loss
|
||||
|
||||
# perceptual loss
|
||||
|
||||
img_vgg_input = img
|
||||
fmap_vgg_input = fmap
|
||||
|
||||
if img.shape[1] == 1:
|
||||
# handle grayscale for vgg
|
||||
img_vgg_input, fmap_vgg_input = map(lambda t: repeat(t, 'b 1 ... -> b c ...', c = 3), (img_vgg_input, fmap_vgg_input))
|
||||
|
||||
img_vgg_feats = self.vgg(img_vgg_input)
|
||||
recon_vgg_feats = self.vgg(fmap_vgg_input)
|
||||
perceptual_loss = F.mse_loss(img_vgg_feats, recon_vgg_feats)
|
||||
|
||||
# generator loss
|
||||
|
||||
gen_loss = self.gen_loss(self.discr(fmap))
|
||||
|
||||
# calculate adaptive weight
|
||||
|
||||
last_dec_layer = self.decoders[-1].weight
|
||||
|
||||
norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_dec_layer).norm(p = 2)
|
||||
norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_dec_layer).norm(p = 2)
|
||||
|
||||
adaptive_weight = safe_div(norm_grad_wrt_perceptual_loss, norm_grad_wrt_gen_loss)
|
||||
adaptive_weight.clamp_(max = 1e4)
|
||||
|
||||
# combine losses
|
||||
|
||||
loss = recon_loss + perceptual_loss + commit_loss + adaptive_weight * gen_loss
|
||||
|
||||
if return_recons:
|
||||
return loss, fmap
|
||||
|
||||
return loss
|
||||
3
setup.py
3
setup.py
@@ -10,7 +10,7 @@ setup(
|
||||
'dream = dalle2_pytorch.cli:dream'
|
||||
],
|
||||
},
|
||||
version = '0.0.25',
|
||||
version = '0.0.47',
|
||||
license='MIT',
|
||||
description = 'DALL-E 2',
|
||||
author = 'Phil Wang',
|
||||
@@ -30,6 +30,7 @@ setup(
|
||||
'torch>=1.10',
|
||||
'torchvision',
|
||||
'tqdm',
|
||||
'vector-quantize-pytorch',
|
||||
'x-clip>=0.4.4',
|
||||
'youtokentome'
|
||||
],
|
||||
|
||||
Reference in New Issue
Block a user