mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 19:44:26 +01:00
Compare commits
42 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8c610aad9a | ||
|
|
6700381a37 | ||
|
|
20377f889a | ||
|
|
6edb1c5dd0 | ||
|
|
b093f92182 | ||
|
|
fa3bb6ba5c | ||
|
|
2705e7c9b0 | ||
|
|
77141882c8 | ||
|
|
4075d02139 | ||
|
|
de0296106b | ||
|
|
eafb136214 | ||
|
|
bfbcc283a3 | ||
|
|
c30544b73a | ||
|
|
bdf5e9c009 | ||
|
|
9878be760b | ||
|
|
7ba6357c05 | ||
|
|
76e063e8b7 | ||
|
|
4d25976f33 | ||
|
|
0b28ee0d01 | ||
|
|
45262a4bb7 | ||
|
|
13a58a78c4 | ||
|
|
f75d49c781 | ||
|
|
3b520dfa85 | ||
|
|
79198c6ae4 | ||
|
|
77a246b1b9 | ||
|
|
f93a3f6ed8 | ||
|
|
8f2a0c7e00 | ||
|
|
863f4ef243 | ||
|
|
fb8a66a2de | ||
|
|
579d4b42dd | ||
|
|
473808850a | ||
|
|
d5318aef4f | ||
|
|
f82917e1fd | ||
|
|
05b74be69a | ||
|
|
a8b5d5d753 | ||
|
|
976ef7f87c | ||
|
|
fd175bcc0e | ||
|
|
76b32f18b3 | ||
|
|
f2d5b87677 | ||
|
|
461347c171 | ||
|
|
46cef31c86 | ||
|
|
59b1a77d4d |
293
README.md
293
README.md
@@ -10,11 +10,9 @@ The main novelty seems to be an extra layer of indirection with the prior networ
|
||||
|
||||
This model is SOTA for text-to-image for now.
|
||||
|
||||
It may also explore an extension of using <a href="https://huggingface.co/spaces/multimodalart/latentdiffusion">latent diffusion</a> in the decoder from Rombach et al.
|
||||
|
||||
Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord" src="https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white"></a> if you are interested in helping out with the replication
|
||||
|
||||
There was enough interest for a Jax version. It will be completed after the Pytorch version shows signs of life on my toy tasks. <a href="https://github.com/lucidrains/dalle2-jax">Placeholder repository</a>. I will also eventually extend this to <a href="https://github.com/lucidrains/dalle2-video">text to video</a>, once the repository is in a good place.
|
||||
There was enough interest for a <a href="https://github.com/lucidrains/dalle2-jax">Jax version</a>. I will also eventually extend this to <a href="https://github.com/lucidrains/dalle2-video">text to video</a>, once the repository is in a good place.
|
||||
|
||||
## Install
|
||||
|
||||
@@ -248,13 +246,6 @@ loss = decoder(images, unet_number = 2)
|
||||
loss.backward()
|
||||
|
||||
# do the above for many steps for both unets
|
||||
|
||||
# then it will learn to generate images based on the CLIP image embeddings
|
||||
|
||||
# chaining the unets from lowest resolution to highest resolution (thus cascading)
|
||||
|
||||
mock_image_embed = torch.randn(1, 512).cuda()
|
||||
images = decoder.sample(mock_image_embed) # (1, 3, 512, 512)
|
||||
```
|
||||
|
||||
Finally, to generate the DALL-E2 images from text. Insert the trained `DiffusionPrior` as well as the `Decoder` (which wraps `CLIP`, the causal transformer, and unet(s))
|
||||
@@ -357,7 +348,8 @@ decoder = Decoder(
|
||||
image_sizes = (128, 256),
|
||||
clip = clip,
|
||||
timesteps = 100,
|
||||
cond_drop_prob = 0.2
|
||||
cond_drop_prob = 0.2,
|
||||
condition_on_text_encodings = False # set this to True if you wish to condition on text during training and sampling
|
||||
).cuda()
|
||||
|
||||
for unet_number in (1, 2):
|
||||
@@ -385,7 +377,247 @@ You can also train the decoder on images of greater than the size (say 512x512)
|
||||
|
||||
For the layperson, no worries, training will all be automated into a CLI tool, at least for small scale training.
|
||||
|
||||
## CLI Usage (work in progress)
|
||||
## Training on Preprocessed CLIP Embeddings
|
||||
|
||||
It is likely, when scaling up, that you would first preprocess your images and text into corresponding embeddings before training the prior network. You can do so easily by simply passing in `image_embed`, `text_embed`, and optionally `text_encodings` and `text_mask`
|
||||
|
||||
Working example below
|
||||
|
||||
```python
|
||||
import torch
|
||||
from dalle2_pytorch import DiffusionPriorNetwork, DiffusionPrior, CLIP
|
||||
|
||||
# get trained CLIP from step one
|
||||
|
||||
clip = CLIP(
|
||||
dim_text = 512,
|
||||
dim_image = 512,
|
||||
dim_latent = 512,
|
||||
num_text_tokens = 49408,
|
||||
text_enc_depth = 6,
|
||||
text_seq_len = 256,
|
||||
text_heads = 8,
|
||||
visual_enc_depth = 6,
|
||||
visual_image_size = 256,
|
||||
visual_patch_size = 32,
|
||||
visual_heads = 8,
|
||||
).cuda()
|
||||
|
||||
# setup prior network, which contains an autoregressive transformer
|
||||
|
||||
prior_network = DiffusionPriorNetwork(
|
||||
dim = 512,
|
||||
depth = 6,
|
||||
dim_head = 64,
|
||||
heads = 8
|
||||
).cuda()
|
||||
|
||||
# diffusion prior network, which contains the CLIP and network (with transformer) above
|
||||
|
||||
diffusion_prior = DiffusionPrior(
|
||||
net = prior_network,
|
||||
clip = clip,
|
||||
timesteps = 100,
|
||||
cond_drop_prob = 0.2,
|
||||
condition_on_text_encodings = False # this probably should be true, but just to get Laion started
|
||||
).cuda()
|
||||
|
||||
# mock data
|
||||
|
||||
text = torch.randint(0, 49408, (4, 256)).cuda()
|
||||
images = torch.randn(4, 3, 256, 256).cuda()
|
||||
|
||||
# precompute the text and image embeddings
|
||||
# here using the diffusion prior class, but could be done with CLIP alone
|
||||
|
||||
clip_image_embeds = diffusion_prior.get_image_embed(images)
|
||||
clip_text_embeds = diffusion_prior.get_text_cond(text).get('text_embed')
|
||||
|
||||
# feed text and images into diffusion prior network
|
||||
|
||||
loss = diffusion_prior(
|
||||
text_embed = clip_text_embeds,
|
||||
image_embed = clip_image_embeds
|
||||
)
|
||||
|
||||
loss.backward()
|
||||
|
||||
# do the above for many many many steps
|
||||
# now the diffusion prior can generate image embeddings from the text embeddings
|
||||
```
|
||||
|
||||
You can also completely go `CLIP`-less, in which case you will need to pass in the `image_embed_dim` into the `DiffusionPrior` on initialization
|
||||
|
||||
```python
|
||||
import torch
|
||||
from dalle2_pytorch import DiffusionPriorNetwork, DiffusionPrior
|
||||
|
||||
# setup prior network, which contains an autoregressive transformer
|
||||
|
||||
prior_network = DiffusionPriorNetwork(
|
||||
dim = 512,
|
||||
depth = 6,
|
||||
dim_head = 64,
|
||||
heads = 8
|
||||
).cuda()
|
||||
|
||||
# diffusion prior network, which contains the CLIP and network (with transformer) above
|
||||
|
||||
diffusion_prior = DiffusionPrior(
|
||||
net = prior_network,
|
||||
image_embed_dim = 512, # this needs to be set
|
||||
timesteps = 100,
|
||||
cond_drop_prob = 0.2,
|
||||
condition_on_text_encodings = False # this probably should be true, but just to get Laion started
|
||||
).cuda()
|
||||
|
||||
# mock data
|
||||
|
||||
text = torch.randint(0, 49408, (4, 256)).cuda()
|
||||
images = torch.randn(4, 3, 256, 256).cuda()
|
||||
|
||||
# precompute the text and image embeddings
|
||||
# here using the diffusion prior class, but could be done with CLIP alone
|
||||
|
||||
clip_image_embeds = torch.randn(4, 512).cuda()
|
||||
clip_text_embeds = torch.randn(4, 512).cuda()
|
||||
|
||||
# feed text and images into diffusion prior network
|
||||
|
||||
loss = diffusion_prior(
|
||||
text_embed = clip_text_embeds,
|
||||
image_embed = clip_image_embeds
|
||||
)
|
||||
|
||||
loss.backward()
|
||||
|
||||
# do the above for many many many steps
|
||||
# now the diffusion prior can generate image embeddings from the text embeddings
|
||||
```
|
||||
|
||||
## Experimental
|
||||
|
||||
### DALL-E2 with Latent Diffusion
|
||||
|
||||
This repository decides to take the next step and offer DALL-E v2 combined with <a href="https://huggingface.co/spaces/multimodalart/latentdiffusion">latent diffusion</a>, from Rombach et al.
|
||||
|
||||
You can use it as follows. Latent diffusion can be limited to just the first U-Net in the cascade, or to any number you wish.
|
||||
|
||||
The repository also comes equipped with all the necessary settings to recreate `ViT-VQGan` from the <a href="https://arxiv.org/abs/2110.04627">Improved VQGans</a> paper. Furthermore, the <a href="https://github.com/lucidrains/vector-quantize-pytorch">vector quantization</a> library also comes equipped to do <a href="https://arxiv.org/abs/2203.01941">residual or multi-headed quantization</a>, which I believe will give an even further boost in performance to the autoencoder.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from dalle2_pytorch import Unet, Decoder, CLIP, VQGanVAE
|
||||
|
||||
# trained clip from step 1
|
||||
|
||||
clip = CLIP(
|
||||
dim_text = 512,
|
||||
dim_image = 512,
|
||||
dim_latent = 512,
|
||||
num_text_tokens = 49408,
|
||||
text_enc_depth = 1,
|
||||
text_seq_len = 256,
|
||||
text_heads = 8,
|
||||
visual_enc_depth = 1,
|
||||
visual_image_size = 256,
|
||||
visual_patch_size = 32,
|
||||
visual_heads = 8
|
||||
)
|
||||
|
||||
# 3 unets for the decoder (a la cascading DDPM)
|
||||
|
||||
# first two unets are doing latent diffusion
|
||||
# vqgan-vae must be trained before hand
|
||||
|
||||
vae1 = VQGanVAE(
|
||||
dim = 32,
|
||||
image_size = 256,
|
||||
layers = 3,
|
||||
layer_mults = (1, 2, 4)
|
||||
)
|
||||
|
||||
vae2 = VQGanVAE(
|
||||
dim = 32,
|
||||
image_size = 512,
|
||||
layers = 3,
|
||||
layer_mults = (1, 2, 4)
|
||||
)
|
||||
|
||||
unet1 = Unet(
|
||||
dim = 32,
|
||||
image_embed_dim = 512,
|
||||
cond_dim = 128,
|
||||
channels = 3,
|
||||
sparse_attn = True,
|
||||
sparse_attn_window = 2,
|
||||
dim_mults = (1, 2, 4, 8)
|
||||
)
|
||||
|
||||
unet2 = Unet(
|
||||
dim = 32,
|
||||
image_embed_dim = 512,
|
||||
channels = 3,
|
||||
dim_mults = (1, 2, 4, 8, 16),
|
||||
cond_on_image_embeds = True,
|
||||
cond_on_text_encodings = False
|
||||
)
|
||||
|
||||
unet3 = Unet(
|
||||
dim = 32,
|
||||
image_embed_dim = 512,
|
||||
channels = 3,
|
||||
dim_mults = (1, 2, 4, 8, 16),
|
||||
cond_on_image_embeds = True,
|
||||
cond_on_text_encodings = False,
|
||||
attend_at_middle = False
|
||||
)
|
||||
|
||||
# decoder, which contains the unet(s) and clip
|
||||
|
||||
decoder = Decoder(
|
||||
clip = clip,
|
||||
vae = (vae1, vae2), # latent diffusion for unet1 (vae1) and unet2 (vae2), but not for the last unet3
|
||||
unet = (unet1, unet2, unet3), # insert unets in order of low resolution to highest resolution (you can have as many stages as you want here)
|
||||
image_sizes = (256, 512, 1024), # resolutions, 256 for first unet, 512 for second, 1024 for third
|
||||
timesteps = 100,
|
||||
cond_drop_prob = 0.2
|
||||
).cuda()
|
||||
|
||||
# mock images (get a lot of this)
|
||||
|
||||
images = torch.randn(1, 3, 1024, 1024).cuda()
|
||||
|
||||
# feed images into decoder, specifying which unet you want to train
|
||||
# each unet can be trained separately, which is one of the benefits of the cascading DDPM scheme
|
||||
|
||||
with decoder.one_unet_in_gpu(1):
|
||||
loss = decoder(images, unet_number = 1)
|
||||
loss.backward()
|
||||
|
||||
with decoder.one_unet_in_gpu(2):
|
||||
loss = decoder(images, unet_number = 2)
|
||||
loss.backward()
|
||||
|
||||
with decoder.one_unet_in_gpu(3):
|
||||
loss = decoder(images, unet_number = 3)
|
||||
loss.backward()
|
||||
|
||||
# do the above for many steps for both unets
|
||||
|
||||
# then it will learn to generate images based on the CLIP image embeddings
|
||||
|
||||
# chaining the unets from lowest resolution to highest resolution (thus cascading)
|
||||
|
||||
mock_image_embed = torch.randn(1, 512).cuda()
|
||||
images = decoder.sample(mock_image_embed) # (1, 3, 1024, 1024)
|
||||
```
|
||||
|
||||
## Training wrapper (wip)
|
||||
|
||||
Offer training wrappers
|
||||
|
||||
## CLI (wip)
|
||||
|
||||
```bash
|
||||
$ dream 'sharing a sunset at the summit of mount everest with my dog'
|
||||
@@ -393,9 +625,7 @@ $ dream 'sharing a sunset at the summit of mount everest with my dog'
|
||||
|
||||
Once built, images will be saved to the same directory the command is invoked
|
||||
|
||||
## Training wrapper (wip)
|
||||
|
||||
Offer training wrappers
|
||||
<a href="https://github.com/lucidrains/big-sleep">template</a>
|
||||
|
||||
## Training CLI (wip)
|
||||
|
||||
@@ -412,11 +642,19 @@ Offer training wrappers
|
||||
- [x] add efficient attention in unet
|
||||
- [x] be able to finely customize what to condition on (text, image embed) for specific unet in the cascade (super resolution ddpms near the end may not need too much conditioning)
|
||||
- [x] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately)
|
||||
- [ ] build out latent diffusion architecture, with the vq-reg variant (vqgan-vae), make it completely optional
|
||||
- [x] build out latent diffusion architecture, with the vq-reg variant (vqgan-vae), make it completely optional and compatible with cascading ddpms
|
||||
- [x] for decoder, allow ability to customize objective (predict epsilon vs x0), in case latent diffusion does better with prediction of x0
|
||||
- [x] use attention-based upsampling https://arxiv.org/abs/2112.11435
|
||||
- [x] use inheritance just this once for sharing logic between decoder and prior network ddpms
|
||||
- [x] bring in vit-vqgan https://arxiv.org/abs/2110.04627 for the latent diffusion
|
||||
- [x] abstract interface for CLIP adapter class, so other CLIPs can be brought in
|
||||
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
|
||||
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
|
||||
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
|
||||
- [ ] train on a toy task, offer in colab
|
||||
- [ ] think about how best to design a declarative training config that handles preencoding for prior and training of multiple networks in decoder
|
||||
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
|
||||
- [ ] bring in tools to train vqgan-vae
|
||||
|
||||
## Citations
|
||||
|
||||
@@ -448,23 +686,12 @@ Offer training wrappers
|
||||
|
||||
```bibtex
|
||||
@inproceedings{Liu2022ACF,
|
||||
title = {A ConvNet for the 2020s},
|
||||
title = {A ConvNet for the 2020https://arxiv.org/abs/2112.11435s},
|
||||
author = {Zhuang Liu and Hanzi Mao and Chaozheng Wu and Christoph Feichtenhofer and Trevor Darrell and Saining Xie},
|
||||
year = {2022}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{zhang2019root,
|
||||
title = {Root Mean Square Layer Normalization},
|
||||
author = {Biao Zhang and Rico Sennrich},
|
||||
year = {2019},
|
||||
eprint = {1910.07467},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.LG}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@inproceedings{Tu2022MaxViTMV,
|
||||
title = {MaxViT: Multi-Axis Vision Transformer},
|
||||
@@ -473,4 +700,14 @@ Offer training wrappers
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@article{Yu2021VectorquantizedIM,
|
||||
title = {Vector-quantized Image Modeling with Improved VQGAN},
|
||||
author = {Jiahui Yu and Xin Li and Jing Yu Koh and Han Zhang and Ruoming Pang and James Qin and Alexander Ku and Yuanzhong Xu and Jason Baldridge and Yonghui Wu},
|
||||
journal = {ArXiv},
|
||||
year = {2021},
|
||||
volume = {abs/2110.04627}
|
||||
}
|
||||
```
|
||||
|
||||
*Creating noise from data is easy; creating data from noise is generative modeling.* - Yang Song's <a href="https://arxiv.org/abs/2011.13456">paper</a>
|
||||
|
||||
@@ -1,9 +1,51 @@
|
||||
import click
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
from pathlib import Path
|
||||
|
||||
from dalle2_pytorch import DALLE2, Decoder, DiffusionPrior
|
||||
|
||||
def safeget(dictionary, keys, default = None):
|
||||
return reduce(lambda d, key: d.get(key, default) if isinstance(d, dict) else default, keys.split('.'), dictionary)
|
||||
|
||||
def simple_slugify(text, max_length = 255):
|
||||
return text.replace("-", "_").replace(",", "").replace(" ", "_").replace("|", "--").strip('-_')[:max_length]
|
||||
|
||||
def get_pkg_version():
|
||||
from pkg_resources import get_distribution
|
||||
return get_distribution('dalle2_pytorch').version
|
||||
|
||||
def main():
|
||||
pass
|
||||
|
||||
@click.command()
|
||||
@click.option('--model', default = './dalle2.pt', help = 'path to trained DALL-E2 model')
|
||||
@click.option('--cond_scale', default = 2, help = 'conditioning scale (classifier free guidance) in decoder')
|
||||
@click.argument('text')
|
||||
def dream(text):
|
||||
return 'not ready yet'
|
||||
def dream(
|
||||
model,
|
||||
cond_scale,
|
||||
text
|
||||
):
|
||||
model_path = Path(model)
|
||||
full_model_path = str(model_path.resolve())
|
||||
assert model_path.exists(), f'model not found at {full_model_path}'
|
||||
loaded = torch.load(str(model_path))
|
||||
|
||||
version = safeget(loaded, 'version')
|
||||
print(f'loading DALL-E2 from {full_model_path}, saved at version {version} - current package version is {get_pkg_version()}')
|
||||
|
||||
prior_init_params = safeget(loaded, 'init_params.prior')
|
||||
decoder_init_params = safeget(loaded, 'init_params.decoder')
|
||||
model_params = safeget(loaded, 'model_params')
|
||||
|
||||
prior = DiffusionPrior(**prior_init_params)
|
||||
decoder = Decoder(**decoder_init_params)
|
||||
|
||||
dalle2 = DALLE2(prior, decoder)
|
||||
dalle2.load_state_dict(model_params)
|
||||
|
||||
image = dalle2(text, cond_scale = cond_scale)
|
||||
|
||||
pil_image = T.ToPILImage()(image)
|
||||
return pil_image.save(f'./{simple_slugify(text)}.png')
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,53 @@
|
||||
import copy
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
# exponential moving average wrapper
|
||||
|
||||
class EMA(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
beta = 0.99,
|
||||
ema_update_after_step = 1000,
|
||||
ema_update_every = 10,
|
||||
):
|
||||
super().__init__()
|
||||
self.beta = beta
|
||||
self.online_model = model
|
||||
self.ema_model = copy.deepcopy(model)
|
||||
|
||||
self.ema_update_after_step = ema_update_after_step # only start EMA after this step number, starting at 0
|
||||
self.ema_update_every = ema_update_every
|
||||
|
||||
self.register_buffer('initted', torch.Tensor([False]))
|
||||
self.register_buffer('step', torch.tensor([0.]))
|
||||
|
||||
def update(self):
|
||||
self.step += 1
|
||||
|
||||
if self.step <= self.ema_update_after_step or (self.step % self.ema_update_every) != 0:
|
||||
return
|
||||
|
||||
if not self.initted:
|
||||
self.ema_model.state_dict(self.online_model.state_dict())
|
||||
self.initted.data.copy_(torch.Tensor([True]))
|
||||
|
||||
self.update_moving_average(self.ema_model, self.online_model)
|
||||
|
||||
def update_moving_average(self, ma_model, current_model):
|
||||
def calculate_ema(beta, old, new):
|
||||
if not exists(old):
|
||||
return new
|
||||
return old * beta + (1 - beta) * new
|
||||
|
||||
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
|
||||
old_weight, up_weight = ma_params.data, current_params.data
|
||||
ma_params.data = calculate_ema(self.beta, old_weight, up_weight)
|
||||
|
||||
for current_buffer, ma_buffer in zip(current_model.buffers(), ma_model.buffers()):
|
||||
new_buffer_value = calculate_ema(self.beta, ma_buffer, current_buffer)
|
||||
ma_buffer.copy_(new_buffer_value)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.ema_model(*args, **kwargs)
|
||||
|
||||
@@ -12,6 +12,8 @@ from torch.autograd import grad as torch_grad
|
||||
import torchvision
|
||||
|
||||
from einops import rearrange, reduce, repeat
|
||||
from einops_exts import rearrange_many
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# constants
|
||||
|
||||
@@ -144,6 +146,8 @@ class LayerNormChan(nn.Module):
|
||||
mean = torch.mean(x, dim = 1, keepdim = True)
|
||||
return (x - mean) / (var + self.eps).sqrt() * self.gamma
|
||||
|
||||
# discriminator
|
||||
|
||||
class Discriminator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -177,6 +181,8 @@ class Discriminator(nn.Module):
|
||||
|
||||
return self.to_logits(x)
|
||||
|
||||
# positional encoding
|
||||
|
||||
class ContinuousPositionBias(nn.Module):
|
||||
""" from https://arxiv.org/abs/2111.09883 """
|
||||
|
||||
@@ -211,6 +217,84 @@ class ContinuousPositionBias(nn.Module):
|
||||
bias = rearrange(rel_pos, 'i j h -> h i j')
|
||||
return x + bias
|
||||
|
||||
# resnet encoder / decoder
|
||||
|
||||
class ResnetEncDec(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
*,
|
||||
channels = 3,
|
||||
layers = 4,
|
||||
layer_mults = None,
|
||||
num_resnet_blocks = 1,
|
||||
resnet_groups = 16,
|
||||
first_conv_kernel_size = 5,
|
||||
use_attn = True,
|
||||
attn_dim_head = 64,
|
||||
attn_heads = 8,
|
||||
attn_dropout = 0.,
|
||||
):
|
||||
super().__init__()
|
||||
assert dim % resnet_groups == 0, f'dimension {dim} must be divisible by {resnet_groups} (groups for the groupnorm)'
|
||||
|
||||
self.layers = layers
|
||||
|
||||
self.encoders = MList([])
|
||||
self.decoders = MList([])
|
||||
|
||||
layer_mults = default(layer_mults, list(map(lambda t: 2 ** t, range(layers))))
|
||||
assert len(layer_mults) == layers, 'layer multipliers must be equal to designated number of layers'
|
||||
|
||||
layer_dims = [dim * mult for mult in layer_mults]
|
||||
dims = (dim, *layer_dims)
|
||||
|
||||
self.encoded_dim = dims[-1]
|
||||
|
||||
dim_pairs = zip(dims[:-1], dims[1:])
|
||||
|
||||
append = lambda arr, t: arr.append(t)
|
||||
prepend = lambda arr, t: arr.insert(0, t)
|
||||
|
||||
if not isinstance(num_resnet_blocks, tuple):
|
||||
num_resnet_blocks = (*((0,) * (layers - 1)), num_resnet_blocks)
|
||||
|
||||
if not isinstance(use_attn, tuple):
|
||||
use_attn = (*((False,) * (layers - 1)), use_attn)
|
||||
|
||||
assert len(num_resnet_blocks) == layers, 'number of resnet blocks config must be equal to number of layers'
|
||||
assert len(use_attn) == layers
|
||||
|
||||
for layer_index, (dim_in, dim_out), layer_num_resnet_blocks, layer_use_attn in zip(range(layers), dim_pairs, num_resnet_blocks, use_attn):
|
||||
append(self.encoders, nn.Sequential(nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), leaky_relu()))
|
||||
prepend(self.decoders, nn.Sequential(nn.ConvTranspose2d(dim_out, dim_in, 4, 2, 1), leaky_relu()))
|
||||
|
||||
if layer_use_attn:
|
||||
prepend(self.decoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
|
||||
|
||||
for _ in range(layer_num_resnet_blocks):
|
||||
append(self.encoders, ResBlock(dim_out, groups = resnet_groups))
|
||||
prepend(self.decoders, GLUResBlock(dim_out, groups = resnet_groups))
|
||||
|
||||
if layer_use_attn:
|
||||
append(self.encoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
|
||||
|
||||
prepend(self.encoders, nn.Conv2d(channels, dim, first_conv_kernel_size, padding = first_conv_kernel_size // 2))
|
||||
append(self.decoders, nn.Conv2d(dim, channels, 1))
|
||||
|
||||
def get_encoded_fmap_size(self, image_size):
|
||||
return image_size // (2 ** self.layers)
|
||||
|
||||
def encode(self, x):
|
||||
for enc in self.encoders:
|
||||
x = enc(x)
|
||||
return x
|
||||
|
||||
def decode(self, x):
|
||||
for dec in self.decoders:
|
||||
x = dec(x)
|
||||
return x
|
||||
|
||||
class GLUResBlock(nn.Module):
|
||||
def __init__(self, chan, groups = 16):
|
||||
super().__init__()
|
||||
@@ -243,6 +327,8 @@ class ResBlock(nn.Module):
|
||||
def forward(self, x):
|
||||
return self.net(x) + x
|
||||
|
||||
# vqgan attention layer
|
||||
|
||||
class VQGanAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -287,6 +373,167 @@ class VQGanAttention(nn.Module):
|
||||
|
||||
return out + residual
|
||||
|
||||
# ViT encoder / decoder
|
||||
|
||||
class RearrangeImage(nn.Module):
|
||||
def forward(self, x):
|
||||
n = x.shape[1]
|
||||
w = h = int(sqrt(n))
|
||||
return rearrange(x, 'b (h w) ... -> b h w ...', h = h, w = w)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
*,
|
||||
heads = 8,
|
||||
dim_head = 32
|
||||
):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
inner_dim = dim_head * heads
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
self.to_out = nn.Linear(inner_dim, dim)
|
||||
|
||||
def forward(self, x):
|
||||
h = self.heads
|
||||
|
||||
x = self.norm(x)
|
||||
|
||||
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = h)
|
||||
|
||||
q = q * self.scale
|
||||
sim = einsum('b h i d, b h j d -> b h i j', q, k)
|
||||
|
||||
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
|
||||
attn = sim.softmax(dim = -1)
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
def FeedForward(dim, mult = 4):
|
||||
return nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, dim * mult, bias = False),
|
||||
nn.GELU(),
|
||||
nn.Linear(dim * mult, dim, bias = False)
|
||||
)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
*,
|
||||
layers,
|
||||
dim_head = 32,
|
||||
heads = 8,
|
||||
ff_mult = 4
|
||||
):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(layers):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Attention(dim = dim, dim_head = dim_head, heads = heads),
|
||||
FeedForward(dim = dim, mult = ff_mult)
|
||||
]))
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
|
||||
return self.norm(x)
|
||||
|
||||
class ViTEncDec(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
channels = 3,
|
||||
layers = 4,
|
||||
patch_size = 8,
|
||||
dim_head = 32,
|
||||
heads = 8,
|
||||
ff_mult = 4
|
||||
):
|
||||
super().__init__()
|
||||
self.encoded_dim = dim
|
||||
self.patch_size = patch_size
|
||||
|
||||
input_dim = channels * (patch_size ** 2)
|
||||
|
||||
self.encoder = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
|
||||
nn.Linear(input_dim, dim),
|
||||
Transformer(
|
||||
dim = dim,
|
||||
dim_head = dim_head,
|
||||
heads = heads,
|
||||
ff_mult = ff_mult,
|
||||
layers = layers
|
||||
),
|
||||
RearrangeImage(),
|
||||
Rearrange('b h w c -> b c h w')
|
||||
)
|
||||
|
||||
self.decoder = nn.Sequential(
|
||||
Rearrange('b c h w -> b (h w) c'),
|
||||
Transformer(
|
||||
dim = dim,
|
||||
dim_head = dim_head,
|
||||
heads = heads,
|
||||
ff_mult = ff_mult,
|
||||
layers = layers
|
||||
),
|
||||
nn.Sequential(
|
||||
nn.Linear(dim, dim * 4, bias = False),
|
||||
nn.Tanh(),
|
||||
nn.Linear(dim * 4, input_dim, bias = False),
|
||||
),
|
||||
RearrangeImage(),
|
||||
Rearrange('b h w (p1 p2 c) -> b c (h p1) (w p2)', p1 = patch_size, p2 = patch_size)
|
||||
)
|
||||
|
||||
def get_encoded_fmap_size(self, image_size):
|
||||
return image_size // self.patch_size
|
||||
|
||||
def encode(self, x):
|
||||
return self.encoder(x)
|
||||
|
||||
def decode(self, x):
|
||||
return self.decoder(x)
|
||||
|
||||
# main vqgan-vae classes
|
||||
|
||||
class NullVQGanVAE(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
channels
|
||||
):
|
||||
super().__init__()
|
||||
self.encoded_dim = channels
|
||||
self.layers = 0
|
||||
|
||||
def get_encoded_fmap_size(self, size):
|
||||
return size
|
||||
|
||||
def copy_for_eval(self):
|
||||
return self
|
||||
|
||||
def encode(self, x):
|
||||
return x
|
||||
|
||||
def decode(self, x):
|
||||
return x
|
||||
|
||||
class VQGanVAE(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -294,80 +541,44 @@ class VQGanVAE(nn.Module):
|
||||
dim,
|
||||
image_size,
|
||||
channels = 3,
|
||||
num_layers = 4,
|
||||
layer_mults = None,
|
||||
layers = 4,
|
||||
l2_recon_loss = False,
|
||||
use_hinge_loss = True,
|
||||
num_resnet_blocks = 1,
|
||||
vgg = None,
|
||||
vq_codebook_size = 512,
|
||||
vq_decay = 0.8,
|
||||
vq_commitment_weight = 1.,
|
||||
vq_kmeans_init = True,
|
||||
vq_use_cosine_sim = True,
|
||||
use_attn = True,
|
||||
attn_dim_head = 64,
|
||||
attn_heads = 8,
|
||||
resnet_groups = 16,
|
||||
attn_dropout = 0.,
|
||||
first_conv_kernel_size = 5,
|
||||
use_vgg_and_gan = True,
|
||||
vae_type = 'resnet',
|
||||
discr_layers = 4,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
assert dim % resnet_groups == 0, f'dimension {dim} must be divisible by {resnet_groups} (groups for the groupnorm)'
|
||||
|
||||
vq_kwargs, kwargs = groupby_prefix_and_trim('vq_', kwargs)
|
||||
encdec_kwargs, kwargs = groupby_prefix_and_trim('encdec_', kwargs)
|
||||
|
||||
self.image_size = image_size
|
||||
self.channels = channels
|
||||
self.num_layers = num_layers
|
||||
self.fmap_size = image_size // (num_layers ** 2)
|
||||
self.codebook_size = vq_codebook_size
|
||||
|
||||
self.encoders = MList([])
|
||||
self.decoders = MList([])
|
||||
if vae_type == 'resnet':
|
||||
enc_dec_klass = ResnetEncDec
|
||||
elif vae_type == 'vit':
|
||||
enc_dec_klass = ViTEncDec
|
||||
else:
|
||||
raise ValueError(f'{vae_type} not valid')
|
||||
|
||||
layer_mults = default(layer_mults, list(map(lambda t: 2 ** t, range(num_layers))))
|
||||
assert len(layer_mults) == num_layers, 'layer multipliers must be equal to designated number of layers'
|
||||
|
||||
layer_dims = [dim * mult for mult in layer_mults]
|
||||
dims = (dim, *layer_dims)
|
||||
codebook_dim = layer_dims[-1]
|
||||
|
||||
dim_pairs = zip(dims[:-1], dims[1:])
|
||||
|
||||
append = lambda arr, t: arr.append(t)
|
||||
prepend = lambda arr, t: arr.insert(0, t)
|
||||
|
||||
if not isinstance(num_resnet_blocks, tuple):
|
||||
num_resnet_blocks = (*((0,) * (num_layers - 1)), num_resnet_blocks)
|
||||
|
||||
if not isinstance(use_attn, tuple):
|
||||
use_attn = (*((False,) * (num_layers - 1)), use_attn)
|
||||
|
||||
assert len(num_resnet_blocks) == num_layers, 'number of resnet blocks config must be equal to number of layers'
|
||||
assert len(use_attn) == num_layers
|
||||
|
||||
for layer_index, (dim_in, dim_out), layer_num_resnet_blocks, layer_use_attn in zip(range(num_layers), dim_pairs, num_resnet_blocks, use_attn):
|
||||
append(self.encoders, nn.Sequential(nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), leaky_relu()))
|
||||
prepend(self.decoders, nn.Sequential(nn.Upsample(scale_factor = 2, mode = 'bilinear', align_corners = False), nn.Conv2d(dim_out, dim_in, 3, padding = 1), leaky_relu()))
|
||||
|
||||
if layer_use_attn:
|
||||
prepend(self.decoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
|
||||
|
||||
for _ in range(layer_num_resnet_blocks):
|
||||
append(self.encoders, ResBlock(dim_out, groups = resnet_groups))
|
||||
prepend(self.decoders, GLUResBlock(dim_out, groups = resnet_groups))
|
||||
|
||||
if layer_use_attn:
|
||||
append(self.encoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
|
||||
|
||||
prepend(self.encoders, nn.Conv2d(channels, dim, first_conv_kernel_size, padding = first_conv_kernel_size // 2))
|
||||
append(self.decoders, nn.Conv2d(dim, channels, 1))
|
||||
self.enc_dec = enc_dec_klass(
|
||||
dim = dim,
|
||||
channels = channels,
|
||||
layers = layers,
|
||||
**encdec_kwargs
|
||||
)
|
||||
|
||||
self.vq = VQ(
|
||||
dim = codebook_dim,
|
||||
dim = self.enc_dec.encoded_dim,
|
||||
codebook_size = vq_codebook_size,
|
||||
decay = vq_decay,
|
||||
commitment_weight = vq_commitment_weight,
|
||||
@@ -400,11 +611,22 @@ class VQGanVAE(nn.Module):
|
||||
|
||||
# gan related losses
|
||||
|
||||
layer_mults = list(map(lambda t: 2 ** t, range(discr_layers)))
|
||||
layer_dims = [dim * mult for mult in layer_mults]
|
||||
dims = (dim, *layer_dims)
|
||||
|
||||
self.discr = Discriminator(dims = dims, channels = channels)
|
||||
|
||||
self.discr_loss = hinge_discr_loss if use_hinge_loss else bce_discr_loss
|
||||
self.gen_loss = hinge_gen_loss if use_hinge_loss else bce_gen_loss
|
||||
|
||||
@property
|
||||
def encoded_dim(self):
|
||||
return self.enc_dec.encoded_dim
|
||||
|
||||
def get_encoded_fmap_size(self, image_size):
|
||||
return self.enc_dec.get_encoded_fmap_size(image_size)
|
||||
|
||||
def copy_for_eval(self):
|
||||
device = next(self.parameters()).device
|
||||
vae_copy = copy.deepcopy(self.cpu())
|
||||
@@ -429,33 +651,34 @@ class VQGanVAE(nn.Module):
|
||||
return self.vq.codebook
|
||||
|
||||
def encode(self, fmap):
|
||||
for enc in self.encoders:
|
||||
fmap = enc(fmap)
|
||||
|
||||
fmap = self.enc_dec.encode(fmap)
|
||||
return fmap
|
||||
|
||||
def decode(self, fmap):
|
||||
fmap = self.vq(fmap)
|
||||
def decode(self, fmap, return_indices_and_loss = False):
|
||||
fmap, indices, commit_loss = self.vq(fmap)
|
||||
|
||||
for dec in self.decoders:
|
||||
fmap = dec(fmap)
|
||||
fmap = self.enc_dec.decode(fmap)
|
||||
|
||||
return fmap
|
||||
if not return_indices_and_loss:
|
||||
return fmap
|
||||
|
||||
return fmap, indices, commit_loss
|
||||
|
||||
def forward(
|
||||
self,
|
||||
img,
|
||||
return_loss = False,
|
||||
return_discr_loss = False,
|
||||
return_recons = False
|
||||
return_recons = False,
|
||||
add_gradient_penalty = True
|
||||
):
|
||||
batch, channels, height, width, device = *img.shape, img.device
|
||||
assert height == self.image_size and width == self.image_size, 'height and width of input image must be equal to {self.image_size}'
|
||||
assert channels == self.channels, 'number of channels on image or sketch is not equal to the channels set on this VQGanVAE'
|
||||
|
||||
fmap, indices, commit_loss = self.encode(img)
|
||||
fmap = self.encode(img)
|
||||
|
||||
fmap = self.decode(fmap)
|
||||
fmap, indices, commit_loss = self.decode(fmap, return_indices_and_loss = True)
|
||||
|
||||
if not return_loss and not return_discr_loss:
|
||||
return fmap
|
||||
@@ -472,11 +695,11 @@ class VQGanVAE(nn.Module):
|
||||
|
||||
fmap_discr_logits, img_discr_logits = map(self.discr, (fmap, img))
|
||||
|
||||
gp = gradient_penalty(img, img_discr_logits)
|
||||
|
||||
discr_loss = self.discr_loss(fmap_discr_logits, img_discr_logits)
|
||||
|
||||
loss = discr_loss + gp
|
||||
if add_gradient_penalty:
|
||||
gp = gradient_penalty(img, img_discr_logits)
|
||||
loss = discr_loss + gp
|
||||
|
||||
if return_recons:
|
||||
return loss, fmap
|
||||
|
||||
Reference in New Issue
Block a user