Compare commits

..

47 Commits

Author SHA1 Message Date
Phil Wang
fb662a62f3 fix another bug thanks to @xiankgx 2022-04-29 07:38:32 -07:00
Phil Wang
587c8c9b44 optimize for clarity 2022-04-28 21:59:13 -07:00
Phil Wang
aa900213e7 force first unet in the cascade to be conditioned on image embeds 2022-04-28 20:53:15 -07:00
Phil Wang
cb26187450 vqgan-vae codebook dims should be 256 or smaller 2022-04-28 08:59:03 -07:00
Phil Wang
625ce23f6b 🐛 2022-04-28 07:21:18 -07:00
Phil Wang
dbf4a281f1 make sure another CLIP can actually be passed in, as long as it is wrapped in an adapter extended from BaseClipAdapter 2022-04-27 20:45:27 -07:00
Phil Wang
4ab527e779 some extra asserts for text encoding of diffusion prior and decoder 2022-04-27 20:11:43 -07:00
Phil Wang
d0cdeb3247 add ability for DALL-E2 to return PIL images with return_pil_images = True on forward, for those who have no clue about deep learning 2022-04-27 19:58:06 -07:00
Phil Wang
8c610aad9a only pass text encodings conditioning in diffusion prior if specified on initialization 2022-04-27 19:48:16 -07:00
Phil Wang
6700381a37 prepare for ability to integrate other clips other than x-clip 2022-04-27 19:35:05 -07:00
Phil Wang
20377f889a todo 2022-04-27 17:22:14 -07:00
Phil Wang
6edb1c5dd0 fix issue with ema class 2022-04-27 16:40:02 -07:00
Phil Wang
b093f92182 inform what is possible 2022-04-27 08:25:16 -07:00
Phil Wang
fa3bb6ba5c make sure cpu-only still works 2022-04-27 08:02:10 -07:00
Phil Wang
2705e7c9b0 attention-based upsampling claims unsupported by local experiments, removing 2022-04-27 07:51:04 -07:00
Phil Wang
77141882c8 complete vit-vqgan from https://arxiv.org/abs/2110.04627 2022-04-26 17:20:47 -07:00
Phil Wang
4075d02139 nevermind, it could be working, but only when i stabilize it with the feedforward layer + tanh as proposed in vit-vqgan paper (which will be built into the repository later for the latent diffusion) 2022-04-26 12:43:31 -07:00
Phil Wang
de0296106b be able to turn off warning for use of LazyLinear by passing in text embedding dimension for unet 2022-04-26 11:42:46 -07:00
Phil Wang
eafb136214 suppress a warning 2022-04-26 11:40:45 -07:00
Phil Wang
bfbcc283a3 DRY a tiny bit for gaussian diffusion related logic 2022-04-26 11:39:12 -07:00
Phil Wang
c30544b73a no CLIP altogether for training DiffusionPrior 2022-04-26 10:23:41 -07:00
Phil Wang
bdf5e9c009 todo 2022-04-26 09:56:54 -07:00
Phil Wang
9878be760b have researcher explicitly state upfront whether to condition with text encodings in cascading ddpm decoder, have DALLE-2 class take care of passing in text if feature turned on 2022-04-26 09:47:09 -07:00
Phil Wang
7ba6357c05 allow for training the Prior network with precomputed CLIP embeddings (or text encodings) 2022-04-26 09:29:51 -07:00
Phil Wang
76e063e8b7 refactor so that the causal transformer in the diffusion prior network can be conditioned without text encodings (for Laions parallel efforts, although it seems from the paper it is needed) 2022-04-26 09:00:11 -07:00
Phil Wang
4d25976f33 make sure non-latent diffusion still works 2022-04-26 08:36:00 -07:00
Phil Wang
0b28ee0d01 revert back to old upsampling, paper does not work 2022-04-26 07:39:04 -07:00
Phil Wang
45262a4bb7 bring in the exponential moving average wrapper, to get ready for training 2022-04-25 19:24:13 -07:00
Phil Wang
13a58a78c4 scratch off todo 2022-04-25 19:01:30 -07:00
Phil Wang
f75d49c781 start a file for all attention-related modules, use attention-based upsampling in the unets in dalle-2 2022-04-25 18:59:10 -07:00
Phil Wang
3b520dfa85 bring in attention-based upsampling to strengthen vqgan-vae, seems to work as advertised in initial experiments in GAN 2022-04-25 17:27:45 -07:00
Phil Wang
79198c6ae4 keep readme simple for reader 2022-04-25 17:21:45 -07:00
Phil Wang
77a246b1b9 todo 2022-04-25 08:48:28 -07:00
Phil Wang
f93a3f6ed8 reprioritize 2022-04-25 08:44:27 -07:00
Phil Wang
8f2a0c7e00 better naming 2022-04-25 07:44:33 -07:00
Phil Wang
863f4ef243 just take care of the logic for setting all latent diffusion to predict x0, if needed 2022-04-24 10:06:42 -07:00
Phil Wang
fb8a66a2de just in case latent diffusion performs better with prediction of x0 instead of epsilon, open up the research avenue 2022-04-24 10:04:22 -07:00
Phil Wang
579d4b42dd does not seem right to clip for the prior diffusion part 2022-04-24 09:51:18 -07:00
Phil Wang
473808850a some outlines to the eventual CLI endpoint 2022-04-24 09:27:15 -07:00
Phil Wang
d5318aef4f todo 2022-04-23 08:23:08 -07:00
Phil Wang
f82917e1fd prepare for turning off gradient penalty, as shown in GAN literature, GP needs to be only applied 1 out of 4 iterations 2022-04-23 07:52:10 -07:00
Phil Wang
05b74be69a use null container pattern to cleanup some conditionals, save more cleanup for next week 2022-04-22 15:23:18 -07:00
Phil Wang
a8b5d5d753 last tweak of readme 2022-04-22 14:16:43 -07:00
Phil Wang
976ef7f87c project management 2022-04-22 14:15:42 -07:00
Phil Wang
fd175bcc0e readme 2022-04-22 14:13:33 -07:00
Phil Wang
76b32f18b3 first pass at complete DALL-E2 + Latent Diffusion integration, latent diffusion on any layer(s) of the cascading ddpm in the decoder. 2022-04-22 13:53:13 -07:00
Phil Wang
f2d5b87677 todo 2022-04-22 11:39:58 -07:00
6 changed files with 1076 additions and 368 deletions

