Compare commits

..

1 Commits

8 changed files with 162 additions and 1428 deletions

265
README.md
View File

@@ -1,6 +1,6 @@
<img src="./dalle2.png" width="450px"></img>
## DALL-E 2 - Pytorch
## DALL-E 2 - Pytorch (wip)
Implementation of <a href="https://openai.com/dall-e-2/">DALL-E 2</a>, OpenAI's updated text-to-image synthesis neural network, in Pytorch.
@@ -10,9 +10,11 @@ The main novelty seems to be an extra layer of indirection with the prior networ
This model is SOTA for text-to-image for now.
It may also explore an extension of using <a href="https://huggingface.co/spaces/multimodalart/latentdiffusion">latent diffusion</a> in the decoder from Rombach et al.
Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord" src="https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white"></a> if you are interested in helping out with the replication
There was enough interest for a <a href="https://github.com/lucidrains/dalle2-jax">Jax version</a>. I will also eventually extend this to <a href="https://github.com/lucidrains/dalle2-video">text to video</a>, once the repository is in a good place.
There was enough interest for a Jax version. It will be completed after the Pytorch version shows signs of life on my toy tasks. <a href="https://github.com/lucidrains/dalle2-jax">Placeholder repository</a>
## Install
@@ -195,10 +197,10 @@ clip = CLIP(
dim_image = 512,
dim_latent = 512,
num_text_tokens = 49408,
text_enc_depth = 6,
text_enc_depth = 1,
text_seq_len = 256,
text_heads = 8,
visual_enc_depth = 6,
visual_enc_depth = 1,
visual_image_size = 256,
visual_patch_size = 32,
visual_heads = 8
@@ -207,28 +209,28 @@ clip = CLIP(
# 2 unets for the decoder (a la cascading DDPM)
unet1 = Unet(
dim = 32,
dim = 16,
image_embed_dim = 512,
cond_dim = 128,
channels = 3,
dim_mults = (1, 2, 4, 8)
).cuda()
unet2 = Unet(
dim = 32,
dim = 16,
image_embed_dim = 512,
lowres_cond = True, # subsequent unets must have this turned on (and first unet must have this turned off)
cond_dim = 128,
channels = 3,
dim_mults = (1, 2, 4, 8, 16)
).cuda()
# decoder, which contains the unet(s) and clip
# decoder, which contains the unet and clip
decoder = Decoder(
clip = clip,
unet = (unet1, unet2), # insert both unets in order of low resolution to highest resolution (you can have as many stages as you want here)
image_sizes = (256, 512), # resolutions, 256 for first unet, 512 for second. these must be unique and in ascending order (matches with the unets passed in)
timesteps = 1000,
image_sizes = (256, 512), # resolutions, 256 for first unet, 512 for second
timesteps = 100,
cond_drop_prob = 0.2
).cuda()
@@ -246,9 +248,16 @@ loss = decoder(images, unet_number = 2)
loss.backward()
# do the above for many steps for both unets
# then it will learn to generate images based on the CLIP image embeddings
# chaining the unets from lowest resolution to highest resolution (thus cascading)
mock_image_embed = torch.randn(1, 512).cuda()
images = decoder.sample(mock_image_embed) # (1, 3, 512, 512)
```
Finally, to generate the DALL-E2 images from text. Insert the trained `DiffusionPrior` as well as the `Decoder` (which wraps `CLIP`, the causal transformer, and unet(s))
Finally, to generate the DALL-E2 images from text. Insert the trained `DiffusionPrior` as well as the `Decoder` (which both contains `CLIP`, a unet, and a causal transformer)
```python
from dalle2_pytorch import DALLE2
@@ -340,7 +349,8 @@ unet2 = Unet(
image_embed_dim = 512,
cond_dim = 128,
channels = 3,
dim_mults = (1, 2, 4, 8, 16)
dim_mults = (1, 2, 4, 8, 16),
lowres_cond = True
).cuda()
decoder = Decoder(
@@ -348,8 +358,7 @@ decoder = Decoder(
image_sizes = (128, 256),
clip = clip,
timesteps = 100,
cond_drop_prob = 0.2,
condition_on_text_encodings = False # set this to True if you wish to condition on text during training and sampling
cond_drop_prob = 0.2
).cuda()
for unet_number in (1, 2):
@@ -377,196 +386,7 @@ You can also train the decoder on images of greater than the size (say 512x512)
For the layperson, no worries, training will all be automated into a CLI tool, at least for small scale training.
## Training on Preprocessed CLIP Embeddings
It is likely, when scaling up, that you would first preprocess your images and text into corresponding embeddings before training the prior network. You can do so easily by simply passing in `image_embed`, `text_embed`, and optionally `text_encodings` and `text_mask`
Working example below
```python
import torch
from dalle2_pytorch import DiffusionPriorNetwork, DiffusionPrior, CLIP
# get trained CLIP from step one
clip = CLIP(
dim_text = 512,
dim_image = 512,
dim_latent = 512,
num_text_tokens = 49408,
text_enc_depth = 6,
text_seq_len = 256,
text_heads = 8,
visual_enc_depth = 6,
visual_image_size = 256,
visual_patch_size = 32,
visual_heads = 8,
).cuda()
# setup prior network, which contains an autoregressive transformer
prior_network = DiffusionPriorNetwork(
dim = 512,
depth = 6,
dim_head = 64,
heads = 8
).cuda()
# diffusion prior network, which contains the CLIP and network (with transformer) above
diffusion_prior = DiffusionPrior(
net = prior_network,
clip = clip,
timesteps = 100,
cond_drop_prob = 0.2,
condition_on_text_encodings = False # this probably should be true, but just to get Laion started
).cuda()
# mock data
text = torch.randint(0, 49408, (4, 256)).cuda()
images = torch.randn(4, 3, 256, 256).cuda()
# precompute the text and image embeddings
# here using the diffusion prior class, but could be done with CLIP alone
clip_image_embeds = diffusion_prior.get_image_embed(images)
clip_text_embeds = diffusion_prior.get_text_cond(text).get('text_embed')
# feed text and images into diffusion prior network
loss = diffusion_prior(
text_embed = clip_text_embeds,
image_embed = clip_image_embeds
)
loss.backward()
# do the above for many many many steps
# now the diffusion prior can generate image embeddings from the text embeddings
```
## Experimental
### DALL-E2 with Latent Diffusion
This repository decides to take the next step and offer DALL-E2 combined with <a href="https://huggingface.co/spaces/multimodalart/latentdiffusion">latent diffusion</a>, from Rombach et al.
You can use it as follows. Latent diffusion can be limited to just the first U-Net in the cascade, or to any number you wish.
```python
import torch
from dalle2_pytorch import Unet, Decoder, CLIP, VQGanVAE
# trained clip from step 1
clip = CLIP(
dim_text = 512,
dim_image = 512,
dim_latent = 512,
num_text_tokens = 49408,
text_enc_depth = 1,
text_seq_len = 256,
text_heads = 8,
visual_enc_depth = 1,
visual_image_size = 256,
visual_patch_size = 32,
visual_heads = 8
)
# 3 unets for the decoder (a la cascading DDPM)
# first two unets are doing latent diffusion
# vqgan-vae must be trained before hand
vae1 = VQGanVAE(
dim = 32,
image_size = 256,
layers = 3,
layer_mults = (1, 2, 4)
)
vae2 = VQGanVAE(
dim = 32,
image_size = 512,
layers = 3,
layer_mults = (1, 2, 4)
)
unet1 = Unet(
dim = 32,
image_embed_dim = 512,
cond_dim = 128,
channels = 3,
sparse_attn = True,
sparse_attn_window = 2,
dim_mults = (1, 2, 4, 8)
)
unet2 = Unet(
dim = 32,
image_embed_dim = 512,
channels = 3,
dim_mults = (1, 2, 4, 8, 16),
cond_on_image_embeds = True,
cond_on_text_encodings = False
)
unet3 = Unet(
dim = 32,
image_embed_dim = 512,
channels = 3,
dim_mults = (1, 2, 4, 8, 16),
cond_on_image_embeds = True,
cond_on_text_encodings = False,
attend_at_middle = False
)
# decoder, which contains the unet(s) and clip
decoder = Decoder(
clip = clip,
vae = (vae1, vae2), # latent diffusion for unet1 (vae1) and unet2 (vae2), but not for the last unet3
unet = (unet1, unet2, unet3), # insert unets in order of low resolution to highest resolution (you can have as many stages as you want here)
image_sizes = (256, 512, 1024), # resolutions, 256 for first unet, 512 for second, 1024 for third
timesteps = 100,
cond_drop_prob = 0.2
).cuda()
# mock images (get a lot of this)
images = torch.randn(1, 3, 1024, 1024).cuda()
# feed images into decoder, specifying which unet you want to train
# each unet can be trained separately, which is one of the benefits of the cascading DDPM scheme
with decoder.one_unet_in_gpu(1):
loss = decoder(images, unet_number = 1)
loss.backward()
with decoder.one_unet_in_gpu(2):
loss = decoder(images, unet_number = 2)
loss.backward()
with decoder.one_unet_in_gpu(3):
loss = decoder(images, unet_number = 3)
loss.backward()
# do the above for many steps for both unets
# then it will learn to generate images based on the CLIP image embeddings
# chaining the unets from lowest resolution to highest resolution (thus cascading)
mock_image_embed = torch.randn(1, 512).cuda()
images = decoder.sample(mock_image_embed) # (1, 3, 1024, 1024)
```
## Training wrapper (wip)
Offer training wrappers
## CLI (wip)
## CLI Usage (work in progress)
```bash
$ dream 'sharing a sunset at the summit of mount everest with my dog'
@@ -574,7 +394,9 @@ $ dream 'sharing a sunset at the summit of mount everest with my dog'
Once built, images will be saved to the same directory the command is invoked
<a href="https://github.com/lucidrains/big-sleep">template</a>
## Training wrapper (wip)
Offer training wrappers
## Training CLI (wip)
@@ -588,20 +410,14 @@ Once built, images will be saved to the same directory the command is invoked
- [x] augment unet so that it can also be conditioned on text encodings (although in paper they hinted this didn't make much a difference)
- [x] figure out all the current bag of tricks needed to make DDPMs great (starting with the blur trick mentioned in paper)
- [x] build the cascading ddpm by having Decoder class manage multiple unets at different resolutions
- [x] add efficient attention in unet
- [x] be able to finely customize what to condition on (text, image embed) for specific unet in the cascade (super resolution ddpms near the end may not need too much conditioning)
- [x] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately)
- [x] build out latent diffusion architecture, with the vq-reg variant (vqgan-vae), make it completely optional and compatible with cascading ddpms
- [x] for decoder, allow ability to customize objective (predict epsilon vs x0), in case latent diffusion does better with prediction of x0
- [x] use attention-based upsampling https://arxiv.org/abs/2112.11435
- [ ] spend one day cleaning up tech debt in decoder
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
- [ ] use an image resolution cutoff and do cross attention conditioning only if resources allow, and MLP + sum conditioning on rest
- [ ] make unet more configurable
- [ ] figure out some factory methods to make cascading unet instantiations less error-prone
- [ ] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately)
- [ ] train on a toy task, offer in colab
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
- [ ] bring in tools to train vqgan-vae
- [ ] bring in vit-vqgan https://arxiv.org/abs/2110.04627 for the latent diffusion
- [ ] add attention to unet - apply some personal tricks with efficient attention - use the sparse attention mechanism from https://github.com/lucidrains/vit-pytorch#maxvit
- [ ] build out latent diffusion architecture in separate file, as it is not faithful to dalle-2 (but offer it as as setting)
- [ ] consider U2-net for decoder https://arxiv.org/abs/2005.09007 (also in separate file as experimental) build out https://github.com/lucidrains/x-unet
## Citations
@@ -633,17 +449,20 @@ Once built, images will be saved to the same directory the command is invoked
```bibtex
@inproceedings{Liu2022ACF,
title = {A ConvNet for the 2020https://arxiv.org/abs/2112.11435s},
title = {A ConvNet for the 2020s},
author = {Zhuang Liu and Hanzi Mao and Chaozheng Wu and Christoph Feichtenhofer and Trevor Darrell and Saining Xie},
year = {2022}
}
```
```bibtex
@inproceedings{Tu2022MaxViTMV,
title = {MaxViT: Multi-Axis Vision Transformer},
author = {Zhe-Wei Tu and Hossein Talebi and Han Zhang and Feng Yang and Peyman Milanfar and Alan Conrad Bovik and Yinxiao Li},
year = {2022}
@misc{zhang2019root,
title = {Root Mean Square Layer Normalization},
author = {Biao Zhang and Rico Sennrich},
year = {2019},
eprint = {1910.07467},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
```

View File

@@ -1,4 +1,2 @@
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
from dalle2_pytorch.vqgan_vae import VQGanVAE
from x_clip import CLIP

View File

@@ -1,125 +0,0 @@
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange, repeat
class LayerNormChan(nn.Module):
def __init__(
self,
dim,
eps = 1e-5
):
super().__init__()
self.eps = eps
self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1))
def forward(self, x):
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (var + self.eps).sqrt() * self.gamma
# attention-based upsampling
# from https://arxiv.org/abs/2112.11435
class QueryAndAttend(nn.Module):
def __init__(
self,
*,
dim,
num_queries = 1,
dim_head = 32,
heads = 8,
window_size = 3
):
super().__init__()
self.scale = dim_head ** -0.5
inner_dim = dim_head * heads
self.heads = heads
self.dim_head = dim_head
self.window_size = window_size
self.num_queries = num_queries
self.rel_pos_bias = nn.Parameter(torch.randn(heads, num_queries, window_size * window_size, 1, 1))
self.queries = nn.Parameter(torch.randn(heads, num_queries, dim_head))
self.to_kv = nn.Conv2d(dim, dim_head * 2, 1, bias = False)
self.to_out = nn.Conv2d(inner_dim, dim, 1, bias = False)
def forward(self, x):
"""
einstein notation
b - batch
h - heads
l - num queries
d - head dimension
x - height
y - width
j - source sequence for attending to (kernel size squared in this case)
"""
wsz, heads, dim_head, num_queries = self.window_size, self.heads, self.dim_head, self.num_queries
batch, _, height, width = x.shape
is_one_query = self.num_queries == 1
# queries, keys, values
q = self.queries * self.scale
k, v = self.to_kv(x).chunk(2, dim = 1)
# similarities
sim = einsum('h l d, b d x y -> b h l x y', q, k)
sim = rearrange(sim, 'b ... x y -> b (...) x y')
# unfold the similarity scores, with float(-inf) as padding value
mask_value = -torch.finfo(sim.dtype).max
sim = F.pad(sim, ((wsz // 2,) * 4), value = mask_value)
sim = F.unfold(sim, kernel_size = wsz)
sim = rearrange(sim, 'b (h l j) (x y) -> b h l j x y', h = heads, l = num_queries, x = height, y = width)
# rel pos bias
sim = sim + self.rel_pos_bias
# numerically stable attention
sim = sim - sim.amax(dim = -3, keepdim = True).detach()
attn = sim.softmax(dim = -3)
# unfold values
v = F.pad(v, ((wsz // 2,) * 4), value = 0.)
v = F.unfold(v, kernel_size = wsz)
v = rearrange(v, 'b (d j) (x y) -> b d j x y', d = dim_head, x = height, y = width)
# aggregate values
out = einsum('b h l j x y, b d j x y -> b l h d x y', attn, v)
# combine heads
out = rearrange(out, 'b l h d x y -> (b l) (h d) x y')
out = self.to_out(out)
out = rearrange(out, '(b l) d x y -> b l d x y', b = batch)
# return original input if one query
if is_one_query:
out = rearrange(out, 'b 1 ... -> b ...')
return out
class QueryAttnUpsample(nn.Module):
def __init__(self, dim, **kwargs):
super().__init__()
self.norm = LayerNormChan(dim)
self.qna = QueryAndAttend(dim = dim, num_queries = 4, **kwargs)
def forward(self, x):
x = self.norm(x)
out = self.qna(x)
out = rearrange(out, 'b (w1 w2) c h w -> b c (h w1) (w w2)', w1 = 2, w2 = 2)
return out

View File

@@ -1,51 +1,9 @@
import click
import torch
import torchvision.transforms as T
from pathlib import Path
from dalle2_pytorch import DALLE2, Decoder, DiffusionPrior
def safeget(dictionary, keys, default = None):
return reduce(lambda d, key: d.get(key, default) if isinstance(d, dict) else default, keys.split('.'), dictionary)
def simple_slugify(text, max_length = 255):
return text.replace("-", "_").replace(",", "").replace(" ", "_").replace("|", "--").strip('-_')[:max_length]
def get_pkg_version():
from pkg_resources import get_distribution
return get_distribution('dalle2_pytorch').version
def main():
pass
@click.command()
@click.option('--model', default = './dalle2.pt', help = 'path to trained DALL-E2 model')
@click.option('--cond_scale', default = 2, help = 'conditioning scale (classifier free guidance) in decoder')
@click.argument('text')
def dream(
model,
cond_scale,
text
):
model_path = Path(model)
full_model_path = str(model_path.resolve())
assert model_path.exists(), f'model not found at {full_model_path}'
loaded = torch.load(str(model_path))
version = safeget(loaded, 'version')
print(f'loading DALL-E2 from {full_model_path}, saved at version {version} - current package version is {get_pkg_version()}')
prior_init_params = safeget(loaded, 'init_params.prior')
decoder_init_params = safeget(loaded, 'init_params.decoder')
model_params = safeget(loaded, 'model_params')
prior = DiffusionPrior(**prior_init_params)
decoder = Decoder(**decoder_init_params)
dalle2 = DALLE2(prior, decoder)
dalle2.load_state_dict(model_params)
image = dalle2(text, cond_scale = cond_scale)
pil_image = T.ToPILImage()(image)
return pil_image.save(f'./{simple_slugify(text)}.png')
def dream(text):
return image

View File

@@ -2,7 +2,6 @@ import math
from tqdm import tqdm
from inspect import isfunction
from functools import partial
from contextlib import contextmanager
import torch
import torch.nn.functional as F
@@ -13,11 +12,10 @@ from einops.layers.torch import Rearrange
from einops_exts import rearrange_many, repeat_many, check_shape
from einops_exts.torch import EinopsToAndFrom
from kornia.filters.gaussian import GaussianBlur2d
from kornia.filters import gaussian_blur2d
from dalle2_pytorch.tokenizer import tokenizer
from dalle2_pytorch.vqgan_vae import NullVQGanVAE, VQGanVAE
from dalle2_pytorch.attention import QueryAttnUpsample
# use x-clip
@@ -50,12 +48,6 @@ def is_list_str(x):
return False
return all([type(el) == str for el in x])
def pad_tuple_to_length(t, length, fillvalue = None):
remain_length = length - len(t)
if remain_length <= 0:
return t
return (*t, *((fillvalue,) * remain_length))
# for controlling freezing of CLIP
def set_module_requires_grad_(module, requires_grad):
@@ -114,8 +106,8 @@ def cosine_beta_schedule(timesteps, s = 0.008):
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
"""
steps = timesteps + 1
x = torch.linspace(0, timesteps, steps)
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
x = torch.linspace(0, steps, steps)
alphas_cumprod = torch.cos(((x / steps) + s) / (1 + s) * torch.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0, 0.999)
@@ -145,27 +137,23 @@ def sigmoid_beta_schedule(timesteps):
# diffusion prior
class LayerNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.ones(dim))
self.register_buffer("beta", torch.zeros(dim))
def forward(self, x):
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
class ChanLayerNorm(nn.Module):
class RMSNorm(nn.Module):
def __init__(self, dim, eps = 1e-5):
super().__init__()
self.eps = eps
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
self.scale = dim ** 0.5
self.gamma = nn.Parameter(torch.ones(dim))
def forward(self, x):
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (var + self.eps).sqrt() * self.g
squared_sum = (x ** 2).sum(dim = -1, keepdim = True)
inv_norm = torch.rsqrt(squared_sum + self.eps)
return x * inv_norm * self.gamma * self.scale
class ChanRMSNorm(RMSNorm):
def forward(self, x):
squared_sum = (x ** 2).sum(dim = 1, keepdim = True)
inv_norm = torch.rsqrt(squared_sum + self.eps)
return x * inv_norm * rearrange(self.gamma, 'c -> 1 c 1 1') * self.scale
class Residual(nn.Module):
def __init__(self, fn):
@@ -261,10 +249,10 @@ def FeedForward(dim, mult = 4, dropout = 0., post_activation_norm = False):
inner_dim = int(mult * dim)
return nn.Sequential(
LayerNorm(dim),
RMSNorm(dim),
nn.Linear(dim, inner_dim * 2, bias = False),
SwiGLU(),
LayerNorm(inner_dim) if post_activation_norm else nn.Identity(),
RMSNorm(inner_dim) if post_activation_norm else nn.Identity(),
nn.Dropout(dropout),
nn.Linear(inner_dim, dim, bias = False)
)
@@ -287,8 +275,7 @@ class Attention(nn.Module):
inner_dim = dim_head * heads
self.causal = causal
self.norm = LayerNorm(dim)
self.post_norm = LayerNorm(dim) # sandwich norm from Coqview paper + Normformer
self.norm = RMSNorm(dim)
self.dropout = nn.Dropout(dropout)
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
@@ -344,8 +331,7 @@ class Attention(nn.Module):
out = einsum('b h i j, b j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out)
return self.post_norm(out)
return self.to_out(out)
class CausalTransformer(nn.Module):
def __init__(
@@ -358,8 +344,7 @@ class CausalTransformer(nn.Module):
ff_mult = 4,
norm_out = False,
attn_dropout = 0.,
ff_dropout = 0.,
final_proj = True
ff_dropout = 0.
):
super().__init__()
self.rel_pos_bias = RelPosBias(heads = heads)
@@ -371,8 +356,7 @@ class CausalTransformer(nn.Module):
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
]))
self.norm = LayerNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
self.project_out = nn.Linear(dim, dim, bias = False) if final_proj else nn.Identity()
self.norm = RMSNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
def forward(
self,
@@ -387,8 +371,7 @@ class CausalTransformer(nn.Module):
x = attn(x, mask = mask, attn_bias = attn_bias) + x
x = ff(x) + x
out = self.norm(x)
return self.project_out(out)
return self.norm(x)
class DiffusionPriorNetwork(nn.Module):
def __init__(
@@ -421,41 +404,25 @@ class DiffusionPriorNetwork(nn.Module):
image_embed,
diffusion_timesteps,
*,
text_encodings,
text_embed,
text_encodings = None,
mask = None,
cond_drop_prob = 0.2
):
batch, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype
batch, text_enc_len, device = image_embed.shape[0], text_encodings.shape[-2], image_embed.device
# in section 2.2, last paragraph
# "... consisting of encoded text, CLIP text embedding, diffusion timestep embedding, noised CLIP image embedding, final embedding for prediction"
text_embed, image_embed = rearrange_many((text_embed, image_embed), 'b d -> b 1 d')
# make text encodings optional
# although the paper seems to suggest it is present <--
if not exists(text_encodings):
text_encodings = torch.empty((batch, 0, dim), device = device, dtype = dtype)
if not exists(mask):
mask = torch.ones((batch, text_encodings.shape[-2]), device = device, dtype = torch.bool)
# classifier free guidance
cond_prob_mask = prob_mask_like((batch,), cond_drop_prob, device = device)
cond_prob_mask = rearrange(cond_prob_mask, 'b -> b 1')
mask &= cond_prob_mask
# whether text embedding is masked or not depends on the classifier free guidance conditional masking
mask = torch.cat((mask, cond_prob_mask), dim = 1)
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
# but let's just do it right
if exists(mask):
not_all_masked_out = mask.any(dim = -1)
mask = torch.cat((mask, rearrange(not_all_masked_out, 'b -> b 1')), dim = 1)
if exists(mask):
mask = F.pad(mask, (0, 2), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
@@ -471,6 +438,16 @@ class DiffusionPriorNetwork(nn.Module):
learned_queries
), dim = -2)
# mask if it doesn't exist
if not exists(mask):
mask = torch.ones((batch, text_enc_len), device = device, dtype = torch.bool)
# classifier free guidance
cond_prob_mask = prob_mask_like((batch,), cond_drop_prob, device = device)
mask &= rearrange(cond_prob_mask, 'b -> b 1')
# attend
tokens = self.causal_transformer(tokens, mask = mask)
@@ -487,12 +464,11 @@ class DiffusionPrior(nn.Module):
net,
*,
clip,
timesteps = 1000,
cond_drop_prob = 0.2,
loss_type = "l1",
predict_x_start = True,
beta_schedule = "cosine",
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
timesteps=1000,
cond_drop_prob=0.2,
loss_type="l1",
predict_x0=True,
beta_schedule="cosine",
):
super().__init__()
assert isinstance(clip, CLIP)
@@ -503,11 +479,9 @@ class DiffusionPrior(nn.Module):
self.image_embed_dim = clip.dim_latent
self.channels = clip.image_channels
self.image_size = clip.image_size
self.cond_drop_prob = cond_drop_prob
self.condition_on_text_encodings = condition_on_text_encodings
self.predict_x_start = predict_x_start
self.predict_x0 = predict_x0
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
if beta_schedule == "cosine":
@@ -524,7 +498,7 @@ class DiffusionPrior(nn.Module):
raise NotImplementedError()
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis = 0)
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
timesteps, = betas.shape
@@ -557,23 +531,17 @@ class DiffusionPrior(nn.Module):
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
@torch.no_grad()
def get_image_embed(self, image):
image_encoding = self.clip.visual_transformer(image)
image_cls = image_encoding[:, 0]
image_embed = self.clip.to_visual_latent(image_cls)
return l2norm(image_embed)
@torch.no_grad()
def get_text_cond(self, text):
text_encodings = self.clip.text_transformer(text)
text_cls, text_encodings = text_encodings[:, 0], text_encodings[:, 1:]
text_embed = self.clip.to_text_latent(text_cls)
text_embed = l2norm(text_embed)
if not self.condition_on_text_encodings:
return dict(text_embed = text_embed)
return dict(text_encodings = text_encodings, text_embed = text_embed, mask = text != 0)
def q_mean_variance(self, x_start, t):
@@ -598,16 +566,14 @@ class DiffusionPrior(nn.Module):
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, x, t, text_cond, clip_denoised: bool):
pred = self.net(x, t, **text_cond)
if self.predict_x_start:
x_recon = pred
if self.predict_x0:
x_recon = self.net(x, t, **text_cond)
# not 100% sure of this above line - for any spectators, let me know in the github issues (or through a pull request) if you know how to correctly do this
# i'll be rereading https://arxiv.org/abs/2111.14822, where i think a similar approach is taken
else:
x_recon = self.predict_start_from_noise(x, t = t, noise = pred)
x_recon = self.predict_start_from_noise(x, t = t, noise = self.net(x, t, **text_cond))
if clip_denoised and not self.predict_x_start:
if clip_denoised:
x_recon.clamp_(-1., 1.)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
@@ -653,7 +619,7 @@ class DiffusionPrior(nn.Module):
**text_cond
)
to_predict = noise if not self.predict_x_start else image_embed
to_predict = noise if not self.predict_x0 else image_embed
if self.loss_type == 'l1':
loss = F.l1_loss(to_predict, x_recon)
@@ -692,41 +658,13 @@ class DiffusionPrior(nn.Module):
top_image_embeds = image_embeds.gather(1, top_sim_indices)
return rearrange(top_image_embeds, 'b 1 d -> b d')
def forward(
self,
text = None,
image = None,
text_embed = None, # allow for training on preprocessed CLIP text and image embeddings
image_embed = None,
text_encodings = None, # as well as CLIP text encodings
text_mask = None, # text mask <- may eventually opt for the learned padding tokens technique from DALL-E1 to reduce complexity
*args,
**kwargs
):
assert exists(text) ^ exists(text_embed), 'either text or text embedding must be supplied'
assert exists(image) ^ exists(image_embed), 'either text or text embedding must be supplied'
assert not (self.condition_on_text_encodings and (not exists(text_encodings) and not exists(text))), 'text encodings must be present if you specified you wish to condition on it on initialization'
def forward(self, text, image, *args, **kwargs):
b, device, img_size, = image.shape[0], image.device, self.image_size
check_shape(image, 'b c h w', h = img_size, w = img_size, c = self.channels)
if exists(image):
image_embed = self.get_image_embed(image)
# calculate text conditionings, based on what is passed in
if exists(text):
text_cond = self.get_text_cond(text)
else:
text_cond = dict(
text_embed = text_embed,
text_encodings = text_encodings,
mask = text_mask
)
# timestep conditioning from ddpm
batch, device = image_embed.shape[0], image_embed.device
times = torch.randint(0, self.num_timesteps, (batch,), device = device, dtype = torch.long)
# calculate forward loss
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
image_embed = self.get_image_embed(image)
text_cond = self.get_text_cond(text)
loss = self.p_losses(image_embed, times, text_cond = text_cond, *args, **kwargs)
return loss
@@ -782,7 +720,7 @@ class ConvNextBlock(nn.Module):
inner_dim = int(dim_out * mult)
self.net = nn.Sequential(
ChanLayerNorm(dim) if norm else nn.Identity(),
ChanRMSNorm(dim) if norm else nn.Identity(),
nn.Conv2d(dim, inner_dim, 3, padding = 1),
nn.GELU(),
nn.Conv2d(inner_dim, dim_out, 3, padding = 1)
@@ -818,8 +756,8 @@ class CrossAttention(nn.Module):
context_dim = default(context_dim, dim)
self.norm = LayerNorm(dim)
self.norm_context = LayerNorm(context_dim)
self.norm = RMSNorm(dim)
self.norm_context = RMSNorm(context_dim)
self.dropout = nn.Dropout(dropout)
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
@@ -861,20 +799,6 @@ class CrossAttention(nn.Module):
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class GridAttention(nn.Module):
def __init__(self, *args, window_size = 8, **kwargs):
super().__init__()
self.window_size = window_size
self.attn = Attention(*args, **kwargs)
def forward(self, x):
h, w = x.shape[-2:]
wsz = self.window_size
x = rearrange(x, 'b c (w1 h) (w2 w) -> (b h w) (w1 w2) c', w1 = wsz, w2 = wsz)
out = self.attn(x)
out = rearrange(out, '(b h w) (w1 w2) c -> b c (w1 h) (w2 w)', w1 = wsz, w2 = wsz, h = h // wsz, w = w // wsz)
return out
class Unet(nn.Module):
def __init__(
self,
@@ -883,30 +807,23 @@ class Unet(nn.Module):
image_embed_dim,
cond_dim = None,
num_image_tokens = 4,
num_time_tokens = 2,
out_dim = None,
dim_mults=(1, 2, 4, 8),
channels = 3,
attn_dim_head = 32,
attn_heads = 8,
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
sparse_attn = False,
sparse_attn_window = 8, # window size for sparse attention
lowres_cond_upsample_mode = 'bilinear',
blur_sigma = 0.1,
blur_kernel_size = 3,
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
cond_on_text_encodings = False,
max_text_len = 256,
cond_on_image_embeds = False,
):
super().__init__()
# save locals to take care of some hyperparameters for cascading DDPM
self._locals = locals()
del self._locals['self']
del self._locals['__class__']
# for eventual cascading diffusion
self.lowres_cond = lowres_cond
self.lowres_cond_upsample_mode = lowres_cond_upsample_mode
self.lowres_blur_kernel_size = blur_kernel_size
self.lowres_blur_sigma = blur_sigma
# determine dimensions
@@ -925,8 +842,8 @@ class Unet(nn.Module):
SinusoidalPosEmb(dim),
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Linear(dim * 4, cond_dim * num_time_tokens),
Rearrange('b (r d) -> b r d', r = num_time_tokens)
nn.Linear(dim * 4, cond_dim),
Rearrange('b d -> b 1 d')
)
self.image_to_cond = nn.Sequential(
@@ -936,20 +853,10 @@ class Unet(nn.Module):
self.text_to_cond = nn.LazyLinear(cond_dim)
# finer control over whether to condition on image embeddings and text encodings
# so one can have the latter unets in the cascading DDPMs only focus on super-resoluting
self.cond_on_text_encodings = cond_on_text_encodings
self.cond_on_image_embeds = cond_on_image_embeds
# for classifier free guidance
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim))
# attention related params
attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head)
self.null_text_embed = nn.Parameter(torch.randn(1, 1, cond_dim))
# layers
@@ -964,7 +871,6 @@ class Unet(nn.Module):
self.downs.append(nn.ModuleList([
ConvNextBlock(dim_in, dim_out, norm = ind != 0),
Residual(GridAttention(dim_out, window_size = sparse_attn_window, **attn_kwargs)) if sparse_attn else nn.Identity(),
ConvNextBlock(dim_out, dim_out, cond_dim = layer_cond_dim),
Downsample(dim_out) if not is_last else nn.Identity()
]))
@@ -972,7 +878,7 @@ class Unet(nn.Module):
mid_dim = dims[-1]
self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim))) if attend_at_middle else None
self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
@@ -981,7 +887,6 @@ class Unet(nn.Module):
self.ups.append(nn.ModuleList([
ConvNextBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim),
Residual(GridAttention(dim_in, window_size = sparse_attn_window, **attn_kwargs)) if sparse_attn else nn.Identity(),
ConvNextBlock(dim_in, dim_in, cond_dim = layer_cond_dim),
Upsample(dim_in)
]))
@@ -992,20 +897,6 @@ class Unet(nn.Module):
nn.Conv2d(dim, out_dim, 1)
)
# if the current settings for the unet are not correct
# for cascading DDPM, then reinit the unet with the right settings
def cast_model_parameters(
self,
*,
lowres_cond,
channels
):
if lowres_cond == self.lowres_cond and channels == self.channels:
return self
updated_kwargs = {**self._locals, 'lowres_cond': lowres_cond, 'channels': channels}
return self.__class__(**updated_kwargs)
def forward_with_cond_scale(
self,
*args,
@@ -1039,6 +930,13 @@ class Unet(nn.Module):
assert not (self.lowres_cond and not exists(lowres_cond_img)), 'low resolution conditioning image must be present'
if exists(lowres_cond_img):
if self.training:
# when training, blur the low resolution conditional image
blur_sigma = default(blur_sigma, self.lowres_blur_sigma)
blur_kernel_size = default(blur_kernel_size, self.lowres_blur_kernel_size)
lowres_cond_img = gaussian_blur2d(lowres_cond_img, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2))
lowres_cond_img = resize_image_to(lowres_cond_img, x.shape[-2:], mode = self.lowres_cond_upsample_mode)
x = torch.cat((x, lowres_cond_img), dim = 1)
# time conditioning
@@ -1053,48 +951,39 @@ class Unet(nn.Module):
# mask out image embedding depending on condition dropout
# for classifier free guidance
image_tokens = None
image_tokens = self.image_to_cond(image_embed)
if self.cond_on_image_embeds:
image_tokens = self.image_to_cond(image_embed)
image_tokens = torch.where(
cond_prob_mask,
image_tokens,
self.null_image_embed
)
image_tokens = torch.where(
cond_prob_mask,
image_tokens,
self.null_image_embed
)
# take care of text encodings (optional)
text_tokens = None
if exists(text_encodings) and self.cond_on_text_encodings:
if exists(text_encodings):
text_tokens = self.text_to_cond(text_encodings)
text_tokens = torch.where(
cond_prob_mask,
text_tokens,
self.null_text_embed[:, :text_tokens.shape[1]]
self.null_text_embed
)
# main conditioning tokens (c)
c = time_tokens
if exists(image_tokens):
c = torch.cat((c, image_tokens), dim = -2)
c = torch.cat((time_tokens, image_tokens), dim = -2)
# text and image conditioning tokens (mid_c)
# to save on compute, only do cross attention based conditioning on the inner most layers of the Unet
mid_c = c if not exists(text_tokens) else torch.cat((c, text_tokens), dim = -2)
mid_c = c if not exists(text_encodings) else torch.cat((c, text_tokens), dim = -2)
# go through the layers of the unet, down and up
hiddens = []
for convnext, sparse_attn, convnext2, downsample in self.downs:
for convnext, convnext2, downsample in self.downs:
x = convnext(x, c)
x = sparse_attn(x)
x = convnext2(x, c)
hiddens.append(x)
x = downsample(x)
@@ -1106,72 +995,25 @@ class Unet(nn.Module):
x = self.mid_block2(x, mid_c)
for convnext, sparse_attn, convnext2, upsample in self.ups:
for convnext, convnext2, upsample in self.ups:
x = torch.cat((x, hiddens.pop()), dim=1)
x = convnext(x, c)
x = sparse_attn(x)
x = convnext2(x, c)
x = upsample(x)
return self.final_conv(x)
class LowresConditioner(nn.Module):
def __init__(
self,
cond_upsample_mode = 'bilinear',
downsample_first = True,
blur_sigma = 0.1,
blur_kernel_size = 3,
):
super().__init__()
self.cond_upsample_mode = cond_upsample_mode
self.downsample_first = downsample_first
self.blur_sigma = blur_sigma
self.blur_kernel_size = blur_kernel_size
def forward(
self,
cond_fmap,
*,
target_image_size,
downsample_image_size = None,
blur_sigma = None,
blur_kernel_size = None
):
target_image_size = cast_tuple(target_image_size, 2)
if self.training and self.downsample_first and exists(downsample_image_size):
cond_fmap = resize_image_to(cond_fmap, target_image_size, mode = self.cond_upsample_mode)
if self.training:
# when training, blur the low resolution conditional image
blur_sigma = default(blur_sigma, self.blur_sigma)
blur_kernel_size = default(blur_kernel_size, self.blur_kernel_size)
cond_fmap = gaussian_blur2d(cond_fmap, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2))
cond_fmap = resize_image_to(cond_fmap, target_image_size, mode = self.cond_upsample_mode)
return cond_fmap
class Decoder(nn.Module):
def __init__(
self,
unet,
*,
clip,
vae = tuple(),
timesteps = 1000,
cond_drop_prob = 0.2,
loss_type = 'l1',
beta_schedule = 'cosine',
predict_x_start = False,
predict_x_start_for_latent_diffusion = False,
image_sizes = None, # for cascading ddpm, image size at each stage
lowres_cond_upsample_mode = 'bilinear', # cascading ddpm - low resolution upsample mode
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
blur_sigma = 0.1, # cascading ddpm - blur sigma
blur_kernel_size = 3, # cascading ddpm - blur kernel size
condition_on_text_encodings = False, # the paper suggested that this didn't do much in the decoder, but i'm allowing the option for experimentation
image_sizes = None # for cascading ddpm, image size at each stage
):
super().__init__()
assert isinstance(clip, CLIP)
@@ -1180,65 +1022,18 @@ class Decoder(nn.Module):
self.clip_image_size = clip.image_size
self.channels = clip.image_channels
self.condition_on_text_encodings = condition_on_text_encodings
# automatically take care of ensuring that first unet is unconditional
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
unets = cast_tuple(unet)
vaes = pad_tuple_to_length(cast_tuple(vae), len(unets), fillvalue = NullVQGanVAE(channels = self.channels))
self.unets = nn.ModuleList([])
self.vaes = nn.ModuleList([])
for ind, (one_unet, one_vae) in enumerate(zip(unets, vaes)):
assert isinstance(one_unet, Unet)
assert isinstance(one_vae, (VQGanVAE, NullVQGanVAE))
is_first = ind == 0
latent_dim = one_vae.encoded_dim if exists(one_vae) else None
unet_channels = default(latent_dim, self.channels)
one_unet = one_unet.cast_model_parameters(
lowres_cond = not is_first,
channels = unet_channels
)
self.unets.append(one_unet)
self.vaes.append(one_vae.copy_for_eval())
# unet image sizes
self.unets = nn.ModuleList(unet)
image_sizes = default(image_sizes, (clip.image_size,))
image_sizes = tuple(sorted(set(image_sizes)))
assert len(self.unets) == len(image_sizes), f'you did not supply the correct number of u-nets ({len(self.unets)}) for resolutions {image_sizes}'
self.image_sizes = image_sizes
self.sample_channels = cast_tuple(self.channels, len(image_sizes))
# predict x0 config
self.predict_x_start = cast_tuple(predict_x_start, len(unets)) if not predict_x_start_for_latent_diffusion else tuple(map(lambda t: isinstance(t, VQGanVAE), self.vaes))
# cascading ddpm related stuff
lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets))
assert lowres_conditions == (False, *((True,) * (len(self.unets) - 1))), 'the first unet must be unconditioned (by low resolution image), and the rest of the unets must have `lowres_cond` set to True'
self.to_lowres_cond = LowresConditioner(
cond_upsample_mode = lowres_cond_upsample_mode,
downsample_first = lowres_downsample_first,
blur_sigma = blur_sigma,
blur_kernel_size = blur_kernel_size,
)
# classifier free guidance
self.cond_drop_prob = cond_drop_prob
# noise schedule
if beta_schedule == "cosine":
betas = cosine_beta_schedule(timesteps)
elif beta_schedule == "linear":
@@ -1253,7 +1048,7 @@ class Decoder(nn.Module):
raise NotImplementedError()
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis = 0)
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
timesteps, = betas.shape
@@ -1286,31 +1081,10 @@ class Decoder(nn.Module):
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
def get_unet(self, unet_number):
assert 0 < unet_number <= len(self.unets)
index = unet_number - 1
return self.unets[index]
@contextmanager
def one_unet_in_gpu(self, unet_number = None, unet = None):
assert exists(unet_number) ^ exists(unet)
if exists(unet_number):
unet = self.get_unet(unet_number)
self.cuda()
self.unets.cpu()
unet.cuda()
yield
unet.cpu()
@torch.no_grad()
def get_text_encodings(self, text):
text_encodings = self.clip.text_transformer(text)
return text_encodings[:, 1:]
@torch.no_grad()
def get_image_embed(self, image):
image = resize_image_to(image, self.clip_image_size)
image_encoding = self.clip.visual_transformer(image)
@@ -1339,48 +1113,34 @@ class Decoder(nn.Module):
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, cond_scale = 1.):
pred = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, lowres_cond_img = None, clip_denoised = True, cond_scale = 1.):
pred_noise = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
x_recon = self.predict_start_from_noise(x, t = t, noise = pred_noise)
if predict_x_start:
x_recon = pred
else:
x_recon = self.predict_start_from_noise(x, t = t, noise = pred)
if clip_denoised and not predict_x_start:
if clip_denoised:
x_recon.clamp_(-1., 1.)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad()
def p_sample(self, unet, x, t, image_embed, text_encodings = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, clip_denoised = True, repeat_noise = False):
def p_sample(self, unet, x, t, image_embed, text_encodings = None, cond_scale = 1., lowres_cond_img = None, clip_denoised = True, repeat_noise = False):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start)
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised)
noise = noise_like(x.shape, device, repeat_noise)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad()
def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, lowres_cond_img = None, text_encodings = None, cond_scale = 1):
def p_sample_loop(self, unet, shape, image_embed, lowres_cond_img = None, text_encodings = None, cond_scale = 1):
device = self.betas.device
b = shape[0]
img = torch.randn(shape, device = device)
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
img = self.p_sample(
unet,
img,
torch.full((b,), i, device = device, dtype = torch.long),
image_embed = image_embed,
text_encodings = text_encodings,
cond_scale = cond_scale,
lowres_cond_img = lowres_cond_img,
predict_x_start = predict_x_start
)
img = self.p_sample(unet, img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
return img
def q_sample(self, x_start, t, noise=None):
@@ -1391,7 +1151,7 @@ class Decoder(nn.Module):
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)
def p_losses(self, unet, x_start, t, *, image_embed, lowres_cond_img = None, text_encodings = None, predict_x_start = False, noise = None):
def p_losses(self, unet, x_start, t, *, image_embed, lowres_cond_img = None, text_encodings = None, noise = None):
noise = default(noise, lambda: torch.randn_like(x_start))
x_noisy = self.q_sample(x_start = x_start, t = t, noise = noise)
@@ -1405,14 +1165,12 @@ class Decoder(nn.Module):
cond_drop_prob = self.cond_drop_prob
)
target = noise if not predict_x_start else x_start
if self.loss_type == 'l1':
loss = F.l1_loss(target, x_recon)
loss = F.l1_loss(noise, x_recon)
elif self.loss_type == 'l2':
loss = F.mse_loss(target, x_recon)
loss = F.mse_loss(noise, x_recon)
elif self.loss_type == "huber":
loss = F.smooth_l1_loss(target, x_recon)
loss = F.smooth_l1_loss(noise, x_recon)
else:
raise NotImplementedError()
@@ -1422,58 +1180,25 @@ class Decoder(nn.Module):
@eval_decorator
def sample(self, image_embed, text = None, cond_scale = 1.):
batch_size = image_embed.shape[0]
channels = self.channels
text_encodings = self.get_text_encodings(text) if exists(text) else None
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
img = None
for unet, vae, channel, image_size, predict_x_start in tqdm(zip(self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)):
with self.one_unet_in_gpu(unet = unet):
lowres_cond_img = None
shape = (batch_size, channel, image_size, image_size)
if unet.lowres_cond:
lowres_cond_img = self.to_lowres_cond(img, target_image_size = image_size)
image_size = vae.get_encoded_fmap_size(image_size)
shape = (batch_size, vae.encoded_dim, image_size, image_size)
if exists(lowres_cond_img):
lowres_cond_img = vae.encode(lowres_cond_img)
img = self.p_sample_loop(
unet,
shape,
image_embed = image_embed,
text_encodings = text_encodings,
cond_scale = cond_scale,
predict_x_start = predict_x_start,
lowres_cond_img = lowres_cond_img
)
img = vae.decode(img)
for unet, image_size in tqdm(zip(self.unets, self.image_sizes)):
shape = (batch_size, channels, image_size, image_size)
img = self.p_sample_loop(unet, shape, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = img)
return img
def forward(
self,
image,
text = None,
image_embed = None,
text_encodings = None,
unet_number = None
):
def forward(self, image, text = None, unet_number = None):
assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
unet_number = default(unet_number, 1)
unet_index = unet_number - 1
assert 1 <= unet_number <= len(self.unets)
unet = self.get_unet(unet_number)
target_image_size = self.image_sizes[unet_index]
vae = self.vaes[unet_index]
predict_x_start = self.predict_x_start[unet_index]
index = unet_number - 1
unet = self.unets[index]
target_image_size = self.image_sizes[index]
b, c, h, w, device, = *image.shape, image.device
@@ -1482,24 +1207,12 @@ class Decoder(nn.Module):
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
if not exists(image_embed):
image_embed = self.get_image_embed(image)
image_embed = self.get_image_embed(image)
text_encodings = self.get_text_encodings(text) if exists(text) else None
text_encodings = self.get_text_encodings(text) if exists(text) and not exists(text_encodings) else None
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None
image = resize_image_to(image, target_image_size)
vae.eval()
with torch.no_grad():
image = vae.encode(image)
if exists(lowres_cond_img):
lowres_cond_img = vae.encode(lowres_cond_img)
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start)
lowres_cond_img = image if index > 0 else None
ddpm_image = resize_image_to(image, target_image_size)
return self.p_losses(unet, ddpm_image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img)
# main class
@@ -1516,9 +1229,7 @@ class DALLE2(nn.Module):
assert isinstance(decoder, Decoder)
self.prior = prior
self.decoder = decoder
self.prior_num_samples = prior_num_samples
self.decoder_need_text_cond = self.decoder.condition_on_text_encodings
@torch.no_grad()
@eval_decorator
@@ -1528,18 +1239,11 @@ class DALLE2(nn.Module):
cond_scale = 1.
):
device = next(self.parameters()).device
one_text = isinstance(text, str) or (not is_list_str(text) and text.shape[0] == 1)
if isinstance(text, str) or is_list_str(text):
text = [text] if not isinstance(text, (list, tuple)) else text
text = tokenizer.tokenize(text).to(device)
image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples)
text_cond = text if self.decoder_need_text_cond else None
images = self.decoder.sample(image_embed, text = text_cond, cond_scale = cond_scale)
if one_text:
return images[0]
images = self.decoder.sample(image_embed, cond_scale = cond_scale)
return images

View File

@@ -1,53 +0,0 @@
import copy
import torch
from torch import nn
# exponential moving average wrapper
class EMA(nn.Module):
def __init__(
self,
model,
beta = 0.99,
ema_update_after_step = 1000,
ema_update_every = 10,
):
super().__init__()
self.beta = beta
self.online_model = model
self.ema_model = copy.deepcopy(model)
self.ema_update_after_step = ema_update_after_step # only start EMA after this step number, starting at 0
self.ema_update_every = ema_update_every
self.register_buffer('initted', torch.Tensor([False]))
self.register_buffer('step', torch.tensor([0.]))
def update(self):
self.step += 1
if self.step <= self.ema_update_after_step or (self.step % self.ema_update_every) != 0:
return
if not self.initted:
self.ema_model.state_dict(self.online_model.state_dict())
self.initted.data.copy_(torch.Tensor([True]))
self.update_moving_average(self.ema_model, self.online_model)
def update_moving_average(ma_model, current_model):
def calculate_ema(beta, old, new):
if not exists(old):
return new
return old * beta + (1 - beta) * new
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
old_weight, up_weight = ma_params.data, current_params.data
ma_params.data = calculate_ema(self.beta, old_weight, up_weight)
for current_buffer, ma_buffer in zip(current_model.buffers(), ma_model.buffers()):
new_buffer_value = calculate_ema(self.beta, ma_buffer, current_buffer)
ma_buffer.copy_(new_buffer_value)
def __call__(self, *args, **kwargs):
return self.ema_model(*args, **kwargs)

View File

@@ -1,566 +0,0 @@
import copy
import math
from math import sqrt
from functools import partial, wraps
from vector_quantize_pytorch import VectorQuantize as VQ
import torch
from torch import nn, einsum
import torch.nn.functional as F
from torch.autograd import grad as torch_grad
import torchvision
from einops import rearrange, reduce, repeat
from dalle2_pytorch.attention import QueryAttnUpsample
# constants
MList = nn.ModuleList
# helper functions
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
# decorators
def eval_decorator(fn):
def inner(model, *args, **kwargs):
was_training = model.training
model.eval()
out = fn(model, *args, **kwargs)
model.train(was_training)
return out
return inner
def remove_vgg(fn):
@wraps(fn)
def inner(self, *args, **kwargs):
has_vgg = hasattr(self, 'vgg')
if has_vgg:
vgg = self.vgg
delattr(self, 'vgg')
out = fn(self, *args, **kwargs)
if has_vgg:
self.vgg = vgg
return out
return inner
# keyword argument helpers
def pick_and_pop(keys, d):
values = list(map(lambda key: d.pop(key), keys))
return dict(zip(keys, values))
def group_dict_by_key(cond, d):
return_val = [dict(),dict()]
for key in d.keys():
match = bool(cond(key))
ind = int(not match)
return_val[ind][key] = d[key]
return (*return_val,)
def string_begins_with(prefix, str):
return str.startswith(prefix)
def group_by_key_prefix(prefix, d):
return group_dict_by_key(partial(string_begins_with, prefix), d)
def groupby_prefix_and_trim(prefix, d):
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
return kwargs_without_prefix, kwargs
# tensor helper functions
def log(t, eps = 1e-10):
return torch.log(t + eps)
def gradient_penalty(images, output, weight = 10):
batch_size = images.shape[0]
gradients = torch_grad(outputs = output, inputs = images,
grad_outputs = torch.ones(output.size(), device = images.device),
create_graph = True, retain_graph = True, only_inputs = True)[0]
gradients = rearrange(gradients, 'b ... -> b (...)')
return weight * ((gradients.norm(2, dim = 1) - 1) ** 2).mean()
def l2norm(t):
return F.normalize(t, dim = -1)
def leaky_relu(p = 0.1):
return nn.LeakyReLU(0.1)
def stable_softmax(t, dim = -1, alpha = 32 ** 2):
t = t / alpha
t = t - torch.amax(t, dim = dim, keepdim = True).detach()
return (t * alpha).softmax(dim = dim)
def safe_div(numer, denom, eps = 1e-8):
return numer / (denom + eps)
# gan losses
def hinge_discr_loss(fake, real):
return (F.relu(1 + fake) + F.relu(1 - real)).mean()
def hinge_gen_loss(fake):
return -fake.mean()
def bce_discr_loss(fake, real):
return (-log(1 - torch.sigmoid(fake)) - log(torch.sigmoid(real))).mean()
def bce_gen_loss(fake):
return -log(torch.sigmoid(fake)).mean()
def grad_layer_wrt_loss(loss, layer):
return torch_grad(
outputs = loss,
inputs = layer,
grad_outputs = torch.ones_like(loss),
retain_graph = True
)[0].detach()
# vqgan vae
class LayerNormChan(nn.Module):
def __init__(
self,
dim,
eps = 1e-5
):
super().__init__()
self.eps = eps
self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1))
def forward(self, x):
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (var + self.eps).sqrt() * self.gamma
class Discriminator(nn.Module):
def __init__(
self,
dims,
channels = 3,
groups = 16,
init_kernel_size = 5
):
super().__init__()
dim_pairs = zip(dims[:-1], dims[1:])
self.layers = MList([nn.Sequential(nn.Conv2d(channels, dims[0], init_kernel_size, padding = init_kernel_size // 2), leaky_relu())])
for dim_in, dim_out in dim_pairs:
self.layers.append(nn.Sequential(
nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1),
nn.GroupNorm(groups, dim_out),
leaky_relu()
))
dim = dims[-1]
self.to_logits = nn.Sequential( # return 5 x 5, for PatchGAN-esque training
nn.Conv2d(dim, dim, 1),
leaky_relu(),
nn.Conv2d(dim, 1, 4)
)
def forward(self, x):
for net in self.layers:
x = net(x)
return self.to_logits(x)
class ContinuousPositionBias(nn.Module):
""" from https://arxiv.org/abs/2111.09883 """
def __init__(self, *, dim, heads, layers = 2):
super().__init__()
self.net = MList([])
self.net.append(nn.Sequential(nn.Linear(2, dim), leaky_relu()))
for _ in range(layers - 1):
self.net.append(nn.Sequential(nn.Linear(dim, dim), leaky_relu()))
self.net.append(nn.Linear(dim, heads))
self.register_buffer('rel_pos', None, persistent = False)
def forward(self, x):
n, device = x.shape[-1], x.device
fmap_size = int(sqrt(n))
if not exists(self.rel_pos):
pos = torch.arange(fmap_size, device = device)
grid = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij'))
grid = rearrange(grid, 'c i j -> (i j) c')
rel_pos = rearrange(grid, 'i c -> i 1 c') - rearrange(grid, 'j c -> 1 j c')
rel_pos = torch.sign(rel_pos) * torch.log(rel_pos.abs() + 1)
self.register_buffer('rel_pos', rel_pos, persistent = False)
rel_pos = self.rel_pos.float()
for layer in self.net:
rel_pos = layer(rel_pos)
bias = rearrange(rel_pos, 'i j h -> h i j')
return x + bias
class GLUResBlock(nn.Module):
def __init__(self, chan, groups = 16):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(chan, chan * 2, 3, padding = 1),
nn.GLU(dim = 1),
nn.GroupNorm(groups, chan),
nn.Conv2d(chan, chan * 2, 3, padding = 1),
nn.GLU(dim = 1),
nn.GroupNorm(groups, chan),
nn.Conv2d(chan, chan, 1)
)
def forward(self, x):
return self.net(x) + x
class ResBlock(nn.Module):
def __init__(self, chan, groups = 16):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(chan, chan, 3, padding = 1),
nn.GroupNorm(groups, chan),
leaky_relu(),
nn.Conv2d(chan, chan, 3, padding = 1),
nn.GroupNorm(groups, chan),
leaky_relu(),
nn.Conv2d(chan, chan, 1)
)
def forward(self, x):
return self.net(x) + x
# vqgan attention layer
class VQGanAttention(nn.Module):
def __init__(
self,
*,
dim,
dim_head = 64,
heads = 8,
dropout = 0.
):
super().__init__()
self.heads = heads
self.scale = dim_head ** -0.5
inner_dim = heads * dim_head
self.dropout = nn.Dropout(dropout)
self.pre_norm = LayerNormChan(dim)
self.cpb = ContinuousPositionBias(dim = dim // 4, heads = heads)
self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)
self.to_out = nn.Conv2d(inner_dim, dim, 1, bias = False)
def forward(self, x):
h = self.heads
height, width, residual = *x.shape[-2:], x.clone()
x = self.pre_norm(x)
q, k, v = self.to_qkv(x).chunk(3, dim = 1)
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = h), (q, k, v))
sim = einsum('b h c i, b h c j -> b h i j', q, k) * self.scale
sim = self.cpb(sim)
attn = stable_softmax(sim, dim = -1)
attn = self.dropout(attn)
out = einsum('b h i j, b h c j -> b h c i', attn, v)
out = rearrange(out, 'b h c (x y) -> b (h c) x y', x = height, y = width)
out = self.to_out(out)
return out + residual
class NullVQGanVAE(nn.Module):
def __init__(
self,
*,
channels
):
super().__init__()
self.encoded_dim = channels
self.layers = 0
def get_encoded_fmap_size(self, size):
return size
def copy_for_eval(self):
return self
def encode(self, x):
return x
def decode(self, x):
return x
class VQGanVAE(nn.Module):
def __init__(
self,
*,
dim,
image_size,
channels = 3,
layers = 4,
layer_mults = None,
l2_recon_loss = False,
use_hinge_loss = True,
num_resnet_blocks = 1,
vgg = None,
vq_codebook_size = 512,
vq_decay = 0.8,
vq_commitment_weight = 1.,
vq_kmeans_init = True,
vq_use_cosine_sim = True,
use_attn = True,
attn_dim_head = 64,
attn_heads = 8,
resnet_groups = 16,
attn_dropout = 0.,
first_conv_kernel_size = 5,
use_vgg_and_gan = True,
**kwargs
):
super().__init__()
assert dim % resnet_groups == 0, f'dimension {dim} must be divisible by {resnet_groups} (groups for the groupnorm)'
vq_kwargs, kwargs = groupby_prefix_and_trim('vq_', kwargs)
self.image_size = image_size
self.channels = channels
self.layers = layers
self.fmap_size = image_size // (layers ** 2)
self.codebook_size = vq_codebook_size
self.encoders = MList([])
self.decoders = MList([])
layer_mults = default(layer_mults, list(map(lambda t: 2 ** t, range(layers))))
assert len(layer_mults) == layers, 'layer multipliers must be equal to designated number of layers'
layer_dims = [dim * mult for mult in layer_mults]
dims = (dim, *layer_dims)
codebook_dim = layer_dims[-1]
self.encoded_dim = dims[-1]
dim_pairs = zip(dims[:-1], dims[1:])
append = lambda arr, t: arr.append(t)
prepend = lambda arr, t: arr.insert(0, t)
if not isinstance(num_resnet_blocks, tuple):
num_resnet_blocks = (*((0,) * (layers - 1)), num_resnet_blocks)
if not isinstance(use_attn, tuple):
use_attn = (*((False,) * (layers - 1)), use_attn)
assert len(num_resnet_blocks) == layers, 'number of resnet blocks config must be equal to number of layers'
assert len(use_attn) == layers
for layer_index, (dim_in, dim_out), layer_num_resnet_blocks, layer_use_attn in zip(range(layers), dim_pairs, num_resnet_blocks, use_attn):
append(self.encoders, nn.Sequential(nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), leaky_relu()))
prepend(self.decoders, nn.Sequential(nn.ConvTranspose2d(dim_out, dim_in, 4, 2, 1), leaky_relu()))
if layer_use_attn:
prepend(self.decoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
for _ in range(layer_num_resnet_blocks):
append(self.encoders, ResBlock(dim_out, groups = resnet_groups))
prepend(self.decoders, GLUResBlock(dim_out, groups = resnet_groups))
if layer_use_attn:
append(self.encoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
prepend(self.encoders, nn.Conv2d(channels, dim, first_conv_kernel_size, padding = first_conv_kernel_size // 2))
append(self.decoders, nn.Conv2d(dim, channels, 1))
self.vq = VQ(
dim = codebook_dim,
codebook_size = vq_codebook_size,
decay = vq_decay,
commitment_weight = vq_commitment_weight,
accept_image_fmap = True,
kmeans_init = vq_kmeans_init,
use_cosine_sim = vq_use_cosine_sim,
**vq_kwargs
)
# reconstruction loss
self.recon_loss_fn = F.mse_loss if l2_recon_loss else F.l1_loss
# turn off GAN and perceptual loss if grayscale
self.vgg = None
self.discr = None
self.use_vgg_and_gan = use_vgg_and_gan
if not use_vgg_and_gan:
return
# preceptual loss
if exists(vgg):
self.vgg = vgg
else:
self.vgg = torchvision.models.vgg16(pretrained = True)
self.vgg.classifier = nn.Sequential(*self.vgg.classifier[:-2])
# gan related losses
self.discr = Discriminator(dims = dims, channels = channels)
self.discr_loss = hinge_discr_loss if use_hinge_loss else bce_discr_loss
self.gen_loss = hinge_gen_loss if use_hinge_loss else bce_gen_loss
def get_encoded_fmap_size(self, image_size):
return image_size // (2 ** self.layers)
def copy_for_eval(self):
device = next(self.parameters()).device
vae_copy = copy.deepcopy(self.cpu())
if vae_copy.use_vgg_and_gan:
del vae_copy.discr
del vae_copy.vgg
vae_copy.eval()
return vae_copy.to(device)
@remove_vgg
def state_dict(self, *args, **kwargs):
return super().state_dict(*args, **kwargs)
@remove_vgg
def load_state_dict(self, *args, **kwargs):
return super().load_state_dict(*args, **kwargs)
@property
def codebook(self):
return self.vq.codebook
def encode(self, fmap):
for enc in self.encoders:
fmap = enc(fmap)
return fmap
def decode(self, fmap, return_indices_and_loss = False):
fmap, indices, commit_loss = self.vq(fmap)
for dec in self.decoders:
fmap = dec(fmap)
if not return_indices_and_loss:
return fmap
return fmap, indices, commit_loss
def forward(
self,
img,
return_loss = False,
return_discr_loss = False,
return_recons = False,
add_gradient_penalty = True
):
batch, channels, height, width, device = *img.shape, img.device
assert height == self.image_size and width == self.image_size, 'height and width of input image must be equal to {self.image_size}'
assert channels == self.channels, 'number of channels on image or sketch is not equal to the channels set on this VQGanVAE'
fmap = self.encode(img)
fmap, indices, commit_loss = self.decode(fmap, return_indices_and_loss = True)
if not return_loss and not return_discr_loss:
return fmap
assert return_loss ^ return_discr_loss, 'you should either return autoencoder loss or discriminator loss, but not both'
# whether to return discriminator loss
if return_discr_loss:
assert exists(self.discr), 'discriminator must exist to train it'
fmap.detach_()
img.requires_grad_()
fmap_discr_logits, img_discr_logits = map(self.discr, (fmap, img))
discr_loss = self.discr_loss(fmap_discr_logits, img_discr_logits)
if add_gradient_penalty:
gp = gradient_penalty(img, img_discr_logits)
loss = discr_loss + gp
if return_recons:
return loss, fmap
return loss
# reconstruction loss
recon_loss = self.recon_loss_fn(fmap, img)
# early return if training on grayscale
if not self.use_vgg_and_gan:
if return_recons:
return recon_loss, fmap
return recon_loss
# perceptual loss
img_vgg_input = img
fmap_vgg_input = fmap
if img.shape[1] == 1:
# handle grayscale for vgg
img_vgg_input, fmap_vgg_input = map(lambda t: repeat(t, 'b 1 ... -> b c ...', c = 3), (img_vgg_input, fmap_vgg_input))
img_vgg_feats = self.vgg(img_vgg_input)
recon_vgg_feats = self.vgg(fmap_vgg_input)
perceptual_loss = F.mse_loss(img_vgg_feats, recon_vgg_feats)
# generator loss
gen_loss = self.gen_loss(self.discr(fmap))
# calculate adaptive weight
last_dec_layer = self.decoders[-1].weight
norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_dec_layer).norm(p = 2)
norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_dec_layer).norm(p = 2)
adaptive_weight = safe_div(norm_grad_wrt_perceptual_loss, norm_grad_wrt_gen_loss)
adaptive_weight.clamp_(max = 1e4)
# combine losses
loss = recon_loss + perceptual_loss + commit_loss + adaptive_weight * gen_loss
if return_recons:
return loss, fmap
return loss

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.49',
version = '0.0.21',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',
@@ -30,7 +30,6 @@ setup(
'torch>=1.10',
'torchvision',
'tqdm',
'vector-quantize-pytorch',
'x-clip>=0.4.4',
'youtokentome'
],