Compare commits

...

19 Commits

Author SHA1 Message Date
Phil Wang
59b1a77d4d be a bit more conservative and stick with layernorm (without bias) for now, given @borisdayma results https://twitter.com/borisdayma/status/1517227191477571585 2022-04-22 11:14:54 -07:00
Phil Wang
7f338319fd makes more sense for blur augmentation to happen before the upsampling 2022-04-22 11:10:47 -07:00
Phil Wang
2c6c91829d refactor blurring training augmentation to be taken care of by the decoder, with option to downsample to previous resolution before upsampling (cascading ddpm). this opens up the possibility of cascading latent ddpm 2022-04-22 11:09:17 -07:00
Phil Wang
ad17c69ab6 prepare for latent diffusion in the first DDPM of the cascade in the Decoder 2022-04-21 17:54:31 -07:00
Phil Wang
0b4ec34efb todo 2022-04-20 12:24:23 -07:00
Phil Wang
f027b82e38 remove wip as main networks (prior and decoder) are completed 2022-04-20 12:12:16 -07:00
Phil Wang
8cc9016cb0 Merge pull request #17 from kashif/patch-2
added diffusion-gan thoughts
2022-04-20 12:10:26 -07:00
Kashif Rasul
1d8f37befe added diffusion-gan thoughts
https://github.com/NVlabs/denoising-diffusion-gan
2022-04-20 21:01:11 +02:00
Phil Wang
faebf4c8b8 from my vision transformer experience, dimension of attention head of 32 is sufficient for image feature maps 2022-04-20 11:40:32 -07:00
Phil Wang
b8e8d3c164 thoughts 2022-04-20 11:34:51 -07:00
Phil Wang
8e2416b49b commit to generalizing latent diffusion to one model 2022-04-20 11:27:42 -07:00
Phil Wang
f37c26e856 cleanup and DRY a little 2022-04-20 10:56:32 -07:00
Phil Wang
27a33e1b20 complete contextmanager method for keeping only one unet in GPU during training or inference 2022-04-20 10:46:13 -07:00
Phil Wang
6f941a219a give time tokens a surface area of 2 tokens as default, make it so researcher can customize which unet actually is conditioned on image embeddings and/or text encodings 2022-04-20 10:04:47 -07:00
Phil Wang
ddde8ca1bf fix cosine bbeta schedule, thanks to @Zhengxinyang 2022-04-19 20:54:28 -07:00
Phil Wang
c26b77ad20 todo 2022-04-19 13:07:32 -07:00
Phil Wang
c5b4aab8e5 intent 2022-04-19 11:00:05 -07:00
Phil Wang
a35c309b5f add sparse attention layers in between convnext blocks in unet (grid like attention, used in mobilevit, maxvit [bytedance ai], as well as a growing number of attention-based GANs) 2022-04-19 09:49:03 -07:00
Phil Wang
55bdcb98b9 scaffold for latent diffusion 2022-04-19 09:26:58 -07:00
6 changed files with 744 additions and 72 deletions

View File