293
README.md
View File

@@ -10,11 +10,9 @@ The main novelty seems to be an extra layer of indirection with the prior networ
This model is SOTA for text-to-image for now.
It may also explore an extension of using <a href="https://huggingface.co/spaces/multimodalart/latentdiffusion">latent diffusion</a> in the decoder from Rombach et al.
Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord" src="https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white"></a> if you are interested in helping out with the replication
There was enough interest for a Jax version. It will be completed after the Pytorch version shows signs of life on my toy tasks. <a href="https://github.com/lucidrains/dalle2-jax">Placeholder repository</a>. I will also eventually extend this to <a href="https://github.com/lucidrains/dalle2-video">text to video</a>, once the repository is in a good place.
There was enough interest for a <a href="https://github.com/lucidrains/dalle2-jax">Jax version</a>. I will also eventually extend this to <a href="https://github.com/lucidrains/dalle2-video">text to video</a>, once the repository is in a good place.
## Install
@@ -248,13 +246,6 @@ loss = decoder(images, unet_number = 2)
loss.backward()
# do the above for many steps for both unets
# then it will learn to generate images based on the CLIP image embeddings
# chaining the unets from lowest resolution to highest resolution (thus cascading)
mock_image_embed = torch.randn(1, 512).cuda()
images = decoder.sample(mock_image_embed) # (1, 3, 512, 512)
```
Finally, to generate the DALL-E2 images from text. Insert the trained `DiffusionPrior` as well as the `Decoder` (which wraps `CLIP`, the causal transformer, and unet(s))
@@ -357,7 +348,8 @@ decoder = Decoder(
image_sizes = (128, 256),
clip = clip,
timesteps = 100,
cond_drop_prob = 0.2
cond_drop_prob = 0.2,
condition_on_text_encodings = False # set this to True if you wish to condition on text during training and sampling
).cuda()
for unet_number in (1, 2):
@@ -385,7 +377,247 @@ You can also train the decoder on images of greater than the size (say 512x512)
For the layperson, no worries, training will all be automated into a CLI tool, at least for small scale training.
## CLI Usage (work in progress)
## Training on Preprocessed CLIP Embeddings
It is likely, when scaling up, that you would first preprocess your images and text into corresponding embeddings before training the prior network. You can do so easily by simply passing in `image_embed`, `text_embed`, and optionally `text_encodings` and `text_mask`
Working example below
```python
import torch
from dalle2_pytorch import DiffusionPriorNetwork, DiffusionPrior, CLIP
# get trained CLIP from step one
clip = CLIP(
dim_text = 512,
dim_image = 512,
dim_latent = 512,
num_text_tokens = 49408,
text_enc_depth = 6,
text_seq_len = 256,
text_heads = 8,
visual_enc_depth = 6,
visual_image_size = 256,
visual_patch_size = 32,
visual_heads = 8,
).cuda()
# setup prior network, which contains an autoregressive transformer
prior_network = DiffusionPriorNetwork(
dim = 512,
depth = 6,
dim_head = 64,
heads = 8
).cuda()
# diffusion prior network, which contains the CLIP and network (with transformer) above
diffusion_prior = DiffusionPrior(
net = prior_network,
clip = clip,
timesteps = 100,
cond_drop_prob = 0.2,
condition_on_text_encodings = False # this probably should be true, but just to get Laion started
).cuda()
# mock data
text = torch.randint(0, 49408, (4, 256)).cuda()
images = torch.randn(4, 3, 256, 256).cuda()
# precompute the text and image embeddings
# here using the diffusion prior class, but could be done with CLIP alone
clip_image_embeds = diffusion_prior.get_image_embed(images)
clip_text_embeds = diffusion_prior.get_text_cond(text).get('text_embed')
# feed text and images into diffusion prior network
loss = diffusion_prior(
text_embed = clip_text_embeds,
image_embed = clip_image_embeds
)
loss.backward()
# do the above for many many many steps
# now the diffusion prior can generate image embeddings from the text embeddings
```
You can also completely go `CLIP`-less, in which case you will need to pass in the `image_embed_dim` into the `DiffusionPrior` on initialization
```python
import torch
from dalle2_pytorch import DiffusionPriorNetwork, DiffusionPrior
# setup prior network, which contains an autoregressive transformer
prior_network = DiffusionPriorNetwork(
dim = 512,
depth = 6,
dim_head = 64,
heads = 8
).cuda()
# diffusion prior network, which contains the CLIP and network (with transformer) above
diffusion_prior = DiffusionPrior(
net = prior_network,
image_embed_dim = 512, # this needs to be set
timesteps = 100,
cond_drop_prob = 0.2,
condition_on_text_encodings = False # this probably should be true, but just to get Laion started
).cuda()
# mock data
text = torch.randint(0, 49408, (4, 256)).cuda()
images = torch.randn(4, 3, 256, 256).cuda()
# precompute the text and image embeddings
# here using the diffusion prior class, but could be done with CLIP alone
clip_image_embeds = torch.randn(4, 512).cuda()
clip_text_embeds = torch.randn(4, 512).cuda()
# feed text and images into diffusion prior network
loss = diffusion_prior(
text_embed = clip_text_embeds,
image_embed = clip_image_embeds
)
loss.backward()
# do the above for many many many steps
# now the diffusion prior can generate image embeddings from the text embeddings
```
## Experimental
### DALL-E2 with Latent Diffusion
This repository decides to take the next step and offer DALL-E v2 combined with <a href="https://huggingface.co/spaces/multimodalart/latentdiffusion">latent diffusion</a>, from Rombach et al.
You can use it as follows. Latent diffusion can be limited to just the first U-Net in the cascade, or to any number you wish.
The repository also comes equipped with all the necessary settings to recreate `ViT-VQGan` from the <a href="https://arxiv.org/abs/2110.04627">Improved VQGans</a> paper. Furthermore, the <a href="https://github.com/lucidrains/vector-quantize-pytorch">vector quantization</a> library also comes equipped to do <a href="https://arxiv.org/abs/2203.01941">residual or multi-headed quantization</a>, which I believe will give an even further boost in performance to the autoencoder.
```python
import torch
from dalle2_pytorch import Unet, Decoder, CLIP, VQGanVAE
# trained clip from step 1
clip = CLIP(
dim_text = 512,
dim_image = 512,
dim_latent = 512,
num_text_tokens = 49408,
text_enc_depth = 1,
text_seq_len = 256,
text_heads = 8,
visual_enc_depth = 1,
visual_image_size = 256,
visual_patch_size = 32,
visual_heads = 8
)
# 3 unets for the decoder (a la cascading DDPM)
# first two unets are doing latent diffusion
# vqgan-vae must be trained before hand
vae1 = VQGanVAE(
dim = 32,
image_size = 256,
layers = 3,
layer_mults = (1, 2, 4)
)
vae2 = VQGanVAE(
dim = 32,
image_size = 512,
layers = 3,
layer_mults = (1, 2, 4)
)
unet1 = Unet(
dim = 32,
image_embed_dim = 512,
cond_dim = 128,
channels = 3,
sparse_attn = True,
sparse_attn_window = 2,
dim_mults = (1, 2, 4, 8)
)
unet2 = Unet(
dim = 32,
image_embed_dim = 512,
channels = 3,
dim_mults = (1, 2, 4, 8, 16),
cond_on_image_embeds = True,
cond_on_text_encodings = False
)
unet3 = Unet(
dim = 32,
image_embed_dim = 512,
channels = 3,
dim_mults = (1, 2, 4, 8, 16),
cond_on_image_embeds = True,
cond_on_text_encodings = False,
attend_at_middle = False
)
# decoder, which contains the unet(s) and clip
decoder = Decoder(
clip = clip,
vae = (vae1, vae2), # latent diffusion for unet1 (vae1) and unet2 (vae2), but not for the last unet3
unet = (unet1, unet2, unet3), # insert unets in order of low resolution to highest resolution (you can have as many stages as you want here)
image_sizes = (256, 512, 1024), # resolutions, 256 for first unet, 512 for second, 1024 for third
timesteps = 100,
cond_drop_prob = 0.2
).cuda()
# mock images (get a lot of this)
images = torch.randn(1, 3, 1024, 1024).cuda()
# feed images into decoder, specifying which unet you want to train
# each unet can be trained separately, which is one of the benefits of the cascading DDPM scheme
with decoder.one_unet_in_gpu(1):
loss = decoder(images, unet_number = 1)
loss.backward()
with decoder.one_unet_in_gpu(2):
loss = decoder(images, unet_number = 2)
loss.backward()
with decoder.one_unet_in_gpu(3):
loss = decoder(images, unet_number = 3)
loss.backward()
# do the above for many steps for both unets
# then it will learn to generate images based on the CLIP image embeddings
# chaining the unets from lowest resolution to highest resolution (thus cascading)
mock_image_embed = torch.randn(1, 512).cuda()
images = decoder.sample(mock_image_embed) # (1, 3, 1024, 1024)
```
## Training wrapper (wip)
Offer training wrappers
## CLI (wip)
```bash
$ dream 'sharing a sunset at the summit of mount everest with my dog'
@@ -393,9 +625,7 @@ $ dream 'sharing a sunset at the summit of mount everest with my dog'
Once built, images will be saved to the same directory the command is invoked
## Training wrapper (wip)
Offer training wrappers
<a href="https://github.com/lucidrains/big-sleep">template</a>
## Training CLI (wip)
@@ -412,11 +642,19 @@ Offer training wrappers
- [x] add efficient attention in unet
- [x] be able to finely customize what to condition on (text, image embed) for specific unet in the cascade (super resolution ddpms near the end may not need too much conditioning)
- [x] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately)
- [ ] build out latent diffusion architecture, with the vq-reg variant (vqgan-vae), make it completely optional
- [x] build out latent diffusion architecture, with the vq-reg variant (vqgan-vae), make it completely optional and compatible with cascading ddpms
- [x] for decoder, allow ability to customize objective (predict epsilon vs x0), in case latent diffusion does better with prediction of x0
- [x] use attention-based upsampling https://arxiv.org/abs/2112.11435
- [x] use inheritance just this once for sharing logic between decoder and prior network ddpms
- [x] bring in vit-vqgan https://arxiv.org/abs/2110.04627 for the latent diffusion
- [x] abstract interface for CLIP adapter class, so other CLIPs can be brought in
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
- [ ] train on a toy task, offer in colab
- [ ] think about how best to design a declarative training config that handles preencoding for prior and training of multiple networks in decoder
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
- [ ] bring in tools to train vqgan-vae
## Citations
@@ -448,23 +686,12 @@ Offer training wrappers
```bibtex
@inproceedings{Liu2022ACF,
title = {A ConvNet for the 2020s},
title = {A ConvNet for the 2020https://arxiv.org/abs/2112.11435s},
author = {Zhuang Liu and Hanzi Mao and Chaozheng Wu and Christoph Feichtenhofer and Trevor Darrell and Saining Xie},
year = {2022}
}
```
```bibtex
@misc{zhang2019root,
title = {Root Mean Square Layer Normalization},
author = {Biao Zhang and Rico Sennrich},
year = {2019},
eprint = {1910.07467},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
```
```bibtex
@inproceedings{Tu2022MaxViTMV,
title = {MaxViT: Multi-Axis Vision Transformer},
@@ -473,4 +700,14 @@ Offer training wrappers
}
```
```bibtex
@article{Yu2021VectorquantizedIM,
title = {Vector-quantized Image Modeling with Improved VQGAN},
author = {Jiahui Yu and Xin Li and Jing Yu Koh and Han Zhang and Ruoming Pang and James Qin and Alexander Ku and Yuanzhong Xu and Jason Baldridge and Yonghui Wu},
journal = {ArXiv},
year = {2021},
volume = {abs/2110.04627}
}
```
*Creating noise from data is easy; creating data from noise is generative modeling.* - Yang Song's <a href="https://arxiv.org/abs/2011.13456">paper</a>

