Compare commits

...

20 Commits

Author SHA1 Message Date
Phil Wang
05b74be69a use null container pattern to cleanup some conditionals, save more cleanup for next week 2022-04-22 15:23:18 -07:00
Phil Wang
a8b5d5d753 last tweak of readme 2022-04-22 14:16:43 -07:00
Phil Wang
976ef7f87c project management 2022-04-22 14:15:42 -07:00
Phil Wang
fd175bcc0e readme 2022-04-22 14:13:33 -07:00
Phil Wang
76b32f18b3 first pass at complete DALL-E2 + Latent Diffusion integration, latent diffusion on any layer(s) of the cascading ddpm in the decoder. 2022-04-22 13:53:13 -07:00
Phil Wang
f2d5b87677 todo 2022-04-22 11:39:58 -07:00
Phil Wang
461347c171 fix vqgan-vae for latent diffusion 2022-04-22 11:38:57 -07:00
Phil Wang
46cef31c86 optional projection out for prior network causal transformer 2022-04-22 11:16:30 -07:00
Phil Wang
59b1a77d4d be a bit more conservative and stick with layernorm (without bias) for now, given @borisdayma results https://twitter.com/borisdayma/status/1517227191477571585 2022-04-22 11:14:54 -07:00
Phil Wang
7f338319fd makes more sense for blur augmentation to happen before the upsampling 2022-04-22 11:10:47 -07:00
Phil Wang
2c6c91829d refactor blurring training augmentation to be taken care of by the decoder, with option to downsample to previous resolution before upsampling (cascading ddpm). this opens up the possibility of cascading latent ddpm 2022-04-22 11:09:17 -07:00
Phil Wang
ad17c69ab6 prepare for latent diffusion in the first DDPM of the cascade in the Decoder 2022-04-21 17:54:31 -07:00
Phil Wang
0b4ec34efb todo 2022-04-20 12:24:23 -07:00
Phil Wang
f027b82e38 remove wip as main networks (prior and decoder) are completed 2022-04-20 12:12:16 -07:00
Phil Wang
8cc9016cb0 Merge pull request #17 from kashif/patch-2
added diffusion-gan thoughts
2022-04-20 12:10:26 -07:00
Kashif Rasul
1d8f37befe added diffusion-gan thoughts
https://github.com/NVlabs/denoising-diffusion-gan
2022-04-20 21:01:11 +02:00
Phil Wang
faebf4c8b8 from my vision transformer experience, dimension of attention head of 32 is sufficient for image feature maps 2022-04-20 11:40:32 -07:00
Phil Wang
b8e8d3c164 thoughts 2022-04-20 11:34:51 -07:00
Phil Wang
8e2416b49b commit to generalizing latent diffusion to one model 2022-04-20 11:27:42 -07:00
Phil Wang
f37c26e856 cleanup and DRY a little 2022-04-20 10:56:32 -07:00
6 changed files with 897 additions and 90 deletions

137
README.md
View File

