Compare commits

..

1 Commits

Author SHA1 Message Date
Phil Wang
a95de92964 fix bug with @jihoonerd 2022-04-15 06:40:46 -07:00
6 changed files with 116 additions and 1025 deletions

133
README.md
View File

@@ -1,10 +1,8 @@
<img src="./dalle2.png" width="450px"></img>
## DALL-E 2 - Pytorch
## DALL-E 2 - Pytorch (wip)
Implementation of <a href="https://openai.com/dall-e-2/">DALL-E 2</a>, OpenAI's updated text-to-image synthesis neural network, in Pytorch.
<a href="https://youtu.be/RJwPN4qNi_Y?t=555">Yannic Kilcher summary</a> | <a href="https://www.youtube.com/watch?v=F1X4fHzF4mQ">AssemblyAI explainer</a>
Implementation of <a href="https://openai.com/dall-e-2/">DALL-E 2</a>, OpenAI's updated text-to-image synthesis neural network, in Pytorch. <a href="https://youtu.be/RJwPN4qNi_Y?t=555">Yannic Kilcher summary</a>
The main novelty seems to be an extra layer of indirection with the prior network (whether it is an autoregressive transformer or a diffusion network), which predicts an image embedding based on the text embedding from CLIP. Specifically, this repository will only build out the diffusion prior network, as it is the best performing variant (but which incidentally involves a causal transformer as the denoising network 😂)
@@ -14,7 +12,9 @@ It may also explore an extension of using <a href="https://huggingface.co/spaces
Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord" src="https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white"></a> if you are interested in helping out with the replication
There was enough interest for a Jax version. It will be completed after the Pytorch version shows signs of life on my toy tasks. <a href="https://github.com/lucidrains/dalle2-jax">Placeholder repository</a>. I will also eventually extend this to <a href="https://github.com/lucidrains/dalle2-video">text to video</a>, once the repository is in a good place.
Do let me know if anyone is interested in a Jax version https://github.com/lucidrains/DALLE2-pytorch/discussions/8
For all of you emailing me (there is a lot), the best way to contribute is through pull requests. Everything is open sourced after all. All my thoughts are public. This is your moment to participate.
## Install
@@ -109,7 +109,7 @@ unet = Unet(
# decoder, which contains the unet and clip
decoder = Decoder(
unet = unet,
net = unet,
clip = clip,
timesteps = 100,
cond_drop_prob = 0.2
@@ -182,82 +182,7 @@ loss.backward()
# now the diffusion prior can generate image embeddings from the text embeddings
```
In the paper, they actually used a <a href="https://cascaded-diffusion.github.io/">recently discovered technique</a>, from <a href="http://www.jonathanho.me/">Jonathan Ho</a> himself (original author of DDPMs, the core technique used in DALL-E v2) for high resolution image synthesis.
This can easily be used within this framework as so
```python
import torch
from dalle2_pytorch import Unet, Decoder, CLIP
# trained clip from step 1
clip = CLIP(
dim_text = 512,
dim_image = 512,
dim_latent = 512,
num_text_tokens = 49408,
text_enc_depth = 6,
text_seq_len = 256,
text_heads = 8,
visual_enc_depth = 6,
visual_image_size = 256,
visual_patch_size = 32,
visual_heads = 8
).cuda()
# 2 unets for the decoder (a la cascading DDPM)
unet1 = Unet(
dim = 32,
image_embed_dim = 512,
cond_dim = 128,
channels = 3,
dim_mults = (1, 2, 4, 8)
).cuda()
unet2 = Unet(
dim = 32,
image_embed_dim = 512,
cond_dim = 128,
channels = 3,
dim_mults = (1, 2, 4, 8, 16)
).cuda()
# decoder, which contains the unet(s) and clip
decoder = Decoder(
clip = clip,
unet = (unet1, unet2), # insert both unets in order of low resolution to highest resolution (you can have as many stages as you want here)
image_sizes = (256, 512), # resolutions, 256 for first unet, 512 for second. these must be unique and in ascending order (matches with the unets passed in)
timesteps = 1000,
cond_drop_prob = 0.2
).cuda()
# mock images (get a lot of this)
images = torch.randn(4, 3, 512, 512).cuda()
# feed images into decoder, specifying which unet you want to train
# each unet can be trained separately, which is one of the benefits of the cascading DDPM scheme
loss = decoder(images, unet_number = 1)
loss.backward()
loss = decoder(images, unet_number = 2)
loss.backward()
# do the above for many steps for both unets
# then it will learn to generate images based on the CLIP image embeddings
# chaining the unets from lowest resolution to highest resolution (thus cascading)
mock_image_embed = torch.randn(1, 512).cuda()
images = decoder.sample(mock_image_embed) # (1, 3, 512, 512)
```
Finally, to generate the DALL-E2 images from text. Insert the trained `DiffusionPrior` as well as the `Decoder` (which wraps `CLIP`, the causal transformer, and unet(s))
Finally, to generate the DALL-E2 images from text. Insert the trained `DiffusionPrior` as well as the `Decoder` (which both contains `CLIP`, a unet, and a causal transformer)
```python
from dalle2_pytorch import DALLE2
@@ -336,7 +261,7 @@ loss.backward()
# decoder (with unet)
unet1 = Unet(
unet = Unet(
dim = 128,
image_embed_dim = 512,
cond_dim = 128,
@@ -344,25 +269,15 @@ unet1 = Unet(
dim_mults=(1, 2, 4, 8)
).cuda()
unet2 = Unet(
dim = 16,
image_embed_dim = 512,
cond_dim = 128,
channels = 3,
dim_mults = (1, 2, 4, 8, 16)
).cuda()
decoder = Decoder(
unet = (unet1, unet2),
image_sizes = (128, 256),
net = unet,
clip = clip,
timesteps = 100,
cond_drop_prob = 0.2
).cuda()
for unet_number in (1, 2):
loss = decoder(images, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
loss.backward()
loss = decoder(images) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
loss.backward()
# do above for many steps
@@ -376,13 +291,11 @@ images = dalle2(
cond_scale = 2. # classifier free guidance strength (> 1 would strengthen the condition)
)
# save your image (in this example, of size 256x256)
# save your image
```
Everything in this readme should run without error
You can also train the decoder on images of greater than the size (say 512x512) at which CLIP was trained (256x256). The images will be resized to CLIP image resolution for the image embeddings
For the layperson, no worries, training will all be automated into a CLI tool, at least for small scale training.
## CLI Usage (work in progress)
@@ -407,16 +320,12 @@ Offer training wrappers
- [x] add what was proposed in the paper, where DDPM objective for image latent embedding predicts x0 directly (reread vq-diffusion paper and get caught up on that line of work)
- [x] make sure it works end to end to produce an output tensor, taking a single gradient step
- [x] augment unet so that it can also be conditioned on text encodings (although in paper they hinted this didn't make much a difference)
- [x] figure out all the current bag of tricks needed to make DDPMs great (starting with the blur trick mentioned in paper)
- [x] build the cascading ddpm by having Decoder class manage multiple unets at different resolutions
- [x] add efficient attention in unet
- [x] be able to finely customize what to condition on (text, image embed) for specific unet in the cascade (super resolution ddpms near the end may not need too much conditioning)
- [x] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately)
- [ ] build out latent diffusion architecture, with the vq-reg variant (vqgan-vae), make it completely optional
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
- [ ] look into Jonathan Ho's cascading DDPM for the decoder, as that seems to be what they are using. get caught up on DDPM literature
- [ ] figure out all the current bag of tricks needed to make DDPMs great (starting with the blur trick mentioned in paper)
- [ ] train on a toy task, offer in colab
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
- [ ] add attention to unet - apply some personal tricks with efficient attention
- [ ] figure out the big idea behind latent diffusion and what can be ported over
- [ ] consider U2-net for decoder https://arxiv.org/abs/2005.09007
## Citations
@@ -465,12 +374,4 @@ Offer training wrappers
}
```
```bibtex
@inproceedings{Tu2022MaxViTMV,
title = {MaxViT: Multi-Axis Vision Transformer},
author = {Zhe-Wei Tu and Hossein Talebi and Han Zhang and Feng Yang and Peyman Milanfar and Alan Conrad Bovik and Yinxiao Li},
year = {2022}
}
```
*Creating noise from data is easy; creating data from noise is generative modeling.* - Yang Song's <a href="https://arxiv.org/abs/2011.13456">paper</a>

View File

@@ -1,4 +1,2 @@
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
from dalle2_pytorch.vqgan_vae import VQGanVAE
from x_clip import CLIP

View File

@@ -6,4 +6,4 @@ def main():
@click.command()
@click.argument('text')
def dream(text):
return 'not ready yet'
return image

View File

@@ -1,8 +1,6 @@
import math
from tqdm import tqdm
from inspect import isfunction
from functools import partial
from contextlib import contextmanager
import torch
import torch.nn.functional as F
@@ -13,7 +11,7 @@ from einops.layers.torch import Rearrange
from einops_exts import rearrange_many, repeat_many, check_shape
from einops_exts.torch import EinopsToAndFrom
from kornia.filters import gaussian_blur2d
from kornia.filters import filter2d
from dalle2_pytorch.tokenizer import tokenizer
@@ -31,9 +29,6 @@ def default(val, d):
return val
return d() if isfunction(d) else d
def cast_tuple(val, length = 1):
return val if isinstance(val, tuple) else ((val,) * length)
def eval_decorator(fn):
def inner(model, *args, **kwargs):
was_training = model.training
@@ -69,15 +64,6 @@ def freeze_model_and_make_eval_(model):
def l2norm(t):
return F.normalize(t, dim = -1)
def resize_image_to(t, image_size, mode = 'bilinear'): # take a look at https://github.com/assafshocher/ResizeRight
shape = cast_tuple(image_size, 2)
orig_image_size = t.shape[-2:]
if orig_image_size == shape:
return t
return F.interpolate(t, size = shape, mode = mode)
# classifier free guidance functions
def prob_mask_like(shape, prob, device):
@@ -106,35 +92,12 @@ def cosine_beta_schedule(timesteps, s = 0.008):
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
"""
steps = timesteps + 1
x = torch.linspace(0, timesteps, steps)
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
x = torch.linspace(0, steps, steps)
alphas_cumprod = torch.cos(((x / steps) + s) / (1 + s) * torch.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0, 0.999)
def linear_beta_schedule(timesteps):
scale = 1000 / timesteps
beta_start = scale * 0.0001
beta_end = scale * 0.02
return torch.linspace(beta_start, beta_end, timesteps)
def quadratic_beta_schedule(timesteps):
scale = 1000 / timesteps
beta_start = scale * 0.0001
beta_end = scale * 0.02
return torch.linspace(beta_start**2, beta_end**2, timesteps) ** 2
def sigmoid_beta_schedule(timesteps):
scale = 1000 / timesteps
beta_start = scale * 0.0001
beta_end = scale * 0.02
betas = torch.linspace(-6, 6, timesteps)
return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
# diffusion prior
class RMSNorm(nn.Module):
@@ -466,9 +429,8 @@ class DiffusionPrior(nn.Module):
clip,
timesteps = 1000,
cond_drop_prob = 0.2,
loss_type = "l1",
predict_x0 = True,
beta_schedule = "cosine",
loss_type = 'l1',
predict_x0 = True
):
super().__init__()
assert isinstance(clip, CLIP)
@@ -484,21 +446,10 @@ class DiffusionPrior(nn.Module):
self.predict_x0 = predict_x0
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
if beta_schedule == "cosine":
betas = cosine_beta_schedule(timesteps)
elif beta_schedule == "linear":
betas = linear_beta_schedule(timesteps)
elif beta_schedule == "quadratic":
betas = quadratic_beta_schedule(timesteps)
elif beta_schedule == "jsd":
betas = 1.0 / torch.linspace(timesteps, 1, timesteps)
elif beta_schedule == "sigmoid":
betas = sigmoid_beta_schedule(timesteps)
else:
raise NotImplementedError()
betas = cosine_beta_schedule(timesteps)
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis = 0)
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
timesteps, = betas.shape
@@ -599,6 +550,31 @@ class DiffusionPrior(nn.Module):
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), text_cond = text_cond)
return img
@torch.no_grad()
def sample(self, text, num_samples_per_batch = 2):
# in the paper, what they did was
# sample 2 image embeddings, choose the top 1 similarity, as judged by CLIP
text = repeat(text, 'b ... -> (b r) ...', r = num_samples_per_batch)
batch_size = text.shape[0]
image_embed_dim = self.image_embed_dim
text_cond = self.get_text_cond(text)
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond)
text_embeds = text_cond['text_embed']
text_embeds = rearrange(text_embeds, '(b r) d -> b r d', r = num_samples_per_batch)
image_embeds = rearrange(image_embeds, '(b r) d -> b r d', r = num_samples_per_batch)
text_image_sims = einsum('b r d, b r d -> b r', l2norm(text_embeds), l2norm(image_embeds))
top_sim_indices = text_image_sims.topk(k = 1).indices
top_sim_indices = repeat(top_sim_indices, 'b 1 -> b 1 d', d = image_embed_dim)
top_image_embeds = image_embeds.gather(1, top_sim_indices)
return rearrange(top_image_embeds, 'b 1 d -> b d')
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
@@ -625,39 +601,11 @@ class DiffusionPrior(nn.Module):
loss = F.l1_loss(to_predict, x_recon)
elif self.loss_type == 'l2':
loss = F.mse_loss(to_predict, x_recon)
elif self.loss_type == "huber":
loss = F.smooth_l1_loss(to_predict, x_recon)
else:
raise NotImplementedError()
return loss
@torch.no_grad()
@eval_decorator
def sample(self, text, num_samples_per_batch = 2):
# in the paper, what they did was
# sample 2 image embeddings, choose the top 1 similarity, as judged by CLIP
text = repeat(text, 'b ... -> (b r) ...', r = num_samples_per_batch)
batch_size = text.shape[0]
image_embed_dim = self.image_embed_dim
text_cond = self.get_text_cond(text)
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond)
text_embeds = text_cond['text_embed']
text_embeds = rearrange(text_embeds, '(b r) d -> b r d', r = num_samples_per_batch)
image_embeds = rearrange(image_embeds, '(b r) d -> b r d', r = num_samples_per_batch)
text_image_sims = einsum('b r d, b r d -> b r', l2norm(text_embeds), l2norm(image_embeds))
top_sim_indices = text_image_sims.topk(k = 1).indices
top_sim_indices = repeat(top_sim_indices, 'b 1 -> b 1 d', d = image_embed_dim)
top_image_embeds = image_embeds.gather(1, top_sim_indices)
return rearrange(top_image_embeds, 'b 1 d -> b d')
def forward(self, text, image, *args, **kwargs):
b, device, img_size, = image.shape[0], image.device, self.image_size
check_shape(image, 'b c h w', h = img_size, w = img_size, c = self.channels)
@@ -677,6 +625,17 @@ def Upsample(dim):
def Downsample(dim):
return nn.Conv2d(dim, dim, 4, 2, 1)
class Blur(nn.Module):
def __init__(self):
super().__init__()
filt = torch.Tensor([1, 2, 1])
self.register_buffer('filt', filt)
def forward(self, x):
filt = self.filt
filt = rearrange(filt, '... j -> ... 1 j') * rearrange(flit, '... i -> ... i 1')
return filter2d(x, filt, normalized = True)
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
@@ -799,20 +758,6 @@ class CrossAttention(nn.Module):
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class GridAttention(nn.Module):
def __init__(self, *args, window_size = 8, **kwargs):
super().__init__()
self.window_size = window_size
self.attn = Attention(*args, **kwargs)
def forward(self, x):
h, w = x.shape[-2:]
wsz = self.window_size
x = rearrange(x, 'b c (w1 h) (w2 w) -> (b h w) (w1 w2) c', w1 = wsz, w2 = wsz)
out = self.attn(x)
out = rearrange(out, '(b h w) (w1 w2) c -> b c (w1 h) (w2 w)', w1 = wsz, w2 = wsz, h = h // wsz, w = w // wsz)
return out
class Unet(nn.Module):
def __init__(
self,
@@ -821,37 +766,14 @@ class Unet(nn.Module):
image_embed_dim,
cond_dim = None,
num_image_tokens = 4,
num_time_tokens = 2,
out_dim = None,
dim_mults=(1, 2, 4, 8),
channels = 3,
attn_dim_head = 32,
attn_heads = 8,
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
sparse_attn = False,
sparse_attn_window = 8, # window size for sparse attention
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
cond_on_text_encodings = False,
cond_on_image_embeds = False,
):
super().__init__()
# save locals to take care of some hyperparameters for cascading DDPM
self._locals = locals()
del self._locals['self']
del self._locals['__class__']
# for eventual cascading diffusion
self.lowres_cond = lowres_cond
# determine dimensions
self.channels = channels
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
dims = [init_channels, *map(lambda m: dim * m, dim_mults)]
dims = [channels, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
# time, image embeddings, and optional text encoding
@@ -862,8 +784,8 @@ class Unet(nn.Module):
SinusoidalPosEmb(dim),
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Linear(dim * 4, cond_dim * num_time_tokens),
Rearrange('b (r d) -> b r d', r = num_time_tokens)
nn.Linear(dim * 4, cond_dim),
Rearrange('b d -> b 1 d')
)
self.image_to_cond = nn.Sequential(
@@ -873,21 +795,11 @@ class Unet(nn.Module):
self.text_to_cond = nn.LazyLinear(cond_dim)
# finer control over whether to condition on image embeddings and text encodings
# so one can have the latter unets in the cascading DDPMs only focus on super-resoluting
self.cond_on_text_encodings = cond_on_text_encodings
self.cond_on_image_embeds = cond_on_image_embeds
# for classifier free guidance
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
self.null_text_embed = nn.Parameter(torch.randn(1, 1, cond_dim))
# attention related params
attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head)
# layers
self.downs = nn.ModuleList([])
@@ -895,32 +807,27 @@ class Unet(nn.Module):
num_resolutions = len(in_out)
for ind, (dim_in, dim_out) in enumerate(in_out):
is_first = ind == 0
is_last = ind >= (num_resolutions - 1)
layer_cond_dim = cond_dim if not is_first else None
self.downs.append(nn.ModuleList([
ConvNextBlock(dim_in, dim_out, norm = ind != 0),
Residual(GridAttention(dim_out, window_size = sparse_attn_window, **attn_kwargs)) if sparse_attn else nn.Identity(),
ConvNextBlock(dim_out, dim_out, cond_dim = layer_cond_dim),
ConvNextBlock(dim_out, dim_out, cond_dim = cond_dim),
Downsample(dim_out) if not is_last else nn.Identity()
]))
mid_dim = dims[-1]
self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim)))
self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
is_last = ind >= (num_resolutions - 2)
layer_cond_dim = cond_dim if not is_last else None
is_last = ind >= (num_resolutions - 1)
self.ups.append(nn.ModuleList([
ConvNextBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim),
Residual(GridAttention(dim_in, window_size = sparse_attn_window, **attn_kwargs)) if sparse_attn else nn.Identity(),
ConvNextBlock(dim_in, dim_in, cond_dim = layer_cond_dim),
Upsample(dim_in)
ConvNextBlock(dim_out * 2, dim_in, cond_dim = cond_dim),
ConvNextBlock(dim_in, dim_in, cond_dim = cond_dim),
Upsample(dim_in) if not is_last else nn.Identity()
]))
out_dim = default(out_dim, channels)
@@ -929,15 +836,6 @@ class Unet(nn.Module):
nn.Conv2d(dim, out_dim, 1)
)
# if the current settings for the unet are not correct
# for cascading DDPM, then reinit the unet with the right settings
def force_lowres_cond(self, lowres_cond):
if lowres_cond == self.lowres_cond:
return self
updated_kwargs = {**self._locals, 'lowres_cond': lowres_cond}
return self.__class__(**updated_kwargs)
def forward_with_cond_scale(
self,
*args,
@@ -958,49 +856,29 @@ class Unet(nn.Module):
time,
*,
image_embed,
lowres_cond_img = None,
text_encodings = None,
cond_drop_prob = 0.,
blur_sigma = None,
blur_kernel_size = None
cond_drop_prob = 0.
):
batch_size, device = x.shape[0], x.device
# add low resolution conditioning, if present
assert not (self.lowres_cond and not exists(lowres_cond_img)), 'low resolution conditioning image must be present'
if exists(lowres_cond_img):
x = torch.cat((x, lowres_cond_img), dim = 1)
# time conditioning
time_tokens = self.time_mlp(time)
# conditional dropout
cond_prob_mask = prob_mask_like((batch_size,), cond_drop_prob, device = device)
cond_prob_mask = rearrange(cond_prob_mask, 'b -> b 1 1')
# mask out image embedding depending on condition dropout
# for classifier free guidance
image_tokens = None
image_tokens = self.image_to_cond(image_embed)
if self.cond_on_image_embeds:
image_tokens = self.image_to_cond(image_embed)
image_tokens = torch.where(
cond_prob_mask,
image_tokens,
self.null_image_embed
)
image_tokens = torch.where(
cond_prob_mask,
image_tokens,
self.null_image_embed
)
# take care of text encodings (optional)
text_tokens = None
if exists(text_encodings) and self.cond_on_text_encodings:
if exists(text_encodings):
text_tokens = self.text_to_cond(text_encodings)
text_tokens = torch.where(
cond_prob_mask,
@@ -1010,155 +888,59 @@ class Unet(nn.Module):
# main conditioning tokens (c)
c = time_tokens
if exists(image_tokens):
c = torch.cat((c, image_tokens), dim = -2)
c = torch.cat((time_tokens, image_tokens), dim = -2)
# text and image conditioning tokens (mid_c)
# to save on compute, only do cross attention based conditioning on the inner most layers of the Unet
mid_c = c if not exists(text_tokens) else torch.cat((c, text_tokens), dim = -2)
mid_c = c if not exists(text_encodings) else torch.cat((c, text_tokens), dim = -2)
# go through the layers of the unet, down and up
hiddens = []
for convnext, sparse_attn, convnext2, downsample in self.downs:
for convnext, convnext2, downsample in self.downs:
x = convnext(x, c)
x = sparse_attn(x)
x = convnext2(x, c)
hiddens.append(x)
x = downsample(x)
x = self.mid_block1(x, mid_c)
if exists(self.mid_attn):
x = self.mid_attn(x)
x = self.mid_attn(x)
x = self.mid_block2(x, mid_c)
for convnext, sparse_attn, convnext2, upsample in self.ups:
for convnext, convnext2, upsample in self.ups:
x = torch.cat((x, hiddens.pop()), dim=1)
x = convnext(x, c)
x = sparse_attn(x)
x = convnext2(x, c)
x = upsample(x)
return self.final_conv(x)
class LowresConditioner(nn.Module):
def __init__(
self,
cond_upsample_mode = 'bilinear',
downsample_first = True,
blur_sigma = 0.1,
blur_kernel_size = 3,
):
super().__init__()
self.cond_upsample_mode = cond_upsample_mode
self.downsample_first = downsample_first
self.blur_sigma = blur_sigma
self.blur_kernel_size = blur_kernel_size
def forward(
self,
cond_fmap,
*,
target_image_size,
downsample_image_size = None,
blur_sigma = None,
blur_kernel_size = None
):
target_image_size = cast_tuple(target_image_size, 2)
if self.training and self.downsample_first and exists(downsample_image_size):
cond_fmap = resize_image_to(cond_fmap, target_image_size, mode = self.cond_upsample_mode)
if self.training:
# when training, blur the low resolution conditional image
blur_sigma = default(blur_sigma, self.blur_sigma)
blur_kernel_size = default(blur_kernel_size, self.blur_kernel_size)
cond_fmap = gaussian_blur2d(cond_fmap, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2))
cond_fmap = resize_image_to(cond_fmap, target_image_size, mode = self.cond_upsample_mode)
return cond_fmap
class Decoder(nn.Module):
def __init__(
self,
unet,
net,
*,
clip,
timesteps = 1000,
cond_drop_prob = 0.2,
loss_type = 'l1',
beta_schedule = 'cosine',
image_sizes = None, # for cascading ddpm, image size at each stage
lowres_cond_upsample_mode = 'bilinear', # cascading ddpm - low resolution upsample mode
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
blur_sigma = 0.1, # cascading ddpm - blur sigma
blur_kernel_size = 3, # cascading ddpm - blur kernel size
loss_type = 'l1'
):
super().__init__()
assert isinstance(clip, CLIP)
freeze_model_and_make_eval_(clip)
self.clip = clip
self.clip_image_size = clip.image_size
self.net = net
self.channels = clip.image_channels
# automatically take care of ensuring that first unet is unconditional
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
self.unets = nn.ModuleList([])
for ind, one_unet in enumerate(cast_tuple(unet)):
is_first = ind == 0
one_unet = one_unet.force_lowres_cond(not is_first)
self.unets.append(one_unet)
# unet image sizes
image_sizes = default(image_sizes, (clip.image_size,))
image_sizes = tuple(sorted(set(image_sizes)))
assert len(self.unets) == len(image_sizes), f'you did not supply the correct number of u-nets ({len(self.unets)}) for resolutions {image_sizes}'
self.image_sizes = image_sizes
self.sample_channels = cast_tuple(self.channels, len(image_sizes))
# cascading ddpm related stuff
lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets))
assert lowres_conditions == (False, *((True,) * (len(self.unets) - 1))), 'the first unet must be unconditioned (by low resolution image), and the rest of the unets must have `lowres_cond` set to True'
self.to_lowres_cond = LowresConditioner(
cond_upsample_mode = lowres_cond_upsample_mode,
downsample_first = lowres_downsample_first,
blur_sigma = blur_sigma,
blur_kernel_size = blur_kernel_size,
)
# classifier free guidance
self.image_size = clip.image_size
self.cond_drop_prob = cond_drop_prob
# noise schedule
if beta_schedule == "cosine":
betas = cosine_beta_schedule(timesteps)
elif beta_schedule == "linear":
betas = linear_beta_schedule(timesteps)
elif beta_schedule == "quadratic":
betas = quadratic_beta_schedule(timesteps)
elif beta_schedule == "jsd":
betas = 1.0 / torch.linspace(timesteps, 1, timesteps)
elif beta_schedule == "sigmoid":
betas = sigmoid_beta_schedule(timesteps)
else:
raise NotImplementedError()
betas = cosine_beta_schedule(timesteps)
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis = 0)
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
timesteps, = betas.shape
@@ -1191,31 +973,11 @@ class Decoder(nn.Module):
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
def get_unet(self, unet_number):
assert 0 < unet_number <= len(self.unets)
index = unet_number - 1
return self.unets[index]
@contextmanager
def one_unet_in_gpu(self, unet_number = None, unet = None):
assert exists(unet_number) ^ exists(unet)
if exists(unet_number):
unet = self.get_unet(unet_number)
self.cuda()
self.unets.cpu()
unet.cuda()
yield
unet.cpu()
def get_text_encodings(self, text):
text_encodings = self.clip.text_transformer(text)
return text_encodings[:, 1:]
def get_image_embed(self, image):
image = resize_image_to(image, self.clip_image_size)
image_encoding = self.clip.visual_transformer(image)
image_cls = image_encoding[:, 0]
image_embed = self.clip.to_visual_latent(image_cls)
@@ -1242,9 +1004,8 @@ class Decoder(nn.Module):
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, lowres_cond_img = None, clip_denoised = True, cond_scale = 1.):
pred_noise = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
x_recon = self.predict_start_from_noise(x, t = t, noise = pred_noise)
def p_mean_variance(self, x, t, image_embed, text_encodings = None, clip_denoised = True, cond_scale = 1.):
x_recon = self.predict_start_from_noise(x, t = t, noise = self.net.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale))
if clip_denoised:
x_recon.clamp_(-1., 1.)
@@ -1253,26 +1014,33 @@ class Decoder(nn.Module):
return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad()
def p_sample(self, unet, x, t, image_embed, text_encodings = None, cond_scale = 1., lowres_cond_img = None, clip_denoised = True, repeat_noise = False):
def p_sample(self, x, t, image_embed, text_encodings = None, cond_scale = 1., clip_denoised = True, repeat_noise = False):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised)
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, clip_denoised = clip_denoised)
noise = noise_like(x.shape, device, repeat_noise)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad()
def p_sample_loop(self, unet, shape, image_embed, lowres_cond_img = None, text_encodings = None, cond_scale = 1):
def p_sample_loop(self, shape, image_embed, text_encodings = None, cond_scale = 1):
device = self.betas.device
b = shape[0]
img = torch.randn(shape, device = device)
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
img = self.p_sample(unet, img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
img = torch.randn(shape, device=device)
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale)
return img
@torch.no_grad()
def sample(self, image_embed, text = None, cond_scale = 1.):
batch_size = image_embed.shape[0]
image_size = self.image_size
channels = self.channels
text_encodings = self.get_text_encodings(text) if exists(text) else None
return self.p_sample_loop((batch_size, channels, image_size, image_size), image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale)
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
@@ -1281,17 +1049,16 @@ class Decoder(nn.Module):
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)
def p_losses(self, unet, x_start, t, *, image_embed, lowres_cond_img = None, text_encodings = None, noise = None):
def p_losses(self, x_start, t, *, image_embed, text_encodings = None, noise = None):
noise = default(noise, lambda: torch.randn_like(x_start))
x_noisy = self.q_sample(x_start = x_start, t = t, noise = noise)
x_recon = unet(
x_recon = self.net(
x_noisy,
t,
image_embed = image_embed,
text_encodings = text_encodings,
lowres_cond_img = lowres_cond_img,
cond_drop_prob = self.cond_drop_prob
)
@@ -1299,64 +1066,22 @@ class Decoder(nn.Module):
loss = F.l1_loss(noise, x_recon)
elif self.loss_type == 'l2':
loss = F.mse_loss(noise, x_recon)
elif self.loss_type == "huber":
loss = F.smooth_l1_loss(noise, x_recon)
else:
raise NotImplementedError()
return loss
@torch.no_grad()
@eval_decorator
def sample(self, image_embed, text = None, cond_scale = 1.):
batch_size = image_embed.shape[0]
text_encodings = self.get_text_encodings(text) if exists(text) else None
img = None
for unet, channel, image_size in tqdm(zip(self.unets, self.sample_channels, self.image_sizes)):
with self.one_unet_in_gpu(unet = unet):
lowres_cond_img = self.to_lowres_cond(
img,
target_image_size = image_size
) if unet.lowres_cond else None
img = self.p_sample_loop(
unet,
(batch_size, channel, image_size, image_size),
image_embed = image_embed,
text_encodings = text_encodings,
cond_scale = cond_scale,
lowres_cond_img = lowres_cond_img
)
return img
def forward(self, image, text = None, image_embed = None, text_encodings = None, unet_number = None):
assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
unet_number = default(unet_number, 1)
unet_index = unet_number - 1
unet = self.get_unet(unet_number)
target_image_size = self.image_sizes[unet_index]
b, c, h, w, device, = *image.shape, image.device
check_shape(image, 'b c h w', c = self.channels)
assert h >= target_image_size and w >= target_image_size
def forward(self, image, text = None):
b, device, img_size, = image.shape[0], image.device, self.image_size
check_shape(image, 'b c h w', h = img_size, w = img_size, c = self.channels)
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
if not exists(image_embed):
image_embed = self.get_image_embed(image)
image_embed = self.get_image_embed(image)
text_encodings = self.get_text_encodings(text) if exists(text) else None
text_encodings = self.get_text_encodings(text) if exists(text) and not exists(text_encodings) else None
lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None
ddpm_image = resize_image_to(image, target_image_size)
return self.p_losses(unet, ddpm_image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img)
loss = self.p_losses(image, times, image_embed = image_embed, text_encodings = text_encodings)
return loss
# main class

View File

@@ -1,532 +0,0 @@
import copy
import math
from math import sqrt
from functools import partial, wraps
from vector_quantize_pytorch import VectorQuantize as VQ
import torch
from torch import nn, einsum
import torch.nn.functional as F
from torch.autograd import grad as torch_grad
import torchvision
from einops import rearrange, reduce, repeat
# constants
MList = nn.ModuleList
# helper functions
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
# decorators
def eval_decorator(fn):
def inner(model, *args, **kwargs):
was_training = model.training
model.eval()
out = fn(model, *args, **kwargs)
model.train(was_training)
return out
return inner
def remove_vgg(fn):
@wraps(fn)
def inner(self, *args, **kwargs):
has_vgg = hasattr(self, 'vgg')
if has_vgg:
vgg = self.vgg
delattr(self, 'vgg')
out = fn(self, *args, **kwargs)
if has_vgg:
self.vgg = vgg
return out
return inner
# keyword argument helpers
def pick_and_pop(keys, d):
values = list(map(lambda key: d.pop(key), keys))
return dict(zip(keys, values))
def group_dict_by_key(cond, d):
return_val = [dict(),dict()]
for key in d.keys():
match = bool(cond(key))
ind = int(not match)
return_val[ind][key] = d[key]
return (*return_val,)
def string_begins_with(prefix, str):
return str.startswith(prefix)
def group_by_key_prefix(prefix, d):
return group_dict_by_key(partial(string_begins_with, prefix), d)
def groupby_prefix_and_trim(prefix, d):
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
return kwargs_without_prefix, kwargs
# tensor helper functions
def log(t, eps = 1e-10):
return torch.log(t + eps)
def gradient_penalty(images, output, weight = 10):
batch_size = images.shape[0]
gradients = torch_grad(outputs = output, inputs = images,
grad_outputs = torch.ones(output.size(), device = images.device),
create_graph = True, retain_graph = True, only_inputs = True)[0]
gradients = rearrange(gradients, 'b ... -> b (...)')
return weight * ((gradients.norm(2, dim = 1) - 1) ** 2).mean()
def l2norm(t):
return F.normalize(t, dim = -1)
def leaky_relu(p = 0.1):
return nn.LeakyReLU(0.1)
def stable_softmax(t, dim = -1, alpha = 32 ** 2):
t = t / alpha
t = t - torch.amax(t, dim = dim, keepdim = True).detach()
return (t * alpha).softmax(dim = dim)
def safe_div(numer, denom, eps = 1e-8):
return numer / (denom + eps)
# gan losses
def hinge_discr_loss(fake, real):
return (F.relu(1 + fake) + F.relu(1 - real)).mean()
def hinge_gen_loss(fake):
return -fake.mean()
def bce_discr_loss(fake, real):
return (-log(1 - torch.sigmoid(fake)) - log(torch.sigmoid(real))).mean()
def bce_gen_loss(fake):
return -log(torch.sigmoid(fake)).mean()
def grad_layer_wrt_loss(loss, layer):
return torch_grad(
outputs = loss,
inputs = layer,
grad_outputs = torch.ones_like(loss),
retain_graph = True
)[0].detach()
# vqgan vae
class LayerNormChan(nn.Module):
def __init__(
self,
dim,
eps = 1e-5
):
super().__init__()
self.eps = eps
self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1))
def forward(self, x):
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (var + self.eps).sqrt() * self.gamma
class Discriminator(nn.Module):
def __init__(
self,
dims,
channels = 3,
groups = 16,
init_kernel_size = 5
):
super().__init__()
dim_pairs = zip(dims[:-1], dims[1:])
self.layers = MList([nn.Sequential(nn.Conv2d(channels, dims[0], init_kernel_size, padding = init_kernel_size // 2), leaky_relu())])
for dim_in, dim_out in dim_pairs:
self.layers.append(nn.Sequential(
nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1),
nn.GroupNorm(groups, dim_out),
leaky_relu()
))
dim = dims[-1]
self.to_logits = nn.Sequential( # return 5 x 5, for PatchGAN-esque training
nn.Conv2d(dim, dim, 1),
leaky_relu(),
nn.Conv2d(dim, 1, 4)
)
def forward(self, x):
for net in self.layers:
x = net(x)
return self.to_logits(x)
class ContinuousPositionBias(nn.Module):
""" from https://arxiv.org/abs/2111.09883 """
def __init__(self, *, dim, heads, layers = 2):
super().__init__()
self.net = MList([])
self.net.append(nn.Sequential(nn.Linear(2, dim), leaky_relu()))
for _ in range(layers - 1):
self.net.append(nn.Sequential(nn.Linear(dim, dim), leaky_relu()))
self.net.append(nn.Linear(dim, heads))
self.register_buffer('rel_pos', None, persistent = False)
def forward(self, x):
n, device = x.shape[-1], x.device
fmap_size = int(sqrt(n))
if not exists(self.rel_pos):
pos = torch.arange(fmap_size, device = device)
grid = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij'))
grid = rearrange(grid, 'c i j -> (i j) c')
rel_pos = rearrange(grid, 'i c -> i 1 c') - rearrange(grid, 'j c -> 1 j c')
rel_pos = torch.sign(rel_pos) * torch.log(rel_pos.abs() + 1)
self.register_buffer('rel_pos', rel_pos, persistent = False)
rel_pos = self.rel_pos.float()
for layer in self.net:
rel_pos = layer(rel_pos)
bias = rearrange(rel_pos, 'i j h -> h i j')
return x + bias
class GLUResBlock(nn.Module):
def __init__(self, chan, groups = 16):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(chan, chan * 2, 3, padding = 1),
nn.GLU(dim = 1),
nn.GroupNorm(groups, chan),
nn.Conv2d(chan, chan * 2, 3, padding = 1),
nn.GLU(dim = 1),
nn.GroupNorm(groups, chan),
nn.Conv2d(chan, chan, 1)
)
def forward(self, x):
return self.net(x) + x
class ResBlock(nn.Module):
def __init__(self, chan, groups = 16):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(chan, chan, 3, padding = 1),
nn.GroupNorm(groups, chan),
leaky_relu(),
nn.Conv2d(chan, chan, 3, padding = 1),
nn.GroupNorm(groups, chan),
leaky_relu(),
nn.Conv2d(chan, chan, 1)
)
def forward(self, x):
return self.net(x) + x
class VQGanAttention(nn.Module):
def __init__(
self,
*,
dim,
dim_head = 64,
heads = 8,
dropout = 0.
):
super().__init__()
self.heads = heads
self.scale = dim_head ** -0.5
inner_dim = heads * dim_head
self.dropout = nn.Dropout(dropout)
self.pre_norm = LayerNormChan(dim)
self.cpb = ContinuousPositionBias(dim = dim // 4, heads = heads)
self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)
self.to_out = nn.Conv2d(inner_dim, dim, 1, bias = False)
def forward(self, x):
h = self.heads
height, width, residual = *x.shape[-2:], x.clone()
x = self.pre_norm(x)
q, k, v = self.to_qkv(x).chunk(3, dim = 1)
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = h), (q, k, v))
sim = einsum('b h c i, b h c j -> b h i j', q, k) * self.scale
sim = self.cpb(sim)
attn = stable_softmax(sim, dim = -1)
attn = self.dropout(attn)
out = einsum('b h i j, b h c j -> b h c i', attn, v)
out = rearrange(out, 'b h c (x y) -> b (h c) x y', x = height, y = width)
out = self.to_out(out)
return out + residual
class VQGanVAE(nn.Module):
def __init__(
self,
*,
dim,
image_size,
channels = 3,
num_layers = 4,
layer_mults = None,
l2_recon_loss = False,
use_hinge_loss = True,
num_resnet_blocks = 1,
vgg = None,
vq_codebook_size = 512,
vq_decay = 0.8,
vq_commitment_weight = 1.,
vq_kmeans_init = True,
vq_use_cosine_sim = True,
use_attn = True,
attn_dim_head = 64,
attn_heads = 8,
resnet_groups = 16,
attn_dropout = 0.,
first_conv_kernel_size = 5,
use_vgg_and_gan = True,
**kwargs
):
super().__init__()
assert dim % resnet_groups == 0, f'dimension {dim} must be divisible by {resnet_groups} (groups for the groupnorm)'
vq_kwargs, kwargs = groupby_prefix_and_trim('vq_', kwargs)
self.image_size = image_size
self.channels = channels
self.num_layers = num_layers
self.fmap_size = image_size // (num_layers ** 2)
self.codebook_size = vq_codebook_size
self.encoders = MList([])
self.decoders = MList([])
layer_mults = default(layer_mults, list(map(lambda t: 2 ** t, range(num_layers))))
assert len(layer_mults) == num_layers, 'layer multipliers must be equal to designated number of layers'
layer_dims = [dim * mult for mult in layer_mults]
dims = (dim, *layer_dims)
codebook_dim = layer_dims[-1]
dim_pairs = zip(dims[:-1], dims[1:])
append = lambda arr, t: arr.append(t)
prepend = lambda arr, t: arr.insert(0, t)
if not isinstance(num_resnet_blocks, tuple):
num_resnet_blocks = (*((0,) * (num_layers - 1)), num_resnet_blocks)
if not isinstance(use_attn, tuple):
use_attn = (*((False,) * (num_layers - 1)), use_attn)
assert len(num_resnet_blocks) == num_layers, 'number of resnet blocks config must be equal to number of layers'
assert len(use_attn) == num_layers
for layer_index, (dim_in, dim_out), layer_num_resnet_blocks, layer_use_attn in zip(range(num_layers), dim_pairs, num_resnet_blocks, use_attn):
append(self.encoders, nn.Sequential(nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), leaky_relu()))
prepend(self.decoders, nn.Sequential(nn.Upsample(scale_factor = 2, mode = 'bilinear', align_corners = False), nn.Conv2d(dim_out, dim_in, 3, padding = 1), leaky_relu()))
if layer_use_attn:
prepend(self.decoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
for _ in range(layer_num_resnet_blocks):
append(self.encoders, ResBlock(dim_out, groups = resnet_groups))
prepend(self.decoders, GLUResBlock(dim_out, groups = resnet_groups))
if layer_use_attn:
append(self.encoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
prepend(self.encoders, nn.Conv2d(channels, dim, first_conv_kernel_size, padding = first_conv_kernel_size // 2))
append(self.decoders, nn.Conv2d(dim, channels, 1))
self.vq = VQ(
dim = codebook_dim,
codebook_size = vq_codebook_size,
decay = vq_decay,
commitment_weight = vq_commitment_weight,
accept_image_fmap = True,
kmeans_init = vq_kmeans_init,
use_cosine_sim = vq_use_cosine_sim,
**vq_kwargs
)
# reconstruction loss
self.recon_loss_fn = F.mse_loss if l2_recon_loss else F.l1_loss
# turn off GAN and perceptual loss if grayscale
self.vgg = None
self.discr = None
self.use_vgg_and_gan = use_vgg_and_gan
if not use_vgg_and_gan:
return
# preceptual loss
if exists(vgg):
self.vgg = vgg
else:
self.vgg = torchvision.models.vgg16(pretrained = True)
self.vgg.classifier = nn.Sequential(*self.vgg.classifier[:-2])
# gan related losses
self.discr = Discriminator(dims = dims, channels = channels)
self.discr_loss = hinge_discr_loss if use_hinge_loss else bce_discr_loss
self.gen_loss = hinge_gen_loss if use_hinge_loss else bce_gen_loss
def copy_for_eval(self):
device = next(self.parameters()).device
vae_copy = copy.deepcopy(self.cpu())
if vae_copy.use_vgg_and_gan:
del vae_copy.discr
del vae_copy.vgg
vae_copy.eval()
return vae_copy.to(device)
@remove_vgg
def state_dict(self, *args, **kwargs):
return super().state_dict(*args, **kwargs)
@remove_vgg
def load_state_dict(self, *args, **kwargs):
return super().load_state_dict(*args, **kwargs)
@property
def codebook(self):
return self.vq.codebook
def encode(self, fmap):
for enc in self.encoders:
fmap = enc(fmap)
return fmap
def decode(self, fmap):
fmap = self.vq(fmap)
for dec in self.decoders:
fmap = dec(fmap)
return fmap
def forward(
self,
img,
return_loss = False,
return_discr_loss = False,
return_recons = False
):
batch, channels, height, width, device = *img.shape, img.device
assert height == self.image_size and width == self.image_size, 'height and width of input image must be equal to {self.image_size}'
assert channels == self.channels, 'number of channels on image or sketch is not equal to the channels set on this VQGanVAE'
fmap, indices, commit_loss = self.encode(img)
fmap = self.decode(fmap)
if not return_loss and not return_discr_loss:
return fmap
assert return_loss ^ return_discr_loss, 'you should either return autoencoder loss or discriminator loss, but not both'
# whether to return discriminator loss
if return_discr_loss:
assert exists(self.discr), 'discriminator must exist to train it'
fmap.detach_()
img.requires_grad_()
fmap_discr_logits, img_discr_logits = map(self.discr, (fmap, img))
gp = gradient_penalty(img, img_discr_logits)
discr_loss = self.discr_loss(fmap_discr_logits, img_discr_logits)
loss = discr_loss + gp
if return_recons:
return loss, fmap
return loss
# reconstruction loss
recon_loss = self.recon_loss_fn(fmap, img)
# early return if training on grayscale
if not self.use_vgg_and_gan:
if return_recons:
return recon_loss, fmap
return recon_loss
# perceptual loss
img_vgg_input = img
fmap_vgg_input = fmap
if img.shape[1] == 1:
# handle grayscale for vgg
img_vgg_input, fmap_vgg_input = map(lambda t: repeat(t, 'b 1 ... -> b c ...', c = 3), (img_vgg_input, fmap_vgg_input))
img_vgg_feats = self.vgg(img_vgg_input)
recon_vgg_feats = self.vgg(fmap_vgg_input)
perceptual_loss = F.mse_loss(img_vgg_feats, recon_vgg_feats)
# generator loss
gen_loss = self.gen_loss(self.discr(fmap))
# calculate adaptive weight
last_dec_layer = self.decoders[-1].weight
norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_dec_layer).norm(p = 2)
norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_dec_layer).norm(p = 2)
adaptive_weight = safe_div(norm_grad_wrt_perceptual_loss, norm_grad_wrt_gen_loss)
adaptive_weight.clamp_(max = 1e4)
# combine losses
loss = recon_loss + perceptual_loss + commit_loss + adaptive_weight * gen_loss
if return_recons:
return loss, fmap
return loss

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.33',
version = '0.0.16',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',
@@ -30,7 +30,6 @@ setup(
'torch>=1.10',
'torchvision',
'tqdm',
'vector-quantize-pytorch',
'x-clip>=0.4.4',
'youtokentome'
],