View File

@@ -1,9 +1,51 @@
import click
import torch
import torchvision.transforms as T
from pathlib import Path
from dalle2_pytorch import DALLE2, Decoder, DiffusionPrior
def safeget(dictionary, keys, default = None):
return reduce(lambda d, key: d.get(key, default) if isinstance(d, dict) else default, keys.split('.'), dictionary)
def simple_slugify(text, max_length = 255):
return text.replace("-", "_").replace(",", "").replace(" ", "_").replace("|", "--").strip('-_')[:max_length]
def get_pkg_version():
from pkg_resources import get_distribution
return get_distribution('dalle2_pytorch').version
def main():
pass
@click.command()
@click.option('--model', default = './dalle2.pt', help = 'path to trained DALL-E2 model')
@click.option('--cond_scale', default = 2, help = 'conditioning scale (classifier free guidance) in decoder')
@click.argument('text')
def dream(text):
return 'not ready yet'
def dream(
model,
cond_scale,
text
):
model_path = Path(model)
full_model_path = str(model_path.resolve())
assert model_path.exists(), f'model not found at {full_model_path}'
loaded = torch.load(str(model_path))
version = safeget(loaded, 'version')
print(f'loading DALL-E2 from {full_model_path}, saved at version {version} - current package version is {get_pkg_version()}')
prior_init_params = safeget(loaded, 'init_params.prior')
decoder_init_params = safeget(loaded, 'init_params.decoder')
model_params = safeget(loaded, 'model_params')
prior = DiffusionPrior(**prior_init_params)
decoder = Decoder(**decoder_init_params)
dalle2 = DALLE2(prior, decoder)
dalle2.load_state_dict(model_params)
image = dalle2(text, cond_scale = cond_scale)
pil_image = T.ToPILImage()(image)
return pil_image.save(f'./{simple_slugify(text)}.png')

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,53 @@
import copy
import torch
from torch import nn
# exponential moving average wrapper
class EMA(nn.Module):
def __init__(
self,
model,
beta = 0.99,
ema_update_after_step = 1000,
ema_update_every = 10,
):
super().__init__()
self.beta = beta
self.online_model = model
self.ema_model = copy.deepcopy(model)
self.ema_update_after_step = ema_update_after_step # only start EMA after this step number, starting at 0
self.ema_update_every = ema_update_every
self.register_buffer('initted', torch.Tensor([False]))
self.register_buffer('step', torch.tensor([0.]))
def update(self):
self.step += 1
if self.step <= self.ema_update_after_step or (self.step % self.ema_update_every) != 0:
return
if not self.initted:
self.ema_model.state_dict(self.online_model.state_dict())
self.initted.data.copy_(torch.Tensor([True]))
self.update_moving_average(self.ema_model, self.online_model)
def update_moving_average(self, ma_model, current_model):
def calculate_ema(beta, old, new):
if not exists(old):
return new
return old * beta + (1 - beta) * new
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
old_weight, up_weight = ma_params.data, current_params.data
ma_params.data = calculate_ema(self.beta, old_weight, up_weight)
for current_buffer, ma_buffer in zip(current_model.buffers(), ma_model.buffers()):
new_buffer_value = calculate_ema(self.beta, ma_buffer, current_buffer)
ma_buffer.copy_(new_buffer_value)
def __call__(self, *args, **kwargs):
return self.ema_model(*args, **kwargs)