@@ -1,6 +1,6 @@
<img src="./dalle2.png" width="450px"></img>
## DALL-E 2 - Pytorch (wip)
## DALL-E 2 - Pytorch
Implementation of <a href="https://openai.com/dall-e-2/">DALL-E 2</a>, OpenAI's updated text-to-image synthesis neural network, in Pytorch.
@@ -14,7 +14,7 @@ It may also explore an extension of using <a href="https://huggingface.co/spaces
Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord" src="https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white"></a> if you are interested in helping out with the replication
There was enough interest for a Jax version. It will be completed after the Pytorch version shows signs of life on my toy tasks. <a href="https://github.com/lucidrains/dalle2-jax">Placeholder repository</a>
There was enough interest for a Jax version. It will be completed after the Pytorch version shows signs of life on my toy tasks. <a href="https://github.com/lucidrains/dalle2-jax">Placeholder repository</a>. I will also eventually extend this to <a href="https://github.com/lucidrains/dalle2-video">text to video</a>, once the repository is in a good place.
## Install
@@ -409,10 +409,14 @@ Offer training wrappers
- [x] augment unet so that it can also be conditioned on text encodings (although in paper they hinted this didn't make much a difference)
- [x] figure out all the current bag of tricks needed to make DDPMs great (starting with the blur trick mentioned in paper)
- [x] build the cascading ddpm by having Decoder class manage multiple unets at different resolutions
- [ ] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately)
- [ ] build out latent diffusion architecture in separate file, as it is not faithful to dalle-2 (but offer it as as setting)
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, add efficient attention (conditional on resolution), port all learnings over to https://github.com/lucidrains/x-unet
- [x] add efficient attention in unet
- [x] be able to finely customize what to condition on (text, image embed) for specific unet in the cascade (super resolution ddpms near the end may not need too much conditioning)
- [x] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately)
- [ ] build out latent diffusion architecture, with the vq-reg variant (vqgan-vae), make it completely optional
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
- [ ] train on a toy task, offer in colab
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
## Citations
@@ -461,4 +465,12 @@ Offer training wrappers
}
```
```bibtex
@inproceedings{Tu2022MaxViTMV,
title = {MaxViT: Multi-Axis Vision Transformer},
author = {Zhe-Wei Tu and Hossein Talebi and Han Zhang and Feng Yang and Peyman Milanfar and Alan Conrad Bovik and Yinxiao Li},
year = {2022}
}
```
*Creating noise from data is easy; creating data from noise is generative modeling.* - Yang Song's <a href="https://arxiv.org/abs/2011.13456">paper</a>

View File

@@ -1,2 +1,4 @@
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
from dalle2_pytorch.vqgan_vae import VQGanVAE
from x_clip import CLIP

View File

@@ -6,4 +6,4 @@ def main():
@click.command()
@click.argument('text')
def dream(text):
return image
return 'not ready yet'

View File

@@ -2,6 +2,7 @@ import math
from tqdm import tqdm
from inspect import isfunction
from functools import partial
from contextlib import contextmanager
import torch
import torch.nn.functional as F
@@ -105,8 +106,8 @@ def cosine_beta_schedule(timesteps, s = 0.008):
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
"""
steps = timesteps + 1
x = torch.linspace(0, steps, steps)
alphas_cumprod = torch.cos(((x / steps) + s) / (1 + s) * torch.pi * 0.5) ** 2
x = torch.linspace(0, timesteps, steps)
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0, 0.999)
@@ -136,23 +137,27 @@ def sigmoid_beta_schedule(timesteps):
# diffusion prior
class RMSNorm(nn.Module):
class LayerNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.ones(dim))
self.register_buffer("beta", torch.zeros(dim))
def forward(self, x):
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
class ChanLayerNorm(nn.Module):
def __init__(self, dim, eps = 1e-5):
super().__init__()
self.eps = eps
self.scale = dim ** 0.5
self.gamma = nn.Parameter(torch.ones(dim))
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
def forward(self, x):
squared_sum = (x ** 2).sum(dim = -1, keepdim = True)
inv_norm = torch.rsqrt(squared_sum + self.eps)
return x * inv_norm * self.gamma * self.scale
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (var + self.eps).sqrt() * self.g
class ChanRMSNorm(RMSNorm):
def forward(self, x):
squared_sum = (x ** 2).sum(dim = 1, keepdim = True)
inv_norm = torch.rsqrt(squared_sum + self.eps)
return x * inv_norm * rearrange(self.gamma, 'c -> 1 c 1 1') * self.scale
class Residual(nn.Module):
def __init__(self, fn):
@@ -248,10 +253,10 @@ def FeedForward(dim, mult = 4, dropout = 0., post_activation_norm = False):
inner_dim = int(mult * dim)
return nn.Sequential(
RMSNorm(dim),
LayerNorm(dim),
nn.Linear(dim, inner_dim * 2, bias = False),
SwiGLU(),
RMSNorm(inner_dim) if post_activation_norm else nn.Identity(),
LayerNorm(inner_dim) if post_activation_norm else nn.Identity(),
nn.Dropout(dropout),
nn.Linear(inner_dim, dim, bias = False)
)
@@ -274,7 +279,8 @@ class Attention(nn.Module):
inner_dim = dim_head * heads
self.causal = causal
self.norm = RMSNorm(dim)
self.norm = LayerNorm(dim)
self.post_norm = LayerNorm(dim) # sandwich norm from Coqview paper + Normformer
self.dropout = nn.Dropout(dropout)
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
@@ -330,7 +336,8 @@ class Attention(nn.Module):
out = einsum('b h i j, b j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
out = self.to_out(out)
return self.post_norm(out)
class CausalTransformer(nn.Module):
def __init__(
@@ -355,7 +362,7 @@ class CausalTransformer(nn.Module):
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
]))
self.norm = RMSNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
self.norm = LayerNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
def forward(
self,
@@ -463,11 +470,11 @@ class DiffusionPrior(nn.Module):
net,
*,
clip,
timesteps=1000,
cond_drop_prob=0.2,
loss_type="l1",
predict_x0=True,
beta_schedule="cosine",
timesteps = 1000,
cond_drop_prob = 0.2,
loss_type = "l1",
predict_x0 = True,
beta_schedule = "cosine",
):
super().__init__()
assert isinstance(clip, CLIP)
@@ -497,7 +504,7 @@ class DiffusionPrior(nn.Module):
raise NotImplementedError()
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod = torch.cumprod(alphas, axis = 0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
timesteps, = betas.shape
@@ -719,7 +726,7 @@ class ConvNextBlock(nn.Module):
inner_dim = int(dim_out * mult)
self.net = nn.Sequential(
ChanRMSNorm(dim) if norm else nn.Identity(),
ChanLayerNorm(dim) if norm else nn.Identity(),
nn.Conv2d(dim, inner_dim, 3, padding = 1),
nn.GELU(),
nn.Conv2d(inner_dim, dim_out, 3, padding = 1)
@@ -755,8 +762,8 @@ class CrossAttention(nn.Module):
context_dim = default(context_dim, dim)
self.norm = RMSNorm(dim)
self.norm_context = RMSNorm(context_dim)
self.norm = LayerNorm(dim)
self.norm_context = LayerNorm(context_dim)
self.dropout = nn.Dropout(dropout)
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
@@ -798,6 +805,20 @@ class CrossAttention(nn.Module):
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class GridAttention(nn.Module):
def __init__(self, *args, window_size = 8, **kwargs):
super().__init__()
self.window_size = window_size
self.attn = Attention(*args, **kwargs)
def forward(self, x):
h, w = x.shape[-2:]
wsz = self.window_size
x = rearrange(x, 'b c (w1 h) (w2 w) -> (b h w) (w1 w2) c', w1 = wsz, w2 = wsz)
out = self.attn(x)
out = rearrange(out, '(b h w) (w1 w2) c -> b c (w1 h) (w2 w)', w1 = wsz, w2 = wsz, h = h // wsz, w = w // wsz)
return out
class Unet(nn.Module):
def __init__(
self,
@@ -806,14 +827,18 @@ class Unet(nn.Module):
image_embed_dim,
cond_dim = None,
num_image_tokens = 4,
num_time_tokens = 2,
out_dim = None,
dim_mults=(1, 2, 4, 8),
channels = 3,
attn_dim_head = 32,
attn_heads = 8,
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
lowres_cond_upsample_mode = 'bilinear',
blur_sigma = 0.1,
blur_kernel_size = 3,
sparse_attn = False,
sparse_attn_window = 8, # window size for sparse attention
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
cond_on_text_encodings = False,
cond_on_image_embeds = False,
):
super().__init__()
# save locals to take care of some hyperparameters for cascading DDPM
@@ -825,9 +850,6 @@ class Unet(nn.Module):
# for eventual cascading diffusion
self.lowres_cond = lowres_cond
self.lowres_cond_upsample_mode = lowres_cond_upsample_mode
self.lowres_blur_kernel_size = blur_kernel_size
self.lowres_blur_sigma = blur_sigma
# determine dimensions
@@ -846,8 +868,8 @@ class Unet(nn.Module):
SinusoidalPosEmb(dim),
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Linear(dim * 4, cond_dim),
Rearrange('b d -> b 1 d')
nn.Linear(dim * 4, cond_dim * num_time_tokens),
Rearrange('b (r d) -> b r d', r = num_time_tokens)
)
self.image_to_cond = nn.Sequential(
@@ -857,11 +879,21 @@ class Unet(nn.Module):
self.text_to_cond = nn.LazyLinear(cond_dim)
# finer control over whether to condition on image embeddings and text encodings
# so one can have the latter unets in the cascading DDPMs only focus on super-resoluting
self.cond_on_text_encodings = cond_on_text_encodings
self.cond_on_image_embeds = cond_on_image_embeds
# for classifier free guidance
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
self.null_text_embed = nn.Parameter(torch.randn(1, 1, cond_dim))
# attention related params
attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head)
# layers
self.downs = nn.ModuleList([])
@@ -875,6 +907,7 @@ class Unet(nn.Module):
self.downs.append(nn.ModuleList([
ConvNextBlock(dim_in, dim_out, norm = ind != 0),
Residual(GridAttention(dim_out, window_size = sparse_attn_window, **attn_kwargs)) if sparse_attn else nn.Identity(),
ConvNextBlock(dim_out, dim_out, cond_dim = layer_cond_dim),
Downsample(dim_out) if not is_last else nn.Identity()
]))
@@ -882,7 +915,7 @@ class Unet(nn.Module):
mid_dim = dims[-1]
self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim))) if attend_at_middle else None
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
@@ -891,6 +924,7 @@ class Unet(nn.Module):
self.ups.append(nn.ModuleList([
ConvNextBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim),
Residual(GridAttention(dim_in, window_size = sparse_attn_window, **attn_kwargs)) if sparse_attn else nn.Identity(),
ConvNextBlock(dim_in, dim_in, cond_dim = layer_cond_dim),
Upsample(dim_in)
]))
@@ -943,13 +977,6 @@ class Unet(nn.Module):
assert not (self.lowres_cond and not exists(lowres_cond_img)), 'low resolution conditioning image must be present'
if exists(lowres_cond_img):
if self.training:
# when training, blur the low resolution conditional image
blur_sigma = default(blur_sigma, self.lowres_blur_sigma)
blur_kernel_size = default(blur_kernel_size, self.lowres_blur_kernel_size)
lowres_cond_img = gaussian_blur2d(lowres_cond_img, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2))
lowres_cond_img = resize_image_to(lowres_cond_img, x.shape[-2:], mode = self.lowres_cond_upsample_mode)
x = torch.cat((x, lowres_cond_img), dim = 1)
# time conditioning
@@ -964,17 +991,22 @@ class Unet(nn.Module):
# mask out image embedding depending on condition dropout
# for classifier free guidance
image_tokens = self.image_to_cond(image_embed)
image_tokens = None
image_tokens = torch.where(
cond_prob_mask,
image_tokens,
self.null_image_embed
)
if self.cond_on_image_embeds:
image_tokens = self.image_to_cond(image_embed)
image_tokens = torch.where(
cond_prob_mask,
image_tokens,
self.null_image_embed
)
# take care of text encodings (optional)
if exists(text_encodings):
text_tokens = None
if exists(text_encodings) and self.cond_on_text_encodings:
text_tokens = self.text_to_cond(text_encodings)
text_tokens = torch.where(
cond_prob_mask,
@@ -984,19 +1016,23 @@ class Unet(nn.Module):
# main conditioning tokens (c)
c = torch.cat((time_tokens, image_tokens), dim = -2)
c = time_tokens
if exists(image_tokens):
c = torch.cat((c, image_tokens), dim = -2)
# text and image conditioning tokens (mid_c)
# to save on compute, only do cross attention based conditioning on the inner most layers of the Unet
mid_c = c if not exists(text_encodings) else torch.cat((c, text_tokens), dim = -2)
mid_c = c if not exists(text_tokens) else torch.cat((c, text_tokens), dim = -2)
# go through the layers of the unet, down and up
hiddens = []
for convnext, convnext2, downsample in self.downs:
for convnext, sparse_attn, convnext2, downsample in self.downs:
x = convnext(x, c)
x = sparse_attn(x)
x = convnext2(x, c)
hiddens.append(x)
x = downsample(x)
@@ -1008,14 +1044,53 @@ class Unet(nn.Module):
x = self.mid_block2(x, mid_c)
for convnext, convnext2, upsample in self.ups:
for convnext, sparse_attn, convnext2, upsample in self.ups:
x = torch.cat((x, hiddens.pop()), dim=1)
x = convnext(x, c)
x = sparse_attn(x)
x = convnext2(x, c)
x = upsample(x)
return self.final_conv(x)
class LowresConditioner(nn.Module):
def __init__(
self,
cond_upsample_mode = 'bilinear',
downsample_first = True,
blur_sigma = 0.1,
blur_kernel_size = 3,
):
super().__init__()
self.cond_upsample_mode = cond_upsample_mode
self.downsample_first = downsample_first
self.blur_sigma = blur_sigma
self.blur_kernel_size = blur_kernel_size
def forward(
self,
cond_fmap,
*,
target_image_size,
downsample_image_size = None,
blur_sigma = None,
blur_kernel_size = None
):
target_image_size = cast_tuple(target_image_size, 2)
if self.training and self.downsample_first and exists(downsample_image_size):
cond_fmap = resize_image_to(cond_fmap, target_image_size, mode = self.cond_upsample_mode)
if self.training:
# when training, blur the low resolution conditional image
blur_sigma = default(blur_sigma, self.blur_sigma)
blur_kernel_size = default(blur_kernel_size, self.blur_kernel_size)
cond_fmap = gaussian_blur2d(cond_fmap, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2))
cond_fmap = resize_image_to(cond_fmap, target_image_size, mode = self.cond_upsample_mode)
return cond_fmap
class Decoder(nn.Module):
def __init__(
self,
@@ -1026,7 +1101,11 @@ class Decoder(nn.Module):
cond_drop_prob = 0.2,
loss_type = 'l1',
beta_schedule = 'cosine',
image_sizes = None # for cascading ddpm, image size at each stage
image_sizes = None, # for cascading ddpm, image size at each stage
lowres_cond_upsample_mode = 'bilinear', # cascading ddpm - low resolution upsample mode
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
blur_sigma = 0.1, # cascading ddpm - blur sigma
blur_kernel_size = 3, # cascading ddpm - blur kernel size
):
super().__init__()
assert isinstance(clip, CLIP)
@@ -1051,12 +1130,26 @@ class Decoder(nn.Module):
assert len(self.unets) == len(image_sizes), f'you did not supply the correct number of u-nets ({len(self.unets)}) for resolutions {image_sizes}'
self.image_sizes = image_sizes
self.sample_channels = cast_tuple(self.channels, len(image_sizes))
# cascading ddpm related stuff
lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets))
assert lowres_conditions == (False, *((True,) * (len(self.unets) - 1))), 'the first unet must be unconditioned (by low resolution image), and the rest of the unets must have `lowres_cond` set to True'
self.to_lowres_cond = LowresConditioner(
cond_upsample_mode = lowres_cond_upsample_mode,
downsample_first = lowres_downsample_first,
blur_sigma = blur_sigma,
blur_kernel_size = blur_kernel_size,
)
# classifier free guidance
self.cond_drop_prob = cond_drop_prob
# noise schedule
if beta_schedule == "cosine":
betas = cosine_beta_schedule(timesteps)
elif beta_schedule == "linear":
@@ -1071,7 +1164,7 @@ class Decoder(nn.Module):
raise NotImplementedError()
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod = torch.cumprod(alphas, axis = 0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
timesteps, = betas.shape
@@ -1104,6 +1197,25 @@ class Decoder(nn.Module):
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
def get_unet(self, unet_number):
assert 0 < unet_number <= len(self.unets)
index = unet_number - 1
return self.unets[index]
@contextmanager
def one_unet_in_gpu(self, unet_number = None, unet = None):
assert exists(unet_number) ^ exists(unet)
if exists(unet_number):
unet = self.get_unet(unet_number)
self.cuda()
self.unets.cpu()
unet.cuda()
yield
unet.cpu()
def get_text_encodings(self, text):
text_encodings = self.clip.text_transformer(text)
return text_encodings[:, 1:]
@@ -1164,6 +1276,7 @@ class Decoder(nn.Module):
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
img = self.p_sample(unet, img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
return img
def q_sample(self, x_start, t, noise=None):
@@ -1203,25 +1316,37 @@ class Decoder(nn.Module):
@eval_decorator
def sample(self, image_embed, text = None, cond_scale = 1.):
batch_size = image_embed.shape[0]
channels = self.channels
text_encodings = self.get_text_encodings(text) if exists(text) else None
img = None
for unet, image_size in tqdm(zip(self.unets, self.image_sizes)):
shape = (batch_size, channels, image_size, image_size)
img = self.p_sample_loop(unet, shape, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = img)
for unet, channel, image_size in tqdm(zip(self.unets, self.sample_channels, self.image_sizes)):
with self.one_unet_in_gpu(unet = unet):
lowres_cond_img = self.to_lowres_cond(
img,
target_image_size = image_size
) if unet.lowres_cond else None
img = self.p_sample_loop(
unet,
(batch_size, channel, image_size, image_size),
image_embed = image_embed,
text_encodings = text_encodings,
cond_scale = cond_scale,
lowres_cond_img = lowres_cond_img
)
return img
def forward(self, image, text = None, image_embed = None, text_encodings = None, unet_number = None):
assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
unet_number = default(unet_number, 1)
assert 1 <= unet_number <= len(self.unets)
unet_index = unet_number - 1
index = unet_number - 1
unet = self.unets[index]
target_image_size = self.image_sizes[index]
unet = self.get_unet(unet_number)
target_image_size = self.image_sizes[unet_index]
b, c, h, w, device, = *image.shape, image.device
@@ -1235,7 +1360,7 @@ class Decoder(nn.Module):
text_encodings = self.get_text_encodings(text) if exists(text) and not exists(text_encodings) else None
lowres_cond_img = image if index > 0 else None
lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None
ddpm_image = resize_image_to(image, target_image_size)
return self.p_losses(unet, ddpm_image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img)

532
dalle2_pytorch/vqgan_vae.py Normal file
View File

@@ -0,0 +1,532 @@
import copy
import math
from math import sqrt
from functools import partial, wraps
from vector_quantize_pytorch import VectorQuantize as VQ
import torch
from torch import nn, einsum
import torch.nn.functional as F
from torch.autograd import grad as torch_grad
import torchvision
from einops import rearrange, reduce, repeat
# constants
MList = nn.ModuleList
# helper functions
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
# decorators
def eval_decorator(fn):
def inner(model, *args, **kwargs):
was_training = model.training
model.eval()
out = fn(model, *args, **kwargs)
model.train(was_training)
return out
return inner
def remove_vgg(fn):
@wraps(fn)
def inner(self, *args, **kwargs):
has_vgg = hasattr(self, 'vgg')
if has_vgg:
vgg = self.vgg
delattr(self, 'vgg')
out = fn(self, *args, **kwargs)
if has_vgg:
self.vgg = vgg
return out
return inner
# keyword argument helpers
def pick_and_pop(keys, d):
values = list(map(lambda key: d.pop(key), keys))
return dict(zip(keys, values))
def group_dict_by_key(cond, d):
return_val = [dict(),dict()]
for key in d.keys():
match = bool(cond(key))
ind = int(not match)
return_val[ind][key] = d[key]
return (*return_val,)
def string_begins_with(prefix, str):
return str.startswith(prefix)
def group_by_key_prefix(prefix, d):
return group_dict_by_key(partial(string_begins_with, prefix), d)
def groupby_prefix_and_trim(prefix, d):
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
return kwargs_without_prefix, kwargs
# tensor helper functions
def log(t, eps = 1e-10):
return torch.log(t + eps)
def gradient_penalty(images, output, weight = 10):
batch_size = images.shape[0]
gradients = torch_grad(outputs = output, inputs = images,
grad_outputs = torch.ones(output.size(), device = images.device),
create_graph = True, retain_graph = True, only_inputs = True)[0]
gradients = rearrange(gradients, 'b ... -> b (...)')
return weight * ((gradients.norm(2, dim = 1) - 1) ** 2).mean()
def l2norm(t):
return F.normalize(t, dim = -1)
def leaky_relu(p = 0.1):
return nn.LeakyReLU(0.1)
def stable_softmax(t, dim = -1, alpha = 32 ** 2):
t = t / alpha
t = t - torch.amax(t, dim = dim, keepdim = True).detach()
return (t * alpha).softmax(dim = dim)
def safe_div(numer, denom, eps = 1e-8):
return numer / (denom + eps)
# gan losses
def hinge_discr_loss(fake, real):
return (F.relu(1 + fake) + F.relu(1 - real)).mean()
def hinge_gen_loss(fake):
return -fake.mean()
def bce_discr_loss(fake, real):
return (-log(1 - torch.sigmoid(fake)) - log(torch.sigmoid(real))).mean()
def bce_gen_loss(fake):
return -log(torch.sigmoid(fake)).mean()
def grad_layer_wrt_loss(loss, layer):
return torch_grad(
outputs = loss,
inputs = layer,
grad_outputs = torch.ones_like(loss),
retain_graph = True
)[0].detach()
# vqgan vae
class LayerNormChan(nn.Module):
def __init__(
self,
dim,
eps = 1e-5
):
super().__init__()
self.eps = eps
self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1))
def forward(self, x):
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (var + self.eps).sqrt() * self.gamma
class Discriminator(nn.Module):
def __init__(
self,
dims,
channels = 3,
groups = 16,
init_kernel_size = 5
):
super().__init__()
dim_pairs = zip(dims[:-1], dims[1:])
self.layers = MList([nn.Sequential(nn.Conv2d(channels, dims[0], init_kernel_size, padding = init_kernel_size // 2), leaky_relu())])
for dim_in, dim_out in dim_pairs:
self.layers.append(nn.Sequential(
nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1),
nn.GroupNorm(groups, dim_out),
leaky_relu()
))
dim = dims[-1]
self.to_logits = nn.Sequential( # return 5 x 5, for PatchGAN-esque training
nn.Conv2d(dim, dim, 1),
leaky_relu(),
nn.Conv2d(dim, 1, 4)
)
def forward(self, x):
for net in self.layers:
x = net(x)
return self.to_logits(x)
class ContinuousPositionBias(nn.Module):
""" from https://arxiv.org/abs/2111.09883 """
def __init__(self, *, dim, heads, layers = 2):
super().__init__()
self.net = MList([])
self.net.append(nn.Sequential(nn.Linear(2, dim), leaky_relu()))
for _ in range(layers - 1):
self.net.append(nn.Sequential(nn.Linear(dim, dim), leaky_relu()))
self.net.append(nn.Linear(dim, heads))
self.register_buffer('rel_pos', None, persistent = False)
def forward(self, x):
n, device = x.shape[-1], x.device
fmap_size = int(sqrt(n))
if not exists(self.rel_pos):
pos = torch.arange(fmap_size, device = device)
grid = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij'))
grid = rearrange(grid, 'c i j -> (i j) c')
rel_pos = rearrange(grid, 'i c -> i 1 c') - rearrange(grid, 'j c -> 1 j c')
rel_pos = torch.sign(rel_pos) * torch.log(rel_pos.abs() + 1)
self.register_buffer('rel_pos', rel_pos, persistent = False)
rel_pos = self.rel_pos.float()
for layer in self.net:
rel_pos = layer(rel_pos)
bias = rearrange(rel_pos, 'i j h -> h i j')
return x + bias
class GLUResBlock(nn.Module):
def __init__(self, chan, groups = 16):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(chan, chan * 2, 3, padding = 1),
nn.GLU(dim = 1),
nn.GroupNorm(groups, chan),
nn.Conv2d(chan, chan * 2, 3, padding = 1),
nn.GLU(dim = 1),
nn.GroupNorm(groups, chan),
nn.Conv2d(chan, chan, 1)
)
def forward(self, x):
return self.net(x) + x
class ResBlock(nn.Module):
def __init__(self, chan, groups = 16):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(chan, chan, 3, padding = 1),
nn.GroupNorm(groups, chan),
leaky_relu(),
nn.Conv2d(chan, chan, 3, padding = 1),
nn.GroupNorm(groups, chan),
leaky_relu(),
nn.Conv2d(chan, chan, 1)
)
def forward(self, x):
return self.net(x) + x
class VQGanAttention(nn.Module):
def __init__(
self,
*,
dim,
dim_head = 64,
heads = 8,
dropout = 0.
):
super().__init__()
self.heads = heads
self.scale = dim_head ** -0.5
inner_dim = heads * dim_head
self.dropout = nn.Dropout(dropout)
self.pre_norm = LayerNormChan(dim)
self.cpb = ContinuousPositionBias(dim = dim // 4, heads = heads)
self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)
self.to_out = nn.Conv2d(inner_dim, dim, 1, bias = False)
def forward(self, x):
h = self.heads
height, width, residual = *x.shape[-2:], x.clone()
x = self.pre_norm(x)
q, k, v = self.to_qkv(x).chunk(3, dim = 1)
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = h), (q, k, v))
sim = einsum('b h c i, b h c j -> b h i j', q, k) * self.scale
sim = self.cpb(sim)
attn = stable_softmax(sim, dim = -1)
attn = self.dropout(attn)
out = einsum('b h i j, b h c j -> b h c i', attn, v)
out = rearrange(out, 'b h c (x y) -> b (h c) x y', x = height, y = width)
out = self.to_out(out)
return out + residual
class VQGanVAE(nn.Module):
def __init__(
self,
*,
dim,
image_size,
channels = 3,
num_layers = 4,
layer_mults = None,
l2_recon_loss = False,
use_hinge_loss = True,
num_resnet_blocks = 1,
vgg = None,
vq_codebook_size = 512,
vq_decay = 0.8,
vq_commitment_weight = 1.,
vq_kmeans_init = True,
vq_use_cosine_sim = True,
use_attn = True,
attn_dim_head = 64,
attn_heads = 8,
resnet_groups = 16,
attn_dropout = 0.,
first_conv_kernel_size = 5,
use_vgg_and_gan = True,
**kwargs
):
super().__init__()
assert dim % resnet_groups == 0, f'dimension {dim} must be divisible by {resnet_groups} (groups for the groupnorm)'
vq_kwargs, kwargs = groupby_prefix_and_trim('vq_', kwargs)
self.image_size = image_size
self.channels = channels
self.num_layers = num_layers
self.fmap_size = image_size // (num_layers ** 2)
self.codebook_size = vq_codebook_size
self.encoders = MList([])
self.decoders = MList([])
layer_mults = default(layer_mults, list(map(lambda t: 2 ** t, range(num_layers))))
assert len(layer_mults) == num_layers, 'layer multipliers must be equal to designated number of layers'
layer_dims = [dim * mult for mult in layer_mults]
dims = (dim, *layer_dims)
codebook_dim = layer_dims[-1]
dim_pairs = zip(dims[:-1], dims[1:])
append = lambda arr, t: arr.append(t)
prepend = lambda arr, t: arr.insert(0, t)
if not isinstance(num_resnet_blocks, tuple):
num_resnet_blocks = (*((0,) * (num_layers - 1)), num_resnet_blocks)
if not isinstance(use_attn, tuple):
use_attn = (*((False,) * (num_layers - 1)), use_attn)
assert len(num_resnet_blocks) == num_layers, 'number of resnet blocks config must be equal to number of layers'
assert len(use_attn) == num_layers
for layer_index, (dim_in, dim_out), layer_num_resnet_blocks, layer_use_attn in zip(range(num_layers), dim_pairs, num_resnet_blocks, use_attn):
append(self.encoders, nn.Sequential(nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), leaky_relu()))
prepend(self.decoders, nn.Sequential(nn.Upsample(scale_factor = 2, mode = 'bilinear', align_corners = False), nn.Conv2d(dim_out, dim_in, 3, padding = 1), leaky_relu()))
if layer_use_attn:
prepend(self.decoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
for _ in range(layer_num_resnet_blocks):
append(self.encoders, ResBlock(dim_out, groups = resnet_groups))
prepend(self.decoders, GLUResBlock(dim_out, groups = resnet_groups))
if layer_use_attn:
append(self.encoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
prepend(self.encoders, nn.Conv2d(channels, dim, first_conv_kernel_size, padding = first_conv_kernel_size // 2))
append(self.decoders, nn.Conv2d(dim, channels, 1))
self.vq = VQ(
dim = codebook_dim,
codebook_size = vq_codebook_size,
decay = vq_decay,
commitment_weight = vq_commitment_weight,
accept_image_fmap = True,
kmeans_init = vq_kmeans_init,
use_cosine_sim = vq_use_cosine_sim,
**vq_kwargs
)
# reconstruction loss
self.recon_loss_fn = F.mse_loss if l2_recon_loss else F.l1_loss
# turn off GAN and perceptual loss if grayscale
self.vgg = None
self.discr = None
self.use_vgg_and_gan = use_vgg_and_gan
if not use_vgg_and_gan:
return
# preceptual loss
if exists(vgg):
self.vgg = vgg
else:
self.vgg = torchvision.models.vgg16(pretrained = True)
self.vgg.classifier = nn.Sequential(*self.vgg.classifier[:-2])
# gan related losses
self.discr = Discriminator(dims = dims, channels = channels)
self.discr_loss = hinge_discr_loss if use_hinge_loss else bce_discr_loss
self.gen_loss = hinge_gen_loss if use_hinge_loss else bce_gen_loss
def copy_for_eval(self):
device = next(self.parameters()).device
vae_copy = copy.deepcopy(self.cpu())
if vae_copy.use_vgg_and_gan:
del vae_copy.discr
del vae_copy.vgg
vae_copy.eval()
return vae_copy.to(device)
@remove_vgg
def state_dict(self, *args, **kwargs):
return super().state_dict(*args, **kwargs)
@remove_vgg
def load_state_dict(self, *args, **kwargs):
return super().load_state_dict(*args, **kwargs)
@property
def codebook(self):
return self.vq.codebook
def encode(self, fmap):
for enc in self.encoders:
fmap = enc(fmap)
return fmap
def decode(self, fmap):
fmap = self.vq(fmap)
for dec in self.decoders:
fmap = dec(fmap)
return fmap
def forward(
self,
img,
return_loss = False,
return_discr_loss = False,
return_recons = False
):
batch, channels, height, width, device = *img.shape, img.device
assert height == self.image_size and width == self.image_size, 'height and width of input image must be equal to {self.image_size}'
assert channels == self.channels, 'number of channels on image or sketch is not equal to the channels set on this VQGanVAE'
fmap, indices, commit_loss = self.encode(img)
fmap = self.decode(fmap)
if not return_loss and not return_discr_loss:
return fmap
assert return_loss ^ return_discr_loss, 'you should either return autoencoder loss or discriminator loss, but not both'
# whether to return discriminator loss
if return_discr_loss:
assert exists(self.discr), 'discriminator must exist to train it'
fmap.detach_()
img.requires_grad_()
fmap_discr_logits, img_discr_logits = map(self.discr, (fmap, img))
gp = gradient_penalty(img, img_discr_logits)
discr_loss = self.discr_loss(fmap_discr_logits, img_discr_logits)
loss = discr_loss + gp
if return_recons:
return loss, fmap
return loss
# reconstruction loss
recon_loss = self.recon_loss_fn(fmap, img)
# early return if training on grayscale
if not self.use_vgg_and_gan:
if return_recons:
return recon_loss, fmap
return recon_loss
# perceptual loss
img_vgg_input = img
fmap_vgg_input = fmap
if img.shape[1] == 1:
# handle grayscale for vgg
img_vgg_input, fmap_vgg_input = map(lambda t: repeat(t, 'b 1 ... -> b c ...', c = 3), (img_vgg_input, fmap_vgg_input))
img_vgg_feats = self.vgg(img_vgg_input)
recon_vgg_feats = self.vgg(fmap_vgg_input)
perceptual_loss = F.mse_loss(img_vgg_feats, recon_vgg_feats)
# generator loss
gen_loss = self.gen_loss(self.discr(fmap))
# calculate adaptive weight
last_dec_layer = self.decoders[-1].weight
norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_dec_layer).norm(p = 2)
norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_dec_layer).norm(p = 2)
adaptive_weight = safe_div(norm_grad_wrt_perceptual_loss, norm_grad_wrt_gen_loss)
adaptive_weight.clamp_(max = 1e4)
# combine losses
loss = recon_loss + perceptual_loss + commit_loss + adaptive_weight * gen_loss
if return_recons:
return loss, fmap
return loss

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.24',
version = '0.0.34',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',
@@ -30,6 +30,7 @@ setup(
'torch>=1.10',
'torchvision',
'tqdm',
'vector-quantize-pytorch',
'x-clip>=0.4.4',
'youtokentome'
],