@@ -1,6 +1,6 @@
<img src="./dalle2.png" width="450px"></img>
## DALL-E 2 - Pytorch (wip)
## DALL-E 2 - Pytorch
Implementation of <a href="https://openai.com/dall-e-2/">DALL-E 2</a>, OpenAI's updated text-to-image synthesis neural network, in Pytorch.
@@ -10,8 +10,6 @@ The main novelty seems to be an extra layer of indirection with the prior networ
This model is SOTA for text-to-image for now.
It may also explore an extension of using <a href="https://huggingface.co/spaces/multimodalart/latentdiffusion">latent diffusion</a> in the decoder from Rombach et al.
Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord" src="https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white"></a> if you are interested in helping out with the replication
There was enough interest for a Jax version. It will be completed after the Pytorch version shows signs of life on my toy tasks. <a href="https://github.com/lucidrains/dalle2-jax">Placeholder repository</a>. I will also eventually extend this to <a href="https://github.com/lucidrains/dalle2-video">text to video</a>, once the repository is in a good place.
@@ -385,7 +383,127 @@ You can also train the decoder on images of greater than the size (say 512x512)
For the layperson, no worries, training will all be automated into a CLI tool, at least for small scale training.
## CLI Usage (work in progress)
## Experimental
### DALL-E2 with Latent Diffusion
This repository decides to take the next step and offer DALL-E2 combined with <a href="https://huggingface.co/spaces/multimodalart/latentdiffusion">latent diffusion</a>, from Rombach et al.
You can use it as follows. Latent diffusion can be limited to just the first U-Net in the cascade, or to any number you wish.
```python
import torch
from dalle2_pytorch import Unet, Decoder, CLIP, VQGanVAE
# trained clip from step 1
clip = CLIP(
dim_text = 512,
dim_image = 512,
dim_latent = 512,
num_text_tokens = 49408,
text_enc_depth = 1,
text_seq_len = 256,
text_heads = 8,
visual_enc_depth = 1,
visual_image_size = 256,
visual_patch_size = 32,
visual_heads = 8
)
# 3 unets for the decoder (a la cascading DDPM)
# first two unets are doing latent diffusion
# vqgan-vae must be trained before hand
vae1 = VQGanVAE(
dim = 32,
image_size = 256,
layers = 3,
layer_mults = (1, 2, 4)
)
vae2 = VQGanVAE(
dim = 32,
image_size = 512,
layers = 3,
layer_mults = (1, 2, 4)
)
unet1 = Unet(
dim = 32,
image_embed_dim = 512,
cond_dim = 128,
channels = 3,
sparse_attn = True,
sparse_attn_window = 2,
dim_mults = (1, 2, 4, 8)
)
unet2 = Unet(
dim = 32,
image_embed_dim = 512,
channels = 3,
dim_mults = (1, 2, 4, 8, 16),
cond_on_image_embeds = True,
cond_on_text_encodings = False
)
unet3 = Unet(
dim = 32,
image_embed_dim = 512,
channels = 3,
dim_mults = (1, 2, 4, 8, 16),
cond_on_image_embeds = True,
cond_on_text_encodings = False,
attend_at_middle = False
)
# decoder, which contains the unet(s) and clip
decoder = Decoder(
clip = clip,
vae = (vae1, vae2), # latent diffusion for unet1 (vae1) and unet2 (vae2), but not for the last unet3
unet = (unet1, unet2, unet3), # insert unets in order of low resolution to highest resolution (you can have as many stages as you want here)
image_sizes = (256, 512, 1024), # resolutions, 256 for first unet, 512 for second, 1024 for third
timesteps = 100,
cond_drop_prob = 0.2
).cuda()
# mock images (get a lot of this)
images = torch.randn(1, 3, 1024, 1024).cuda()
# feed images into decoder, specifying which unet you want to train
# each unet can be trained separately, which is one of the benefits of the cascading DDPM scheme
with decoder.one_unet_in_gpu(1):
loss = decoder(images, unet_number = 1)
loss.backward()
with decoder.one_unet_in_gpu(2):
loss = decoder(images, unet_number = 2)
loss.backward()
with decoder.one_unet_in_gpu(3):
loss = decoder(images, unet_number = 3)
loss.backward()
# do the above for many steps for both unets
# then it will learn to generate images based on the CLIP image embeddings
# chaining the unets from lowest resolution to highest resolution (thus cascading)
mock_image_embed = torch.randn(1, 512).cuda()
images = decoder.sample(mock_image_embed) # (1, 3, 1024, 1024)
```
## Training wrapper (wip)
Offer training wrappers
## CLI (wip)
```bash
$ dream 'sharing a sunset at the summit of mount everest with my dog'
@@ -393,9 +511,7 @@ $ dream 'sharing a sunset at the summit of mount everest with my dog'
Once built, images will be saved to the same directory the command is invoked
## Training wrapper (wip)
Offer training wrappers
<a href="https://github.com/lucidrains/big-sleep">template</a>
## Training CLI (wip)
@@ -412,9 +528,14 @@ Offer training wrappers
- [x] add efficient attention in unet
- [x] be able to finely customize what to condition on (text, image embed) for specific unet in the cascade (super resolution ddpms near the end may not need too much conditioning)
- [x] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately)
- [ ] build out latent diffusion architecture in separate file, as it is not faithful to dalle-2 (but offer it as as setting)
- [x] build out latent diffusion architecture, with the vq-reg variant (vqgan-vae), make it completely optional and compatible with cascading ddpms
- [ ] spend one day cleaning up tech debt in decoder
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
- [ ] train on a toy task, offer in colab
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
- [ ] bring in tools to train vqgan-vae
- [ ] for decoder, allow ability to customize objective (predict epsilon vs x0), in case latent diffusion does better with prediction of x0
## Citations

View File

@@ -1,2 +1,4 @@
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
from dalle2_pytorch.vqgan_vae import VQGanVAE
from x_clip import CLIP

View File

@@ -16,6 +16,7 @@ from einops_exts.torch import EinopsToAndFrom
from kornia.filters import gaussian_blur2d
from dalle2_pytorch.tokenizer import tokenizer
from dalle2_pytorch.vqgan_vae import NullVQGanVAE, VQGanVAE
# use x-clip
@@ -48,6 +49,12 @@ def is_list_str(x):
return False
return all([type(el) == str for el in x])
def pad_tuple_to_length(t, length, fillvalue = None):
remain_length = length - len(t)
if remain_length <= 0:
return t
return (*t, *((fillvalue,) * remain_length))
# for controlling freezing of CLIP
def set_module_requires_grad_(module, requires_grad):
@@ -137,23 +144,27 @@ def sigmoid_beta_schedule(timesteps):
# diffusion prior
class RMSNorm(nn.Module):
class LayerNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.ones(dim))
self.register_buffer("beta", torch.zeros(dim))
def forward(self, x):
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
class ChanLayerNorm(nn.Module):
def __init__(self, dim, eps = 1e-5):
super().__init__()
self.eps = eps
self.scale = dim ** 0.5
self.gamma = nn.Parameter(torch.ones(dim))
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
def forward(self, x):
squared_sum = (x ** 2).sum(dim = -1, keepdim = True)
inv_norm = torch.rsqrt(squared_sum + self.eps)
return x * inv_norm * self.gamma * self.scale
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (var + self.eps).sqrt() * self.g
class ChanRMSNorm(RMSNorm):
def forward(self, x):
squared_sum = (x ** 2).sum(dim = 1, keepdim = True)
inv_norm = torch.rsqrt(squared_sum + self.eps)
return x * inv_norm * rearrange(self.gamma, 'c -> 1 c 1 1') * self.scale
class Residual(nn.Module):
def __init__(self, fn):
@@ -249,10 +260,10 @@ def FeedForward(dim, mult = 4, dropout = 0., post_activation_norm = False):
inner_dim = int(mult * dim)
return nn.Sequential(
RMSNorm(dim),
LayerNorm(dim),
nn.Linear(dim, inner_dim * 2, bias = False),
SwiGLU(),
RMSNorm(inner_dim) if post_activation_norm else nn.Identity(),
LayerNorm(inner_dim) if post_activation_norm else nn.Identity(),
nn.Dropout(dropout),
nn.Linear(inner_dim, dim, bias = False)
)
@@ -275,7 +286,8 @@ class Attention(nn.Module):
inner_dim = dim_head * heads
self.causal = causal
self.norm = RMSNorm(dim)
self.norm = LayerNorm(dim)
self.post_norm = LayerNorm(dim) # sandwich norm from Coqview paper + Normformer
self.dropout = nn.Dropout(dropout)
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
@@ -331,7 +343,8 @@ class Attention(nn.Module):
out = einsum('b h i j, b j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
out = self.to_out(out)
return self.post_norm(out)
class CausalTransformer(nn.Module):
def __init__(
@@ -344,7 +357,8 @@ class CausalTransformer(nn.Module):
ff_mult = 4,
norm_out = False,
attn_dropout = 0.,
ff_dropout = 0.
ff_dropout = 0.,
final_proj = True
):
super().__init__()
self.rel_pos_bias = RelPosBias(heads = heads)
@@ -356,7 +370,8 @@ class CausalTransformer(nn.Module):
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
]))
self.norm = RMSNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
self.norm = LayerNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
self.project_out = nn.Linear(dim, dim, bias = False) if final_proj else nn.Identity()
def forward(
self,
@@ -371,7 +386,8 @@ class CausalTransformer(nn.Module):
x = attn(x, mask = mask, attn_bias = attn_bias) + x
x = ff(x) + x
return self.norm(x)
out = self.norm(x)
return self.project_out(out)
class DiffusionPriorNetwork(nn.Module):
def __init__(
@@ -464,11 +480,11 @@ class DiffusionPrior(nn.Module):
net,
*,
clip,
timesteps=1000,
cond_drop_prob=0.2,
loss_type="l1",
predict_x0=True,
beta_schedule="cosine",
timesteps = 1000,
cond_drop_prob = 0.2,
loss_type = "l1",
predict_x0 = True,
beta_schedule = "cosine",
):
super().__init__()
assert isinstance(clip, CLIP)
@@ -498,7 +514,7 @@ class DiffusionPrior(nn.Module):
raise NotImplementedError()
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod = torch.cumprod(alphas, axis = 0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
timesteps, = betas.shape
@@ -531,12 +547,14 @@ class DiffusionPrior(nn.Module):
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
@torch.no_grad()
def get_image_embed(self, image):
image_encoding = self.clip.visual_transformer(image)
image_cls = image_encoding[:, 0]
image_embed = self.clip.to_visual_latent(image_cls)
return l2norm(image_embed)
@torch.no_grad()
def get_text_cond(self, text):
text_encodings = self.clip.text_transformer(text)
text_cls, text_encodings = text_encodings[:, 0], text_encodings[:, 1:]
@@ -720,7 +738,7 @@ class ConvNextBlock(nn.Module):
inner_dim = int(dim_out * mult)
self.net = nn.Sequential(
ChanRMSNorm(dim) if norm else nn.Identity(),
ChanLayerNorm(dim) if norm else nn.Identity(),
nn.Conv2d(dim, inner_dim, 3, padding = 1),
nn.GELU(),
nn.Conv2d(inner_dim, dim_out, 3, padding = 1)
@@ -756,8 +774,8 @@ class CrossAttention(nn.Module):
context_dim = default(context_dim, dim)
self.norm = RMSNorm(dim)
self.norm_context = RMSNorm(context_dim)
self.norm = LayerNorm(dim)
self.norm_context = LayerNorm(context_dim)
self.dropout = nn.Dropout(dropout)
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
@@ -825,10 +843,9 @@ class Unet(nn.Module):
out_dim = None,
dim_mults=(1, 2, 4, 8),
channels = 3,
attn_dim_head = 32,
attn_heads = 8,
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
lowres_cond_upsample_mode = 'bilinear',
blur_sigma = 0.1,
blur_kernel_size = 3,
sparse_attn = False,
sparse_attn_window = 8, # window size for sparse attention
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
@@ -845,9 +862,6 @@ class Unet(nn.Module):
# for eventual cascading diffusion
self.lowres_cond = lowres_cond
self.lowres_cond_upsample_mode = lowres_cond_upsample_mode
self.lowres_blur_kernel_size = blur_kernel_size
self.lowres_blur_sigma = blur_sigma
# determine dimensions
@@ -888,6 +902,10 @@ class Unet(nn.Module):
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
self.null_text_embed = nn.Parameter(torch.randn(1, 1, cond_dim))
# attention related params
attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head)
# layers
self.downs = nn.ModuleList([])
@@ -901,7 +919,7 @@ class Unet(nn.Module):
self.downs.append(nn.ModuleList([
ConvNextBlock(dim_in, dim_out, norm = ind != 0),
Residual(GridAttention(dim_out, window_size = sparse_attn_window)) if sparse_attn else nn.Identity(),
Residual(GridAttention(dim_out, window_size = sparse_attn_window, **attn_kwargs)) if sparse_attn else nn.Identity(),
ConvNextBlock(dim_out, dim_out, cond_dim = layer_cond_dim),
Downsample(dim_out) if not is_last else nn.Identity()
]))
@@ -909,7 +927,7 @@ class Unet(nn.Module):
mid_dim = dims[-1]
self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim))) if attend_at_middle else None
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
@@ -918,7 +936,7 @@ class Unet(nn.Module):
self.ups.append(nn.ModuleList([
ConvNextBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim),
Residual(GridAttention(dim_in, window_size = sparse_attn_window)) if sparse_attn else nn.Identity(),
Residual(GridAttention(dim_in, window_size = sparse_attn_window, **attn_kwargs)) if sparse_attn else nn.Identity(),
ConvNextBlock(dim_in, dim_in, cond_dim = layer_cond_dim),
Upsample(dim_in)
]))
@@ -931,11 +949,16 @@ class Unet(nn.Module):
# if the current settings for the unet are not correct
# for cascading DDPM, then reinit the unet with the right settings
def force_lowres_cond(self, lowres_cond):
if lowres_cond == self.lowres_cond:
def cast_model_parameters(
self,
*,
lowres_cond,
channels
):
if lowres_cond == self.lowres_cond and channels == self.channels:
return self
updated_kwargs = {**self._locals, 'lowres_cond': lowres_cond}
updated_kwargs = {**self._locals, 'lowres_cond': lowres_cond, 'channels': channels}
return self.__class__(**updated_kwargs)
def forward_with_cond_scale(
@@ -971,13 +994,6 @@ class Unet(nn.Module):
assert not (self.lowres_cond and not exists(lowres_cond_img)), 'low resolution conditioning image must be present'
if exists(lowres_cond_img):
if self.training:
# when training, blur the low resolution conditional image
blur_sigma = default(blur_sigma, self.lowres_blur_sigma)
blur_kernel_size = default(blur_kernel_size, self.lowres_blur_kernel_size)
lowres_cond_img = gaussian_blur2d(lowres_cond_img, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2))
lowres_cond_img = resize_image_to(lowres_cond_img, x.shape[-2:], mode = self.lowres_cond_upsample_mode)
x = torch.cat((x, lowres_cond_img), dim = 1)
# time conditioning
@@ -1054,17 +1070,60 @@ class Unet(nn.Module):
return self.final_conv(x)
class LowresConditioner(nn.Module):
def __init__(
self,
cond_upsample_mode = 'bilinear',
downsample_first = True,
blur_sigma = 0.1,
blur_kernel_size = 3,
):
super().__init__()
self.cond_upsample_mode = cond_upsample_mode
self.downsample_first = downsample_first
self.blur_sigma = blur_sigma
self.blur_kernel_size = blur_kernel_size
def forward(
self,
cond_fmap,
*,
target_image_size,
downsample_image_size = None,
blur_sigma = None,
blur_kernel_size = None
):
target_image_size = cast_tuple(target_image_size, 2)
if self.training and self.downsample_first and exists(downsample_image_size):
cond_fmap = resize_image_to(cond_fmap, target_image_size, mode = self.cond_upsample_mode)
if self.training:
# when training, blur the low resolution conditional image
blur_sigma = default(blur_sigma, self.blur_sigma)
blur_kernel_size = default(blur_kernel_size, self.blur_kernel_size)
cond_fmap = gaussian_blur2d(cond_fmap, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2))
cond_fmap = resize_image_to(cond_fmap, target_image_size, mode = self.cond_upsample_mode)
return cond_fmap
class Decoder(nn.Module):
def __init__(
self,
unet,
*,
clip,
vae = None,
timesteps = 1000,
cond_drop_prob = 0.2,
loss_type = 'l1',
beta_schedule = 'cosine',
image_sizes = None # for cascading ddpm, image size at each stage
image_sizes = None, # for cascading ddpm, image size at each stage
lowres_cond_upsample_mode = 'bilinear', # cascading ddpm - low resolution upsample mode
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
blur_sigma = 0.1, # cascading ddpm - blur sigma
blur_kernel_size = 3, # cascading ddpm - blur kernel size
):
super().__init__()
assert isinstance(clip, CLIP)
@@ -1076,11 +1135,28 @@ class Decoder(nn.Module):
# automatically take care of ensuring that first unet is unconditional
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
unets = cast_tuple(unet)
vaes = pad_tuple_to_length(cast_tuple(vae), len(unets), fillvalue = NullVQGanVAE(channels = self.channels))
self.unets = nn.ModuleList([])
for ind, one_unet in enumerate(cast_tuple(unet)):
self.vaes = nn.ModuleList([])
for ind, (one_unet, one_vae) in enumerate(zip(unets, vaes)):
assert isinstance(one_unet, Unet)
assert isinstance(one_vae, (VQGanVAE, NullVQGanVAE))
is_first = ind == 0
one_unet = one_unet.force_lowres_cond(not is_first)
latent_dim = one_vae.encoded_dim if exists(one_vae) else None
unet_channels = default(latent_dim, self.channels)
one_unet = one_unet.cast_model_parameters(
lowres_cond = not is_first,
channels = unet_channels
)
self.unets.append(one_unet)
self.vaes.append(one_vae.copy_for_eval())
# unet image sizes
@@ -1089,12 +1165,26 @@ class Decoder(nn.Module):
assert len(self.unets) == len(image_sizes), f'you did not supply the correct number of u-nets ({len(self.unets)}) for resolutions {image_sizes}'
self.image_sizes = image_sizes
self.sample_channels = cast_tuple(self.channels, len(image_sizes))
# cascading ddpm related stuff
lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets))
assert lowres_conditions == (False, *((True,) * (len(self.unets) - 1))), 'the first unet must be unconditioned (by low resolution image), and the rest of the unets must have `lowres_cond` set to True'
self.to_lowres_cond = LowresConditioner(
cond_upsample_mode = lowres_cond_upsample_mode,
downsample_first = lowres_downsample_first,
blur_sigma = blur_sigma,
blur_kernel_size = blur_kernel_size,
)
# classifier free guidance
self.cond_drop_prob = cond_drop_prob
# noise schedule
if beta_schedule == "cosine":
betas = cosine_beta_schedule(timesteps)
elif beta_schedule == "linear":
@@ -1109,7 +1199,7 @@ class Decoder(nn.Module):
raise NotImplementedError()
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod = torch.cumprod(alphas, axis = 0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
timesteps, = betas.shape
@@ -1142,24 +1232,31 @@ class Decoder(nn.Module):
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
@contextmanager
def one_unet_in_gpu(self, unet_number):
def get_unet(self, unet_number):
assert 0 < unet_number <= len(self.unets)
index = unet_number - 1
return self.unets[index]
@contextmanager
def one_unet_in_gpu(self, unet_number = None, unet = None):
assert exists(unet_number) ^ exists(unet)
if exists(unet_number):
unet = self.get_unet(unet_number)
self.cuda()
self.unets.cpu()
unet = self.unets[index]
unet.cuda()
yield
unet.cpu()
self.unets.cpu()
@torch.no_grad()
def get_text_encodings(self, text):
text_encodings = self.clip.text_transformer(text)
return text_encodings[:, 1:]
@torch.no_grad()
def get_image_embed(self, image):
image = resize_image_to(image, self.clip_image_size)
image_encoding = self.clip.visual_transformer(image)
@@ -1216,6 +1313,7 @@ class Decoder(nn.Module):
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
img = self.p_sample(unet, img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
return img
def q_sample(self, x_start, t, noise=None):
@@ -1255,27 +1353,54 @@ class Decoder(nn.Module):
@eval_decorator
def sample(self, image_embed, text = None, cond_scale = 1.):
batch_size = image_embed.shape[0]
channels = self.channels
text_encodings = self.get_text_encodings(text) if exists(text) else None
img = None
for ind, (unet, image_size) in tqdm(enumerate(zip(self.unets, self.image_sizes))):
with self.one_unet_in_gpu(ind + 1):
shape = (batch_size, channels, image_size, image_size)
img = self.p_sample_loop(unet, shape, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = img)
for unet, vae, channel, image_size in tqdm(zip(self.unets, self.vaes, self.sample_channels, self.image_sizes)):
with self.one_unet_in_gpu(unet = unet):
lowres_cond_img = None
shape = (batch_size, channel, image_size, image_size)
if unet.lowres_cond:
lowres_cond_img = self.to_lowres_cond(img, target_image_size = image_size)
image_size = vae.get_encoded_fmap_size(image_size)
shape = (batch_size, vae.encoded_dim, image_size, image_size)
if exists(lowres_cond_img):
lowres_cond_img = vae.encode(lowres_cond_img)
img = self.p_sample_loop(
unet,
shape,
image_embed = image_embed,
text_encodings = text_encodings,
cond_scale = cond_scale,
lowres_cond_img = lowres_cond_img
)
img = vae.decode(img)
return img
def forward(self, image, text = None, image_embed = None, text_encodings = None, unet_number = None):
def forward(
self,
image,
text = None,
image_embed = None,
text_encodings = None,
unet_number = None
):
assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
unet_number = default(unet_number, 1)
assert 1 <= unet_number <= len(self.unets)
unet_index = unet_number - 1
index = unet_number - 1
unet = self.unets[index]
target_image_size = self.image_sizes[index]
unet = self.get_unet(unet_number)
target_image_size = self.image_sizes[unet_index]
vae = self.vaes[unet_index]
b, c, h, w, device, = *image.shape, image.device
@@ -1289,9 +1414,17 @@ class Decoder(nn.Module):
text_encodings = self.get_text_encodings(text) if exists(text) and not exists(text_encodings) else None
lowres_cond_img = image if index > 0 else None
ddpm_image = resize_image_to(image, target_image_size)
return self.p_losses(unet, ddpm_image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img)
lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None
image = resize_image_to(image, target_image_size)
vae.eval()
with torch.no_grad():
image = vae.encode(image)
if exists(lowres_cond_img):
lowres_cond_img = vae.encode(lowres_cond_img)
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img)
# main class

View File

@@ -1,12 +0,0 @@
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange
class LatentDiffusion(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x

562
dalle2_pytorch/vqgan_vae.py Normal file
View File

@@ -0,0 +1,562 @@
import copy
import math
from math import sqrt
from functools import partial, wraps
from vector_quantize_pytorch import VectorQuantize as VQ
import torch
from torch import nn, einsum
import torch.nn.functional as F
from torch.autograd import grad as torch_grad
import torchvision
from einops import rearrange, reduce, repeat
# constants
MList = nn.ModuleList
# helper functions
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
# decorators
def eval_decorator(fn):
def inner(model, *args, **kwargs):
was_training = model.training
model.eval()
out = fn(model, *args, **kwargs)
model.train(was_training)
return out
return inner
def remove_vgg(fn):
@wraps(fn)
def inner(self, *args, **kwargs):
has_vgg = hasattr(self, 'vgg')
if has_vgg:
vgg = self.vgg
delattr(self, 'vgg')
out = fn(self, *args, **kwargs)
if has_vgg:
self.vgg = vgg
return out
return inner
# keyword argument helpers
def pick_and_pop(keys, d):
values = list(map(lambda key: d.pop(key), keys))
return dict(zip(keys, values))
def group_dict_by_key(cond, d):
return_val = [dict(),dict()]
for key in d.keys():
match = bool(cond(key))
ind = int(not match)
return_val[ind][key] = d[key]
return (*return_val,)
def string_begins_with(prefix, str):
return str.startswith(prefix)
def group_by_key_prefix(prefix, d):
return group_dict_by_key(partial(string_begins_with, prefix), d)
def groupby_prefix_and_trim(prefix, d):
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
return kwargs_without_prefix, kwargs
# tensor helper functions
def log(t, eps = 1e-10):
return torch.log(t + eps)
def gradient_penalty(images, output, weight = 10):
batch_size = images.shape[0]
gradients = torch_grad(outputs = output, inputs = images,
grad_outputs = torch.ones(output.size(), device = images.device),
create_graph = True, retain_graph = True, only_inputs = True)[0]
gradients = rearrange(gradients, 'b ... -> b (...)')
return weight * ((gradients.norm(2, dim = 1) - 1) ** 2).mean()
def l2norm(t):
return F.normalize(t, dim = -1)
def leaky_relu(p = 0.1):
return nn.LeakyReLU(0.1)
def stable_softmax(t, dim = -1, alpha = 32 ** 2):
t = t / alpha
t = t - torch.amax(t, dim = dim, keepdim = True).detach()
return (t * alpha).softmax(dim = dim)
def safe_div(numer, denom, eps = 1e-8):
return numer / (denom + eps)
# gan losses
def hinge_discr_loss(fake, real):
return (F.relu(1 + fake) + F.relu(1 - real)).mean()
def hinge_gen_loss(fake):
return -fake.mean()
def bce_discr_loss(fake, real):
return (-log(1 - torch.sigmoid(fake)) - log(torch.sigmoid(real))).mean()
def bce_gen_loss(fake):
return -log(torch.sigmoid(fake)).mean()
def grad_layer_wrt_loss(loss, layer):
return torch_grad(
outputs = loss,
inputs = layer,
grad_outputs = torch.ones_like(loss),
retain_graph = True
)[0].detach()
# vqgan vae
class LayerNormChan(nn.Module):
def __init__(
self,
dim,
eps = 1e-5
):
super().__init__()
self.eps = eps
self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1))
def forward(self, x):
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (var + self.eps).sqrt() * self.gamma
class Discriminator(nn.Module):
def __init__(
self,
dims,
channels = 3,
groups = 16,
init_kernel_size = 5
):
super().__init__()
dim_pairs = zip(dims[:-1], dims[1:])
self.layers = MList([nn.Sequential(nn.Conv2d(channels, dims[0], init_kernel_size, padding = init_kernel_size // 2), leaky_relu())])
for dim_in, dim_out in dim_pairs:
self.layers.append(nn.Sequential(
nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1),
nn.GroupNorm(groups, dim_out),
leaky_relu()
))
dim = dims[-1]
self.to_logits = nn.Sequential( # return 5 x 5, for PatchGAN-esque training
nn.Conv2d(dim, dim, 1),
leaky_relu(),
nn.Conv2d(dim, 1, 4)
)
def forward(self, x):
for net in self.layers:
x = net(x)
return self.to_logits(x)
class ContinuousPositionBias(nn.Module):
""" from https://arxiv.org/abs/2111.09883 """
def __init__(self, *, dim, heads, layers = 2):
super().__init__()
self.net = MList([])
self.net.append(nn.Sequential(nn.Linear(2, dim), leaky_relu()))
for _ in range(layers - 1):
self.net.append(nn.Sequential(nn.Linear(dim, dim), leaky_relu()))
self.net.append(nn.Linear(dim, heads))
self.register_buffer('rel_pos', None, persistent = False)
def forward(self, x):
n, device = x.shape[-1], x.device
fmap_size = int(sqrt(n))
if not exists(self.rel_pos):
pos = torch.arange(fmap_size, device = device)
grid = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij'))
grid = rearrange(grid, 'c i j -> (i j) c')
rel_pos = rearrange(grid, 'i c -> i 1 c') - rearrange(grid, 'j c -> 1 j c')
rel_pos = torch.sign(rel_pos) * torch.log(rel_pos.abs() + 1)
self.register_buffer('rel_pos', rel_pos, persistent = False)
rel_pos = self.rel_pos.float()
for layer in self.net:
rel_pos = layer(rel_pos)
bias = rearrange(rel_pos, 'i j h -> h i j')
return x + bias
class GLUResBlock(nn.Module):
def __init__(self, chan, groups = 16):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(chan, chan * 2, 3, padding = 1),
nn.GLU(dim = 1),
nn.GroupNorm(groups, chan),
nn.Conv2d(chan, chan * 2, 3, padding = 1),
nn.GLU(dim = 1),
nn.GroupNorm(groups, chan),
nn.Conv2d(chan, chan, 1)
)
def forward(self, x):
return self.net(x) + x
class ResBlock(nn.Module):
def __init__(self, chan, groups = 16):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(chan, chan, 3, padding = 1),
nn.GroupNorm(groups, chan),
leaky_relu(),
nn.Conv2d(chan, chan, 3, padding = 1),
nn.GroupNorm(groups, chan),
leaky_relu(),
nn.Conv2d(chan, chan, 1)
)
def forward(self, x):
return self.net(x) + x
class VQGanAttention(nn.Module):
def __init__(
self,
*,
dim,
dim_head = 64,
heads = 8,
dropout = 0.
):
super().__init__()
self.heads = heads
self.scale = dim_head ** -0.5
inner_dim = heads * dim_head
self.dropout = nn.Dropout(dropout)
self.pre_norm = LayerNormChan(dim)
self.cpb = ContinuousPositionBias(dim = dim // 4, heads = heads)
self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)
self.to_out = nn.Conv2d(inner_dim, dim, 1, bias = False)
def forward(self, x):
h = self.heads
height, width, residual = *x.shape[-2:], x.clone()
x = self.pre_norm(x)
q, k, v = self.to_qkv(x).chunk(3, dim = 1)
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = h), (q, k, v))
sim = einsum('b h c i, b h c j -> b h i j', q, k) * self.scale
sim = self.cpb(sim)
attn = stable_softmax(sim, dim = -1)
attn = self.dropout(attn)
out = einsum('b h i j, b h c j -> b h c i', attn, v)
out = rearrange(out, 'b h c (x y) -> b (h c) x y', x = height, y = width)
out = self.to_out(out)
return out + residual
class NullVQGanVAE(nn.Module):
def __init__(
self,
*,
channels
):
super().__init__()
self.encoded_dim = channels
self.layers = 0
def get_encoded_fmap_size(self, size):
return size
def copy_for_eval(self):
return self
def encode(self, x):
return x
def decode(self, x):
return x
class VQGanVAE(nn.Module):
def __init__(
self,
*,
dim,
image_size,
channels = 3,
layers = 4,
layer_mults = None,
l2_recon_loss = False,
use_hinge_loss = True,
num_resnet_blocks = 1,
vgg = None,
vq_codebook_size = 512,
vq_decay = 0.8,
vq_commitment_weight = 1.,
vq_kmeans_init = True,
vq_use_cosine_sim = True,
use_attn = True,
attn_dim_head = 64,
attn_heads = 8,
resnet_groups = 16,
attn_dropout = 0.,
first_conv_kernel_size = 5,
use_vgg_and_gan = True,
**kwargs
):
super().__init__()
assert dim % resnet_groups == 0, f'dimension {dim} must be divisible by {resnet_groups} (groups for the groupnorm)'
vq_kwargs, kwargs = groupby_prefix_and_trim('vq_', kwargs)
self.image_size = image_size
self.channels = channels
self.layers = layers
self.fmap_size = image_size // (layers ** 2)
self.codebook_size = vq_codebook_size
self.encoders = MList([])
self.decoders = MList([])
layer_mults = default(layer_mults, list(map(lambda t: 2 ** t, range(layers))))
assert len(layer_mults) == layers, 'layer multipliers must be equal to designated number of layers'
layer_dims = [dim * mult for mult in layer_mults]
dims = (dim, *layer_dims)
codebook_dim = layer_dims[-1]
self.encoded_dim = dims[-1]
dim_pairs = zip(dims[:-1], dims[1:])
append = lambda arr, t: arr.append(t)
prepend = lambda arr, t: arr.insert(0, t)
if not isinstance(num_resnet_blocks, tuple):
num_resnet_blocks = (*((0,) * (layers - 1)), num_resnet_blocks)
if not isinstance(use_attn, tuple):
use_attn = (*((False,) * (layers - 1)), use_attn)
assert len(num_resnet_blocks) == layers, 'number of resnet blocks config must be equal to number of layers'
assert len(use_attn) == layers
for layer_index, (dim_in, dim_out), layer_num_resnet_blocks, layer_use_attn in zip(range(layers), dim_pairs, num_resnet_blocks, use_attn):
append(self.encoders, nn.Sequential(nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), leaky_relu()))
prepend(self.decoders, nn.Sequential(nn.Upsample(scale_factor = 2, mode = 'bilinear', align_corners = False), nn.Conv2d(dim_out, dim_in, 3, padding = 1), leaky_relu()))
if layer_use_attn:
prepend(self.decoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
for _ in range(layer_num_resnet_blocks):
append(self.encoders, ResBlock(dim_out, groups = resnet_groups))
prepend(self.decoders, GLUResBlock(dim_out, groups = resnet_groups))
if layer_use_attn:
append(self.encoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
prepend(self.encoders, nn.Conv2d(channels, dim, first_conv_kernel_size, padding = first_conv_kernel_size // 2))
append(self.decoders, nn.Conv2d(dim, channels, 1))
self.vq = VQ(
dim = codebook_dim,
codebook_size = vq_codebook_size,
decay = vq_decay,
commitment_weight = vq_commitment_weight,
accept_image_fmap = True,
kmeans_init = vq_kmeans_init,
use_cosine_sim = vq_use_cosine_sim,
**vq_kwargs
)
# reconstruction loss
self.recon_loss_fn = F.mse_loss if l2_recon_loss else F.l1_loss
# turn off GAN and perceptual loss if grayscale
self.vgg = None
self.discr = None
self.use_vgg_and_gan = use_vgg_and_gan
if not use_vgg_and_gan:
return
# preceptual loss
if exists(vgg):
self.vgg = vgg
else:
self.vgg = torchvision.models.vgg16(pretrained = True)
self.vgg.classifier = nn.Sequential(*self.vgg.classifier[:-2])
# gan related losses
self.discr = Discriminator(dims = dims, channels = channels)
self.discr_loss = hinge_discr_loss if use_hinge_loss else bce_discr_loss
self.gen_loss = hinge_gen_loss if use_hinge_loss else bce_gen_loss
def get_encoded_fmap_size(self, image_size):
return image_size // (2 ** self.layers)
def copy_for_eval(self):
device = next(self.parameters()).device
vae_copy = copy.deepcopy(self.cpu())
if vae_copy.use_vgg_and_gan:
del vae_copy.discr
del vae_copy.vgg
vae_copy.eval()
return vae_copy.to(device)
@remove_vgg
def state_dict(self, *args, **kwargs):
return super().state_dict(*args, **kwargs)
@remove_vgg
def load_state_dict(self, *args, **kwargs):
return super().load_state_dict(*args, **kwargs)
@property
def codebook(self):
return self.vq.codebook
def encode(self, fmap):
for enc in self.encoders:
fmap = enc(fmap)
return fmap
def decode(self, fmap, return_indices_and_loss = False):
fmap, indices, commit_loss = self.vq(fmap)
for dec in self.decoders:
fmap = dec(fmap)
if not return_indices_and_loss:
return fmap
return fmap, indices, commit_loss
def forward(
self,
img,
return_loss = False,
return_discr_loss = False,
return_recons = False
):
batch, channels, height, width, device = *img.shape, img.device
assert height == self.image_size and width == self.image_size, 'height and width of input image must be equal to {self.image_size}'
assert channels == self.channels, 'number of channels on image or sketch is not equal to the channels set on this VQGanVAE'
fmap = self.encode(img)
fmap, indices, commit_loss = self.decode(fmap, return_indices_and_loss = True)
if not return_loss and not return_discr_loss:
return fmap
assert return_loss ^ return_discr_loss, 'you should either return autoencoder loss or discriminator loss, but not both'
# whether to return discriminator loss
if return_discr_loss:
assert exists(self.discr), 'discriminator must exist to train it'
fmap.detach_()
img.requires_grad_()
fmap_discr_logits, img_discr_logits = map(self.discr, (fmap, img))
gp = gradient_penalty(img, img_discr_logits)
discr_loss = self.discr_loss(fmap_discr_logits, img_discr_logits)
loss = discr_loss + gp
if return_recons:
return loss, fmap
return loss
# reconstruction loss
recon_loss = self.recon_loss_fn(fmap, img)
# early return if training on grayscale
if not self.use_vgg_and_gan:
if return_recons:
return recon_loss, fmap
return recon_loss
# perceptual loss
img_vgg_input = img
fmap_vgg_input = fmap
if img.shape[1] == 1:
# handle grayscale for vgg
img_vgg_input, fmap_vgg_input = map(lambda t: repeat(t, 'b 1 ... -> b c ...', c = 3), (img_vgg_input, fmap_vgg_input))
img_vgg_feats = self.vgg(img_vgg_input)
recon_vgg_feats = self.vgg(fmap_vgg_input)
perceptual_loss = F.mse_loss(img_vgg_feats, recon_vgg_feats)
# generator loss
gen_loss = self.gen_loss(self.discr(fmap))
# calculate adaptive weight
last_dec_layer = self.decoders[-1].weight
norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_dec_layer).norm(p = 2)
norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_dec_layer).norm(p = 2)
adaptive_weight = safe_div(norm_grad_wrt_perceptual_loss, norm_grad_wrt_gen_loss)
adaptive_weight.clamp_(max = 1e4)
# combine losses
loss = recon_loss + perceptual_loss + commit_loss + adaptive_weight * gen_loss
if return_recons:
return loss, fmap
return loss

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.28',
version = '0.0.38',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',
@@ -30,6 +30,7 @@ setup(
'torch>=1.10',
'torchvision',
'tqdm',
'vector-quantize-pytorch',
'x-clip>=0.4.4',
'youtokentome'
],