View File

@@ -12,6 +12,8 @@ from torch.autograd import grad as torch_grad
import torchvision
from einops import rearrange, reduce, repeat
from einops_exts import rearrange_many
from einops.layers.torch import Rearrange
# constants
@@ -144,6 +146,8 @@ class LayerNormChan(nn.Module):
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (var + self.eps).sqrt() * self.gamma
# discriminator
class Discriminator(nn.Module):
def __init__(
self,
@@ -177,6 +181,8 @@ class Discriminator(nn.Module):
return self.to_logits(x)
# positional encoding
class ContinuousPositionBias(nn.Module):
""" from https://arxiv.org/abs/2111.09883 """
@@ -211,6 +217,84 @@ class ContinuousPositionBias(nn.Module):
bias = rearrange(rel_pos, 'i j h -> h i j')
return x + bias
# resnet encoder / decoder
class ResnetEncDec(nn.Module):
def __init__(
self,
dim,
*,
channels = 3,
layers = 4,
layer_mults = None,
num_resnet_blocks = 1,
resnet_groups = 16,
first_conv_kernel_size = 5,
use_attn = True,
attn_dim_head = 64,
attn_heads = 8,
attn_dropout = 0.,
):
super().__init__()
assert dim % resnet_groups == 0, f'dimension {dim} must be divisible by {resnet_groups} (groups for the groupnorm)'
self.layers = layers
self.encoders = MList([])
self.decoders = MList([])
layer_mults = default(layer_mults, list(map(lambda t: 2 ** t, range(layers))))
assert len(layer_mults) == layers, 'layer multipliers must be equal to designated number of layers'
layer_dims = [dim * mult for mult in layer_mults]
dims = (dim, *layer_dims)
self.encoded_dim = dims[-1]
dim_pairs = zip(dims[:-1], dims[1:])
append = lambda arr, t: arr.append(t)
prepend = lambda arr, t: arr.insert(0, t)
if not isinstance(num_resnet_blocks, tuple):
num_resnet_blocks = (*((0,) * (layers - 1)), num_resnet_blocks)
if not isinstance(use_attn, tuple):
use_attn = (*((False,) * (layers - 1)), use_attn)
assert len(num_resnet_blocks) == layers, 'number of resnet blocks config must be equal to number of layers'
assert len(use_attn) == layers
for layer_index, (dim_in, dim_out), layer_num_resnet_blocks, layer_use_attn in zip(range(layers), dim_pairs, num_resnet_blocks, use_attn):
append(self.encoders, nn.Sequential(nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), leaky_relu()))
prepend(self.decoders, nn.Sequential(nn.ConvTranspose2d(dim_out, dim_in, 4, 2, 1), leaky_relu()))
if layer_use_attn:
prepend(self.decoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
for _ in range(layer_num_resnet_blocks):
append(self.encoders, ResBlock(dim_out, groups = resnet_groups))
prepend(self.decoders, GLUResBlock(dim_out, groups = resnet_groups))
if layer_use_attn:
append(self.encoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
prepend(self.encoders, nn.Conv2d(channels, dim, first_conv_kernel_size, padding = first_conv_kernel_size // 2))
append(self.decoders, nn.Conv2d(dim, channels, 1))
def get_encoded_fmap_size(self, image_size):
return image_size // (2 ** self.layers)
def encode(self, x):
for enc in self.encoders:
x = enc(x)
return x
def decode(self, x):
for dec in self.decoders:
x = dec(x)
return x
class GLUResBlock(nn.Module):
def __init__(self, chan, groups = 16):
super().__init__()
@@ -243,6 +327,8 @@ class ResBlock(nn.Module):
def forward(self, x):
return self.net(x) + x
# vqgan attention layer
class VQGanAttention(nn.Module):
def __init__(
self,
@@ -287,6 +373,167 @@ class VQGanAttention(nn.Module):
return out + residual
# ViT encoder / decoder
class RearrangeImage(nn.Module):
def forward(self, x):
n = x.shape[1]
w = h = int(sqrt(n))
return rearrange(x, 'b (h w) ... -> b h w ...', h = h, w = w)
class Attention(nn.Module):
def __init__(
self,
dim,
*,
heads = 8,
dim_head = 32
):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.heads = heads
self.scale = dim_head ** -0.5
inner_dim = dim_head * heads
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Linear(inner_dim, dim)
def forward(self, x):
h = self.heads
x = self.norm(x)
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = h)
q = q * self.scale
sim = einsum('b h i d, b h j d -> b h i j', q, k)
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
attn = sim.softmax(dim = -1)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
def FeedForward(dim, mult = 4):
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, dim * mult, bias = False),
nn.GELU(),
nn.Linear(dim * mult, dim, bias = False)
)
class Transformer(nn.Module):
def __init__(
self,
dim,
*,
layers,
dim_head = 32,
heads = 8,
ff_mult = 4
):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(layers):
self.layers.append(nn.ModuleList([
Attention(dim = dim, dim_head = dim_head, heads = heads),
FeedForward(dim = dim, mult = ff_mult)
]))
self.norm = nn.LayerNorm(dim)
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return self.norm(x)
class ViTEncDec(nn.Module):
def __init__(
self,
dim,
channels = 3,
layers = 4,
patch_size = 8,
dim_head = 32,
heads = 8,
ff_mult = 4
):
super().__init__()
self.encoded_dim = dim
self.patch_size = patch_size
input_dim = channels * (patch_size ** 2)
self.encoder = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
nn.Linear(input_dim, dim),
Transformer(
dim = dim,
dim_head = dim_head,
heads = heads,
ff_mult = ff_mult,
layers = layers
),
RearrangeImage(),
Rearrange('b h w c -> b c h w')
)
self.decoder = nn.Sequential(
Rearrange('b c h w -> b (h w) c'),
Transformer(
dim = dim,
dim_head = dim_head,
heads = heads,
ff_mult = ff_mult,
layers = layers
),
nn.Sequential(
nn.Linear(dim, dim * 4, bias = False),
nn.Tanh(),
nn.Linear(dim * 4, input_dim, bias = False),
),
RearrangeImage(),
Rearrange('b h w (p1 p2 c) -> b c (h p1) (w p2)', p1 = patch_size, p2 = patch_size)
)
def get_encoded_fmap_size(self, image_size):
return image_size // self.patch_size
def encode(self, x):
return self.encoder(x)
def decode(self, x):
return self.decoder(x)
# main vqgan-vae classes
class NullVQGanVAE(nn.Module):
def __init__(
self,
*,
channels
):
super().__init__()
self.encoded_dim = channels
self.layers = 0
def get_encoded_fmap_size(self, size):
return size
def copy_for_eval(self):
return self
def encode(self, x):
return x
def decode(self, x):
return x
class VQGanVAE(nn.Module):
def __init__(
self,
@@ -294,80 +541,46 @@ class VQGanVAE(nn.Module):
dim,
image_size,
channels = 3,
num_layers = 4,
layer_mults = None,
layers = 4,
l2_recon_loss = False,
use_hinge_loss = True,
num_resnet_blocks = 1,
vgg = None,
vq_codebook_dim = 256,
vq_codebook_size = 512,
vq_decay = 0.8,
vq_commitment_weight = 1.,
vq_kmeans_init = True,
vq_use_cosine_sim = True,
use_attn = True,
attn_dim_head = 64,
attn_heads = 8,
resnet_groups = 16,
attn_dropout = 0.,
first_conv_kernel_size = 5,
use_vgg_and_gan = True,
vae_type = 'resnet',
discr_layers = 4,
**kwargs
):
super().__init__()
assert dim % resnet_groups == 0, f'dimension {dim} must be divisible by {resnet_groups} (groups for the groupnorm)'
vq_kwargs, kwargs = groupby_prefix_and_trim('vq_', kwargs)
encdec_kwargs, kwargs = groupby_prefix_and_trim('encdec_', kwargs)
self.image_size = image_size
self.channels = channels
self.num_layers = num_layers
self.fmap_size = image_size // (num_layers ** 2)
self.codebook_size = vq_codebook_size
self.encoders = MList([])
self.decoders = MList([])
if vae_type == 'resnet':
enc_dec_klass = ResnetEncDec
elif vae_type == 'vit':
enc_dec_klass = ViTEncDec
else:
raise ValueError(f'{vae_type} not valid')
layer_mults = default(layer_mults, list(map(lambda t: 2 ** t, range(num_layers))))
assert len(layer_mults) == num_layers, 'layer multipliers must be equal to designated number of layers'
layer_dims = [dim * mult for mult in layer_mults]
dims = (dim, *layer_dims)
codebook_dim = layer_dims[-1]
dim_pairs = zip(dims[:-1], dims[1:])
append = lambda arr, t: arr.append(t)
prepend = lambda arr, t: arr.insert(0, t)
if not isinstance(num_resnet_blocks, tuple):
num_resnet_blocks = (*((0,) * (num_layers - 1)), num_resnet_blocks)
if not isinstance(use_attn, tuple):
use_attn = (*((False,) * (num_layers - 1)), use_attn)
assert len(num_resnet_blocks) == num_layers, 'number of resnet blocks config must be equal to number of layers'
assert len(use_attn) == num_layers
for layer_index, (dim_in, dim_out), layer_num_resnet_blocks, layer_use_attn in zip(range(num_layers), dim_pairs, num_resnet_blocks, use_attn):
append(self.encoders, nn.Sequential(nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), leaky_relu()))
prepend(self.decoders, nn.Sequential(nn.Upsample(scale_factor = 2, mode = 'bilinear', align_corners = False), nn.Conv2d(dim_out, dim_in, 3, padding = 1), leaky_relu()))
if layer_use_attn:
prepend(self.decoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
for _ in range(layer_num_resnet_blocks):
append(self.encoders, ResBlock(dim_out, groups = resnet_groups))
prepend(self.decoders, GLUResBlock(dim_out, groups = resnet_groups))
if layer_use_attn:
append(self.encoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
prepend(self.encoders, nn.Conv2d(channels, dim, first_conv_kernel_size, padding = first_conv_kernel_size // 2))
append(self.decoders, nn.Conv2d(dim, channels, 1))
self.enc_dec = enc_dec_klass(
dim = dim,
channels = channels,
layers = layers,
**encdec_kwargs
)
self.vq = VQ(
dim = codebook_dim,
dim = self.enc_dec.encoded_dim,
codebook_dim = vq_codebook_dim,
codebook_size = vq_codebook_size,
decay = vq_decay,
commitment_weight = vq_commitment_weight,
@@ -400,11 +613,22 @@ class VQGanVAE(nn.Module):
# gan related losses
layer_mults = list(map(lambda t: 2 ** t, range(discr_layers)))
layer_dims = [dim * mult for mult in layer_mults]
dims = (dim, *layer_dims)
self.discr = Discriminator(dims = dims, channels = channels)
self.discr_loss = hinge_discr_loss if use_hinge_loss else bce_discr_loss
self.gen_loss = hinge_gen_loss if use_hinge_loss else bce_gen_loss
@property
def encoded_dim(self):
return self.enc_dec.encoded_dim
def get_encoded_fmap_size(self, image_size):
return self.enc_dec.get_encoded_fmap_size(image_size)
def copy_for_eval(self):
device = next(self.parameters()).device
vae_copy = copy.deepcopy(self.cpu())
@@ -429,16 +653,16 @@ class VQGanVAE(nn.Module):
return self.vq.codebook
def encode(self, fmap):
for enc in self.encoders:
fmap = enc(fmap)
fmap = self.enc_dec.encode(fmap)
return fmap
def decode(self, fmap):
def decode(self, fmap, return_indices_and_loss = False):
fmap, indices, commit_loss = self.vq(fmap)
for dec in self.decoders:
fmap = dec(fmap)
fmap = self.enc_dec.decode(fmap)
if not return_indices_and_loss:
return fmap
return fmap, indices, commit_loss
@@ -447,7 +671,8 @@ class VQGanVAE(nn.Module):
img,
return_loss = False,
return_discr_loss = False,
return_recons = False
return_recons = False,
add_gradient_penalty = True
):
batch, channels, height, width, device = *img.shape, img.device
assert height == self.image_size and width == self.image_size, 'height and width of input image must be equal to {self.image_size}'
@@ -455,7 +680,7 @@ class VQGanVAE(nn.Module):
fmap = self.encode(img)
fmap, indices, commit_loss = self.decode(fmap)
fmap, indices, commit_loss = self.decode(fmap, return_indices_and_loss = True)
if not return_loss and not return_discr_loss:
return fmap
@@ -472,11 +697,11 @@ class VQGanVAE(nn.Module):
fmap_discr_logits, img_discr_logits = map(self.discr, (fmap, img))
gp = gradient_penalty(img, img_discr_logits)
discr_loss = self.discr_loss(fmap_discr_logits, img_discr_logits)
loss = discr_loss + gp
if add_gradient_penalty:
gp = gradient_penalty(img, img_discr_logits)
loss = discr_loss + gp
if return_recons:
return loss, fmap

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.36',
version = '0.0.65',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',