mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 11:34:29 +01:00
Compare commits
75 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fb8a66a2de | ||
|
|
579d4b42dd | ||
|
|
473808850a | ||
|
|
d5318aef4f | ||
|
|
f82917e1fd | ||
|
|
05b74be69a | ||
|
|
a8b5d5d753 | ||
|
|
976ef7f87c | ||
|
|
fd175bcc0e | ||
|
|
76b32f18b3 | ||
|
|
f2d5b87677 | ||
|
|
461347c171 | ||
|
|
46cef31c86 | ||
|
|
59b1a77d4d | ||
|
|
7f338319fd | ||
|
|
2c6c91829d | ||
|
|
ad17c69ab6 | ||
|
|
0b4ec34efb | ||
|
|
f027b82e38 | ||
|
|
8cc9016cb0 | ||
|
|
1d8f37befe | ||
|
|
faebf4c8b8 | ||
|
|
b8e8d3c164 | ||
|
|
8e2416b49b | ||
|
|
f37c26e856 | ||
|
|
27a33e1b20 | ||
|
|
6f941a219a | ||
|
|
ddde8ca1bf | ||
|
|
c26b77ad20 | ||
|
|
c5b4aab8e5 | ||
|
|
a35c309b5f | ||
|
|
55bdcb98b9 | ||
|
|
82328f16cd | ||
|
|
6fee4fce6e | ||
|
|
a54e309269 | ||
|
|
c6bfd7fdc8 | ||
|
|
960a79857b | ||
|
|
7214df472d | ||
|
|
00ae50999b | ||
|
|
6cddefad26 | ||
|
|
0332eaa6ff | ||
|
|
1cce4225eb | ||
|
|
5ab0700bab | ||
|
|
b0f2fbaa95 | ||
|
|
51361c2d15 | ||
|
|
42d6e47387 | ||
|
|
1e939153fb | ||
|
|
1abeb8918e | ||
|
|
b423855483 | ||
|
|
c400d8758c | ||
|
|
bece206699 | ||
|
|
5b4ee09625 | ||
|
|
6e27f617f1 | ||
|
|
9f55c24db6 | ||
|
|
69e822b7f8 | ||
|
|
23c401a5d5 | ||
|
|
68e9883f59 | ||
|
|
95b018374a | ||
|
|
8b5c2385b0 | ||
|
|
f2c52d8239 | ||
|
|
97e951221b | ||
|
|
e1b0c140f1 | ||
|
|
5989569a44 | ||
|
|
82464d7bd3 | ||
|
|
7fb3f695d5 | ||
|
|
7e93b9d3c8 | ||
|
|
4c827ba94f | ||
|
|
cb3923a90f | ||
|
|
cc30676a3f | ||
|
|
c7fb327618 | ||
|
|
14ddbc159c | ||
|
|
0692f1699f | ||
|
|
26c4534bc3 | ||
|
|
5e06cde4cb | ||
|
|
a1a8a78f21 |
310
README.md
310
README.md
@@ -1,20 +1,18 @@
|
||||
<img src="./dalle2.png" width="450px"></img>
|
||||
|
||||
## DALL-E 2 - Pytorch (wip)
|
||||
## DALL-E 2 - Pytorch
|
||||
|
||||
Implementation of <a href="https://openai.com/dall-e-2/">DALL-E 2</a>, OpenAI's updated text-to-image synthesis neural network, in Pytorch. <a href="https://youtu.be/RJwPN4qNi_Y?t=555">Yannic Kilcher summary</a>
|
||||
Implementation of <a href="https://openai.com/dall-e-2/">DALL-E 2</a>, OpenAI's updated text-to-image synthesis neural network, in Pytorch.
|
||||
|
||||
<a href="https://youtu.be/RJwPN4qNi_Y?t=555">Yannic Kilcher summary</a> | <a href="https://www.youtube.com/watch?v=F1X4fHzF4mQ">AssemblyAI explainer</a>
|
||||
|
||||
The main novelty seems to be an extra layer of indirection with the prior network (whether it is an autoregressive transformer or a diffusion network), which predicts an image embedding based on the text embedding from CLIP. Specifically, this repository will only build out the diffusion prior network, as it is the best performing variant (but which incidentally involves a causal transformer as the denoising network 😂)
|
||||
|
||||
This model is SOTA for text-to-image for now.
|
||||
|
||||
It may also explore an extension of using <a href="https://huggingface.co/spaces/multimodalart/latentdiffusion">latent diffusion</a> in the decoder from Rombach et al.
|
||||
|
||||
Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord" src="https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white"></a> if you are interested in helping out with the replication
|
||||
|
||||
Do let me know if anyone is interested in a Jax version https://github.com/lucidrains/DALLE2-pytorch/discussions/8
|
||||
|
||||
For all of you emailing me (there is a lot), the best way to contribute is through pull requests. Everything is open sourced after all. All my thoughts are public. This is your moment to participate.
|
||||
There was enough interest for a Jax version. It will be completed after the Pytorch version shows signs of life on my toy tasks. <a href="https://github.com/lucidrains/dalle2-jax">Placeholder repository</a>. I will also eventually extend this to <a href="https://github.com/lucidrains/dalle2-video">text to video</a>, once the repository is in a good place.
|
||||
|
||||
## Install
|
||||
|
||||
@@ -22,19 +20,11 @@ For all of you emailing me (there is a lot), the best way to contribute is throu
|
||||
$ pip install dalle2-pytorch
|
||||
```
|
||||
|
||||
## CLI Usage (work in progress)
|
||||
|
||||
```bash
|
||||
$ dream 'sharing a sunset at the summit of mount everest with my dog'
|
||||
```
|
||||
|
||||
Once built, images will be saved to the same directory the command is invoked
|
||||
|
||||
## Training (for deep learning practitioners)
|
||||
## Usage
|
||||
|
||||
To train DALLE-2 is a 3 step process, with the training of CLIP being the most important
|
||||
|
||||
To train CLIP, you can either use `x-clip` package, or join the LAION discord, where a lot of replication efforts are already underway.
|
||||
To train CLIP, you can either use <a href="https://github.com/lucidrains/x-clip">x-clip</a> package, or join the LAION discord, where a lot of replication efforts are already <a href="https://github.com/mlfoundations/open_clip">underway</a>.
|
||||
|
||||
This repository will demonstrate integration with `x-clip` for starters
|
||||
|
||||
@@ -109,7 +99,7 @@ clip = CLIP(
|
||||
unet = Unet(
|
||||
dim = 128,
|
||||
image_embed_dim = 512,
|
||||
time_dim = 128,
|
||||
cond_dim = 128,
|
||||
channels = 3,
|
||||
dim_mults=(1, 2, 4, 8)
|
||||
).cuda()
|
||||
@@ -117,7 +107,7 @@ unet = Unet(
|
||||
# decoder, which contains the unet and clip
|
||||
|
||||
decoder = Decoder(
|
||||
net = unet,
|
||||
unet = unet,
|
||||
clip = clip,
|
||||
timesteps = 100,
|
||||
cond_drop_prob = 0.2
|
||||
@@ -136,12 +126,14 @@ loss.backward()
|
||||
# then it will learn to generate images based on the CLIP image embeddings
|
||||
```
|
||||
|
||||
Finally, the main contribution of the paper. The repository offers the diffusion prior network. It takes the CLIP text embeddings and tries to generate the CLIP image embeddings. Again, you will need the trained CLIP fron the first step
|
||||
Finally, the main contribution of the paper. The repository offers the diffusion prior network. It takes the CLIP text embeddings and tries to generate the CLIP image embeddings. Again, you will need the trained CLIP from the first step
|
||||
|
||||
```python
|
||||
import torch
|
||||
from dalle2_pytorch import DiffusionPriorNetwork, DiffusionPrior, CLIP
|
||||
|
||||
# get trained CLIP from step one
|
||||
|
||||
clip = CLIP(
|
||||
dim_text = 512,
|
||||
dim_image = 512,
|
||||
@@ -160,7 +152,6 @@ clip = CLIP(
|
||||
|
||||
prior_network = DiffusionPriorNetwork(
|
||||
dim = 512,
|
||||
num_timesteps = 100,
|
||||
depth = 6,
|
||||
dim_head = 64,
|
||||
heads = 8
|
||||
@@ -189,7 +180,82 @@ loss.backward()
|
||||
# now the diffusion prior can generate image embeddings from the text embeddings
|
||||
```
|
||||
|
||||
Finally, to generate the DALL-E2 images from text. Insert the trained `DiffusionPrior` as well as the `Decoder` (which both contains `CLIP`, a unet, and a causal transformer)
|
||||
In the paper, they actually used a <a href="https://cascaded-diffusion.github.io/">recently discovered technique</a>, from <a href="http://www.jonathanho.me/">Jonathan Ho</a> himself (original author of DDPMs, the core technique used in DALL-E v2) for high resolution image synthesis.
|
||||
|
||||
This can easily be used within this framework as so
|
||||
|
||||
```python
|
||||
import torch
|
||||
from dalle2_pytorch import Unet, Decoder, CLIP
|
||||
|
||||
# trained clip from step 1
|
||||
|
||||
clip = CLIP(
|
||||
dim_text = 512,
|
||||
dim_image = 512,
|
||||
dim_latent = 512,
|
||||
num_text_tokens = 49408,
|
||||
text_enc_depth = 6,
|
||||
text_seq_len = 256,
|
||||
text_heads = 8,
|
||||
visual_enc_depth = 6,
|
||||
visual_image_size = 256,
|
||||
visual_patch_size = 32,
|
||||
visual_heads = 8
|
||||
).cuda()
|
||||
|
||||
# 2 unets for the decoder (a la cascading DDPM)
|
||||
|
||||
unet1 = Unet(
|
||||
dim = 32,
|
||||
image_embed_dim = 512,
|
||||
cond_dim = 128,
|
||||
channels = 3,
|
||||
dim_mults = (1, 2, 4, 8)
|
||||
).cuda()
|
||||
|
||||
unet2 = Unet(
|
||||
dim = 32,
|
||||
image_embed_dim = 512,
|
||||
cond_dim = 128,
|
||||
channels = 3,
|
||||
dim_mults = (1, 2, 4, 8, 16)
|
||||
).cuda()
|
||||
|
||||
# decoder, which contains the unet(s) and clip
|
||||
|
||||
decoder = Decoder(
|
||||
clip = clip,
|
||||
unet = (unet1, unet2), # insert both unets in order of low resolution to highest resolution (you can have as many stages as you want here)
|
||||
image_sizes = (256, 512), # resolutions, 256 for first unet, 512 for second. these must be unique and in ascending order (matches with the unets passed in)
|
||||
timesteps = 1000,
|
||||
cond_drop_prob = 0.2
|
||||
).cuda()
|
||||
|
||||
# mock images (get a lot of this)
|
||||
|
||||
images = torch.randn(4, 3, 512, 512).cuda()
|
||||
|
||||
# feed images into decoder, specifying which unet you want to train
|
||||
# each unet can be trained separately, which is one of the benefits of the cascading DDPM scheme
|
||||
|
||||
loss = decoder(images, unet_number = 1)
|
||||
loss.backward()
|
||||
|
||||
loss = decoder(images, unet_number = 2)
|
||||
loss.backward()
|
||||
|
||||
# do the above for many steps for both unets
|
||||
|
||||
# then it will learn to generate images based on the CLIP image embeddings
|
||||
|
||||
# chaining the unets from lowest resolution to highest resolution (thus cascading)
|
||||
|
||||
mock_image_embed = torch.randn(1, 512).cuda()
|
||||
images = decoder.sample(mock_image_embed) # (1, 3, 512, 512)
|
||||
```
|
||||
|
||||
Finally, to generate the DALL-E2 images from text. Insert the trained `DiffusionPrior` as well as the `Decoder` (which wraps `CLIP`, the causal transformer, and unet(s))
|
||||
|
||||
```python
|
||||
from dalle2_pytorch import DALLE2
|
||||
@@ -199,7 +265,7 @@ dalle2 = DALLE2(
|
||||
decoder = decoder
|
||||
)
|
||||
|
||||
# send the text as a string if you want to use the simple tokenizer from DALL-E1
|
||||
# send the text as a string if you want to use the simple tokenizer from DALLE v1
|
||||
# or you can do it as token ids, if you have your own tokenizer
|
||||
|
||||
texts = ['glistening morning dew on a flower petal']
|
||||
@@ -212,10 +278,7 @@ Let's see the whole script below
|
||||
|
||||
```python
|
||||
import torch
|
||||
from dalle2_pytorch.dalle2_pytorch import DALLE2
|
||||
from dalle2_pytorch import DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder, CLIP
|
||||
|
||||
import torch
|
||||
from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder, CLIP
|
||||
|
||||
clip = CLIP(
|
||||
dim_text = 512,
|
||||
@@ -252,7 +315,6 @@ loss.backward()
|
||||
|
||||
prior_network = DiffusionPriorNetwork(
|
||||
dim = 512,
|
||||
num_timesteps = 100,
|
||||
depth = 6,
|
||||
dim_head = 64,
|
||||
heads = 8
|
||||
@@ -272,23 +334,33 @@ loss.backward()
|
||||
|
||||
# decoder (with unet)
|
||||
|
||||
unet = Unet(
|
||||
unet1 = Unet(
|
||||
dim = 128,
|
||||
image_embed_dim = 512,
|
||||
time_dim = 128,
|
||||
cond_dim = 128,
|
||||
channels = 3,
|
||||
dim_mults=(1, 2, 4, 8)
|
||||
).cuda()
|
||||
|
||||
unet2 = Unet(
|
||||
dim = 16,
|
||||
image_embed_dim = 512,
|
||||
cond_dim = 128,
|
||||
channels = 3,
|
||||
dim_mults = (1, 2, 4, 8, 16)
|
||||
).cuda()
|
||||
|
||||
decoder = Decoder(
|
||||
net = unet,
|
||||
unet = (unet1, unet2),
|
||||
image_sizes = (128, 256),
|
||||
clip = clip,
|
||||
timesteps = 100,
|
||||
cond_drop_prob = 0.2
|
||||
).cuda()
|
||||
|
||||
loss = decoder(images)
|
||||
loss.backward()
|
||||
for unet_number in (1, 2):
|
||||
loss = decoder(images, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
|
||||
loss.backward()
|
||||
|
||||
# do above for many steps
|
||||
|
||||
@@ -297,13 +369,150 @@ dalle2 = DALLE2(
|
||||
decoder = decoder
|
||||
)
|
||||
|
||||
images = dalle2(['cute puppy chasing after a squirrel'])
|
||||
images = dalle2(
|
||||
['cute puppy chasing after a squirrel'],
|
||||
cond_scale = 2. # classifier free guidance strength (> 1 would strengthen the condition)
|
||||
)
|
||||
|
||||
# save your image
|
||||
# save your image (in this example, of size 256x256)
|
||||
```
|
||||
|
||||
Everything in this readme should run without error
|
||||
|
||||
You can also train the decoder on images of greater than the size (say 512x512) at which CLIP was trained (256x256). The images will be resized to CLIP image resolution for the image embeddings
|
||||
|
||||
For the layperson, no worries, training will all be automated into a CLI tool, at least for small scale training.
|
||||
|
||||
## Experimental
|
||||
|
||||
### DALL-E2 with Latent Diffusion
|
||||
|
||||
This repository decides to take the next step and offer DALL-E2 combined with <a href="https://huggingface.co/spaces/multimodalart/latentdiffusion">latent diffusion</a>, from Rombach et al.
|
||||
|
||||
You can use it as follows. Latent diffusion can be limited to just the first U-Net in the cascade, or to any number you wish.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from dalle2_pytorch import Unet, Decoder, CLIP, VQGanVAE
|
||||
|
||||
# trained clip from step 1
|
||||
|
||||
clip = CLIP(
|
||||
dim_text = 512,
|
||||
dim_image = 512,
|
||||
dim_latent = 512,
|
||||
num_text_tokens = 49408,
|
||||
text_enc_depth = 1,
|
||||
text_seq_len = 256,
|
||||
text_heads = 8,
|
||||
visual_enc_depth = 1,
|
||||
visual_image_size = 256,
|
||||
visual_patch_size = 32,
|
||||
visual_heads = 8
|
||||
)
|
||||
|
||||
# 3 unets for the decoder (a la cascading DDPM)
|
||||
|
||||
# first two unets are doing latent diffusion
|
||||
# vqgan-vae must be trained before hand
|
||||
|
||||
vae1 = VQGanVAE(
|
||||
dim = 32,
|
||||
image_size = 256,
|
||||
layers = 3,
|
||||
layer_mults = (1, 2, 4)
|
||||
)
|
||||
|
||||
vae2 = VQGanVAE(
|
||||
dim = 32,
|
||||
image_size = 512,
|
||||
layers = 3,
|
||||
layer_mults = (1, 2, 4)
|
||||
)
|
||||
|
||||
unet1 = Unet(
|
||||
dim = 32,
|
||||
image_embed_dim = 512,
|
||||
cond_dim = 128,
|
||||
channels = 3,
|
||||
sparse_attn = True,
|
||||
sparse_attn_window = 2,
|
||||
dim_mults = (1, 2, 4, 8)
|
||||
)
|
||||
|
||||
unet2 = Unet(
|
||||
dim = 32,
|
||||
image_embed_dim = 512,
|
||||
channels = 3,
|
||||
dim_mults = (1, 2, 4, 8, 16),
|
||||
cond_on_image_embeds = True,
|
||||
cond_on_text_encodings = False
|
||||
)
|
||||
|
||||
unet3 = Unet(
|
||||
dim = 32,
|
||||
image_embed_dim = 512,
|
||||
channels = 3,
|
||||
dim_mults = (1, 2, 4, 8, 16),
|
||||
cond_on_image_embeds = True,
|
||||
cond_on_text_encodings = False,
|
||||
attend_at_middle = False
|
||||
)
|
||||
|
||||
# decoder, which contains the unet(s) and clip
|
||||
|
||||
decoder = Decoder(
|
||||
clip = clip,
|
||||
vae = (vae1, vae2), # latent diffusion for unet1 (vae1) and unet2 (vae2), but not for the last unet3
|
||||
unet = (unet1, unet2, unet3), # insert unets in order of low resolution to highest resolution (you can have as many stages as you want here)
|
||||
image_sizes = (256, 512, 1024), # resolutions, 256 for first unet, 512 for second, 1024 for third
|
||||
timesteps = 100,
|
||||
cond_drop_prob = 0.2
|
||||
).cuda()
|
||||
|
||||
# mock images (get a lot of this)
|
||||
|
||||
images = torch.randn(1, 3, 1024, 1024).cuda()
|
||||
|
||||
# feed images into decoder, specifying which unet you want to train
|
||||
# each unet can be trained separately, which is one of the benefits of the cascading DDPM scheme
|
||||
|
||||
with decoder.one_unet_in_gpu(1):
|
||||
loss = decoder(images, unet_number = 1)
|
||||
loss.backward()
|
||||
|
||||
with decoder.one_unet_in_gpu(2):
|
||||
loss = decoder(images, unet_number = 2)
|
||||
loss.backward()
|
||||
|
||||
with decoder.one_unet_in_gpu(3):
|
||||
loss = decoder(images, unet_number = 3)
|
||||
loss.backward()
|
||||
|
||||
# do the above for many steps for both unets
|
||||
|
||||
# then it will learn to generate images based on the CLIP image embeddings
|
||||
|
||||
# chaining the unets from lowest resolution to highest resolution (thus cascading)
|
||||
|
||||
mock_image_embed = torch.randn(1, 512).cuda()
|
||||
images = decoder.sample(mock_image_embed) # (1, 3, 1024, 1024)
|
||||
```
|
||||
|
||||
## Training wrapper (wip)
|
||||
|
||||
Offer training wrappers
|
||||
|
||||
## CLI (wip)
|
||||
|
||||
```bash
|
||||
$ dream 'sharing a sunset at the summit of mount everest with my dog'
|
||||
```
|
||||
|
||||
Once built, images will be saved to the same directory the command is invoked
|
||||
|
||||
<a href="https://github.com/lucidrains/big-sleep">template</a>
|
||||
|
||||
## Training CLI (wip)
|
||||
|
||||
<a href="https://github.com/lucidrains/stylegan2-pytorch">template</a>
|
||||
@@ -313,11 +522,21 @@ Everything in this readme should run without error
|
||||
- [x] finish off gaussian diffusion class for latent embedding - allow for prediction of epsilon
|
||||
- [x] add what was proposed in the paper, where DDPM objective for image latent embedding predicts x0 directly (reread vq-diffusion paper and get caught up on that line of work)
|
||||
- [x] make sure it works end to end to produce an output tensor, taking a single gradient step
|
||||
- [ ] augment unet so that it can also be conditioned on text encodings (although in paper they hinted this didn't make much a difference)
|
||||
- [ ] look into Jonathan Ho's cascading DDPM for the decoder, as that seems to be what they are using. get caught up on DDPM literature
|
||||
- [ ] figure out all the current bag of tricks needed to make DDPMs great (starting with the blur trick mentioned in paper)
|
||||
- [x] augment unet so that it can also be conditioned on text encodings (although in paper they hinted this didn't make much a difference)
|
||||
- [x] figure out all the current bag of tricks needed to make DDPMs great (starting with the blur trick mentioned in paper)
|
||||
- [x] build the cascading ddpm by having Decoder class manage multiple unets at different resolutions
|
||||
- [x] add efficient attention in unet
|
||||
- [x] be able to finely customize what to condition on (text, image embed) for specific unet in the cascade (super resolution ddpms near the end may not need too much conditioning)
|
||||
- [x] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately)
|
||||
- [x] build out latent diffusion architecture, with the vq-reg variant (vqgan-vae), make it completely optional and compatible with cascading ddpms
|
||||
- [x] for decoder, allow ability to customize objective (predict epsilon vs x0), in case latent diffusion does better with prediction of x0
|
||||
- [ ] spend one day cleaning up tech debt in decoder
|
||||
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
|
||||
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
|
||||
- [ ] train on a toy task, offer in colab
|
||||
- [ ] add attention to unet - apply some personal tricks with efficient attention
|
||||
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
|
||||
- [ ] bring in tools to train vqgan-vae
|
||||
- [ ] bring in vit-vqgan https://arxiv.org/abs/2110.04627 for the latent diffusion
|
||||
|
||||
## Citations
|
||||
|
||||
@@ -356,12 +575,11 @@ Everything in this readme should run without error
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{zhang2019root,
|
||||
title = {Root Mean Square Layer Normalization},
|
||||
author = {Biao Zhang and Rico Sennrich},
|
||||
year = {2019},
|
||||
eprint = {1910.07467},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.LG}
|
||||
@inproceedings{Tu2022MaxViTMV,
|
||||
title = {MaxViT: Multi-Axis Vision Transformer},
|
||||
author = {Zhe-Wei Tu and Hossein Talebi and Han Zhang and Feng Yang and Peyman Milanfar and Alan Conrad Bovik and Yinxiao Li},
|
||||
year = {2022}
|
||||
}
|
||||
```
|
||||
|
||||
*Creating noise from data is easy; creating data from noise is generative modeling.* - Yang Song's <a href="https://arxiv.org/abs/2011.13456">paper</a>
|
||||
|
||||
@@ -1,2 +1,4 @@
|
||||
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
|
||||
|
||||
from dalle2_pytorch.vqgan_vae import VQGanVAE
|
||||
from x_clip import CLIP
|
||||
|
||||
@@ -1,9 +1,51 @@
|
||||
import click
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
from pathlib import Path
|
||||
|
||||
from dalle2_pytorch import DALLE2, Decoder, DiffusionPrior
|
||||
|
||||
def safeget(dictionary, keys, default = None):
|
||||
return reduce(lambda d, key: d.get(key, default) if isinstance(d, dict) else default, keys.split('.'), dictionary)
|
||||
|
||||
def simple_slugify(text, max_length = 255):
|
||||
return text.replace("-", "_").replace(",", "").replace(" ", "_").replace("|", "--").strip('-_')[:max_length]
|
||||
|
||||
def get_pkg_version():
|
||||
from pkg_resources import get_distribution
|
||||
return get_distribution('dalle2_pytorch').version
|
||||
|
||||
def main():
|
||||
pass
|
||||
|
||||
@click.command()
|
||||
@click.option('--model', default = './dalle2.pt', help = 'path to trained DALL-E2 model')
|
||||
@click.option('--cond_scale', default = 2, help = 'conditioning scale (classifier free guidance) in decoder')
|
||||
@click.argument('text')
|
||||
def dream(text):
|
||||
return image
|
||||
def dream(
|
||||
model,
|
||||
cond_scale,
|
||||
text
|
||||
):
|
||||
model_path = Path(model)
|
||||
full_model_path = str(model_path.resolve())
|
||||
assert model_path.exists(), f'model not found at {full_model_path}'
|
||||
loaded = torch.load(str(model_path))
|
||||
|
||||
version = safeget(loaded, 'version')
|
||||
print(f'loading DALL-E2 from {full_model_path}, saved at version {version} - current package version is {get_pkg_version()}')
|
||||
|
||||
prior_init_params = safeget(loaded, 'init_params.prior')
|
||||
decoder_init_params = safeget(loaded, 'init_params.decoder')
|
||||
model_params = safeget(loaded, 'model_params')
|
||||
|
||||
prior = DiffusionPrior(**prior_init_params)
|
||||
decoder = Decoder(**decoder_init_params)
|
||||
|
||||
dalle2 = DALLE2(prior, decoder)
|
||||
dalle2.load_state_dict(model_params)
|
||||
|
||||
image = dalle2(text, cond_scale = cond_scale)
|
||||
|
||||
pil_image = T.ToPILImage()(image)
|
||||
return pil_image.save(f'./{simple_slugify(text)}.png')
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
563
dalle2_pytorch/vqgan_vae.py
Normal file
563
dalle2_pytorch/vqgan_vae.py
Normal file
@@ -0,0 +1,563 @@
|
||||
import copy
|
||||
import math
|
||||
from math import sqrt
|
||||
from functools import partial, wraps
|
||||
|
||||
from vector_quantize_pytorch import VectorQuantize as VQ
|
||||
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import grad as torch_grad
|
||||
import torchvision
|
||||
|
||||
from einops import rearrange, reduce, repeat
|
||||
|
||||
# constants
|
||||
|
||||
MList = nn.ModuleList
|
||||
|
||||
# helper functions
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
# decorators
|
||||
|
||||
def eval_decorator(fn):
|
||||
def inner(model, *args, **kwargs):
|
||||
was_training = model.training
|
||||
model.eval()
|
||||
out = fn(model, *args, **kwargs)
|
||||
model.train(was_training)
|
||||
return out
|
||||
return inner
|
||||
|
||||
def remove_vgg(fn):
|
||||
@wraps(fn)
|
||||
def inner(self, *args, **kwargs):
|
||||
has_vgg = hasattr(self, 'vgg')
|
||||
if has_vgg:
|
||||
vgg = self.vgg
|
||||
delattr(self, 'vgg')
|
||||
|
||||
out = fn(self, *args, **kwargs)
|
||||
|
||||
if has_vgg:
|
||||
self.vgg = vgg
|
||||
|
||||
return out
|
||||
return inner
|
||||
|
||||
# keyword argument helpers
|
||||
|
||||
def pick_and_pop(keys, d):
|
||||
values = list(map(lambda key: d.pop(key), keys))
|
||||
return dict(zip(keys, values))
|
||||
|
||||
def group_dict_by_key(cond, d):
|
||||
return_val = [dict(),dict()]
|
||||
for key in d.keys():
|
||||
match = bool(cond(key))
|
||||
ind = int(not match)
|
||||
return_val[ind][key] = d[key]
|
||||
return (*return_val,)
|
||||
|
||||
def string_begins_with(prefix, str):
|
||||
return str.startswith(prefix)
|
||||
|
||||
def group_by_key_prefix(prefix, d):
|
||||
return group_dict_by_key(partial(string_begins_with, prefix), d)
|
||||
|
||||
def groupby_prefix_and_trim(prefix, d):
|
||||
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
|
||||
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
|
||||
return kwargs_without_prefix, kwargs
|
||||
|
||||
# tensor helper functions
|
||||
|
||||
def log(t, eps = 1e-10):
|
||||
return torch.log(t + eps)
|
||||
|
||||
def gradient_penalty(images, output, weight = 10):
|
||||
batch_size = images.shape[0]
|
||||
gradients = torch_grad(outputs = output, inputs = images,
|
||||
grad_outputs = torch.ones(output.size(), device = images.device),
|
||||
create_graph = True, retain_graph = True, only_inputs = True)[0]
|
||||
|
||||
gradients = rearrange(gradients, 'b ... -> b (...)')
|
||||
return weight * ((gradients.norm(2, dim = 1) - 1) ** 2).mean()
|
||||
|
||||
def l2norm(t):
|
||||
return F.normalize(t, dim = -1)
|
||||
|
||||
def leaky_relu(p = 0.1):
|
||||
return nn.LeakyReLU(0.1)
|
||||
|
||||
def stable_softmax(t, dim = -1, alpha = 32 ** 2):
|
||||
t = t / alpha
|
||||
t = t - torch.amax(t, dim = dim, keepdim = True).detach()
|
||||
return (t * alpha).softmax(dim = dim)
|
||||
|
||||
def safe_div(numer, denom, eps = 1e-8):
|
||||
return numer / (denom + eps)
|
||||
|
||||
# gan losses
|
||||
|
||||
def hinge_discr_loss(fake, real):
|
||||
return (F.relu(1 + fake) + F.relu(1 - real)).mean()
|
||||
|
||||
def hinge_gen_loss(fake):
|
||||
return -fake.mean()
|
||||
|
||||
def bce_discr_loss(fake, real):
|
||||
return (-log(1 - torch.sigmoid(fake)) - log(torch.sigmoid(real))).mean()
|
||||
|
||||
def bce_gen_loss(fake):
|
||||
return -log(torch.sigmoid(fake)).mean()
|
||||
|
||||
def grad_layer_wrt_loss(loss, layer):
|
||||
return torch_grad(
|
||||
outputs = loss,
|
||||
inputs = layer,
|
||||
grad_outputs = torch.ones_like(loss),
|
||||
retain_graph = True
|
||||
)[0].detach()
|
||||
|
||||
# vqgan vae
|
||||
|
||||
class LayerNormChan(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
eps = 1e-5
|
||||
):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1))
|
||||
|
||||
def forward(self, x):
|
||||
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
|
||||
mean = torch.mean(x, dim = 1, keepdim = True)
|
||||
return (x - mean) / (var + self.eps).sqrt() * self.gamma
|
||||
|
||||
class Discriminator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dims,
|
||||
channels = 3,
|
||||
groups = 16,
|
||||
init_kernel_size = 5
|
||||
):
|
||||
super().__init__()
|
||||
dim_pairs = zip(dims[:-1], dims[1:])
|
||||
|
||||
self.layers = MList([nn.Sequential(nn.Conv2d(channels, dims[0], init_kernel_size, padding = init_kernel_size // 2), leaky_relu())])
|
||||
|
||||
for dim_in, dim_out in dim_pairs:
|
||||
self.layers.append(nn.Sequential(
|
||||
nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1),
|
||||
nn.GroupNorm(groups, dim_out),
|
||||
leaky_relu()
|
||||
))
|
||||
|
||||
dim = dims[-1]
|
||||
self.to_logits = nn.Sequential( # return 5 x 5, for PatchGAN-esque training
|
||||
nn.Conv2d(dim, dim, 1),
|
||||
leaky_relu(),
|
||||
nn.Conv2d(dim, 1, 4)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
for net in self.layers:
|
||||
x = net(x)
|
||||
|
||||
return self.to_logits(x)
|
||||
|
||||
class ContinuousPositionBias(nn.Module):
|
||||
""" from https://arxiv.org/abs/2111.09883 """
|
||||
|
||||
def __init__(self, *, dim, heads, layers = 2):
|
||||
super().__init__()
|
||||
self.net = MList([])
|
||||
self.net.append(nn.Sequential(nn.Linear(2, dim), leaky_relu()))
|
||||
|
||||
for _ in range(layers - 1):
|
||||
self.net.append(nn.Sequential(nn.Linear(dim, dim), leaky_relu()))
|
||||
|
||||
self.net.append(nn.Linear(dim, heads))
|
||||
self.register_buffer('rel_pos', None, persistent = False)
|
||||
|
||||
def forward(self, x):
|
||||
n, device = x.shape[-1], x.device
|
||||
fmap_size = int(sqrt(n))
|
||||
|
||||
if not exists(self.rel_pos):
|
||||
pos = torch.arange(fmap_size, device = device)
|
||||
grid = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij'))
|
||||
grid = rearrange(grid, 'c i j -> (i j) c')
|
||||
rel_pos = rearrange(grid, 'i c -> i 1 c') - rearrange(grid, 'j c -> 1 j c')
|
||||
rel_pos = torch.sign(rel_pos) * torch.log(rel_pos.abs() + 1)
|
||||
self.register_buffer('rel_pos', rel_pos, persistent = False)
|
||||
|
||||
rel_pos = self.rel_pos.float()
|
||||
|
||||
for layer in self.net:
|
||||
rel_pos = layer(rel_pos)
|
||||
|
||||
bias = rearrange(rel_pos, 'i j h -> h i j')
|
||||
return x + bias
|
||||
|
||||
class GLUResBlock(nn.Module):
|
||||
def __init__(self, chan, groups = 16):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(chan, chan * 2, 3, padding = 1),
|
||||
nn.GLU(dim = 1),
|
||||
nn.GroupNorm(groups, chan),
|
||||
nn.Conv2d(chan, chan * 2, 3, padding = 1),
|
||||
nn.GLU(dim = 1),
|
||||
nn.GroupNorm(groups, chan),
|
||||
nn.Conv2d(chan, chan, 1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x) + x
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, chan, groups = 16):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(chan, chan, 3, padding = 1),
|
||||
nn.GroupNorm(groups, chan),
|
||||
leaky_relu(),
|
||||
nn.Conv2d(chan, chan, 3, padding = 1),
|
||||
nn.GroupNorm(groups, chan),
|
||||
leaky_relu(),
|
||||
nn.Conv2d(chan, chan, 1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x) + x
|
||||
|
||||
class VQGanAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dim,
|
||||
dim_head = 64,
|
||||
heads = 8,
|
||||
dropout = 0.
|
||||
):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
inner_dim = heads * dim_head
|
||||
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.pre_norm = LayerNormChan(dim)
|
||||
|
||||
self.cpb = ContinuousPositionBias(dim = dim // 4, heads = heads)
|
||||
self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)
|
||||
self.to_out = nn.Conv2d(inner_dim, dim, 1, bias = False)
|
||||
|
||||
def forward(self, x):
|
||||
h = self.heads
|
||||
height, width, residual = *x.shape[-2:], x.clone()
|
||||
|
||||
x = self.pre_norm(x)
|
||||
|
||||
q, k, v = self.to_qkv(x).chunk(3, dim = 1)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = h), (q, k, v))
|
||||
|
||||
sim = einsum('b h c i, b h c j -> b h i j', q, k) * self.scale
|
||||
|
||||
sim = self.cpb(sim)
|
||||
|
||||
attn = stable_softmax(sim, dim = -1)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = einsum('b h i j, b h c j -> b h c i', attn, v)
|
||||
out = rearrange(out, 'b h c (x y) -> b (h c) x y', x = height, y = width)
|
||||
out = self.to_out(out)
|
||||
|
||||
return out + residual
|
||||
|
||||
class NullVQGanVAE(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
channels
|
||||
):
|
||||
super().__init__()
|
||||
self.encoded_dim = channels
|
||||
self.layers = 0
|
||||
|
||||
def get_encoded_fmap_size(self, size):
|
||||
return size
|
||||
|
||||
def copy_for_eval(self):
|
||||
return self
|
||||
|
||||
def encode(self, x):
|
||||
return x
|
||||
|
||||
def decode(self, x):
|
||||
return x
|
||||
|
||||
class VQGanVAE(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dim,
|
||||
image_size,
|
||||
channels = 3,
|
||||
layers = 4,
|
||||
layer_mults = None,
|
||||
l2_recon_loss = False,
|
||||
use_hinge_loss = True,
|
||||
num_resnet_blocks = 1,
|
||||
vgg = None,
|
||||
vq_codebook_size = 512,
|
||||
vq_decay = 0.8,
|
||||
vq_commitment_weight = 1.,
|
||||
vq_kmeans_init = True,
|
||||
vq_use_cosine_sim = True,
|
||||
use_attn = True,
|
||||
attn_dim_head = 64,
|
||||
attn_heads = 8,
|
||||
resnet_groups = 16,
|
||||
attn_dropout = 0.,
|
||||
first_conv_kernel_size = 5,
|
||||
use_vgg_and_gan = True,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
assert dim % resnet_groups == 0, f'dimension {dim} must be divisible by {resnet_groups} (groups for the groupnorm)'
|
||||
|
||||
vq_kwargs, kwargs = groupby_prefix_and_trim('vq_', kwargs)
|
||||
|
||||
self.image_size = image_size
|
||||
self.channels = channels
|
||||
self.layers = layers
|
||||
self.fmap_size = image_size // (layers ** 2)
|
||||
self.codebook_size = vq_codebook_size
|
||||
|
||||
self.encoders = MList([])
|
||||
self.decoders = MList([])
|
||||
|
||||
layer_mults = default(layer_mults, list(map(lambda t: 2 ** t, range(layers))))
|
||||
assert len(layer_mults) == layers, 'layer multipliers must be equal to designated number of layers'
|
||||
|
||||
layer_dims = [dim * mult for mult in layer_mults]
|
||||
dims = (dim, *layer_dims)
|
||||
codebook_dim = layer_dims[-1]
|
||||
|
||||
self.encoded_dim = dims[-1]
|
||||
|
||||
dim_pairs = zip(dims[:-1], dims[1:])
|
||||
|
||||
append = lambda arr, t: arr.append(t)
|
||||
prepend = lambda arr, t: arr.insert(0, t)
|
||||
|
||||
if not isinstance(num_resnet_blocks, tuple):
|
||||
num_resnet_blocks = (*((0,) * (layers - 1)), num_resnet_blocks)
|
||||
|
||||
if not isinstance(use_attn, tuple):
|
||||
use_attn = (*((False,) * (layers - 1)), use_attn)
|
||||
|
||||
assert len(num_resnet_blocks) == layers, 'number of resnet blocks config must be equal to number of layers'
|
||||
assert len(use_attn) == layers
|
||||
|
||||
for layer_index, (dim_in, dim_out), layer_num_resnet_blocks, layer_use_attn in zip(range(layers), dim_pairs, num_resnet_blocks, use_attn):
|
||||
append(self.encoders, nn.Sequential(nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), leaky_relu()))
|
||||
prepend(self.decoders, nn.Sequential(nn.Upsample(scale_factor = 2, mode = 'bilinear', align_corners = False), nn.Conv2d(dim_out, dim_in, 3, padding = 1), leaky_relu()))
|
||||
|
||||
if layer_use_attn:
|
||||
prepend(self.decoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
|
||||
|
||||
for _ in range(layer_num_resnet_blocks):
|
||||
append(self.encoders, ResBlock(dim_out, groups = resnet_groups))
|
||||
prepend(self.decoders, GLUResBlock(dim_out, groups = resnet_groups))
|
||||
|
||||
if layer_use_attn:
|
||||
append(self.encoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
|
||||
|
||||
prepend(self.encoders, nn.Conv2d(channels, dim, first_conv_kernel_size, padding = first_conv_kernel_size // 2))
|
||||
append(self.decoders, nn.Conv2d(dim, channels, 1))
|
||||
|
||||
self.vq = VQ(
|
||||
dim = codebook_dim,
|
||||
codebook_size = vq_codebook_size,
|
||||
decay = vq_decay,
|
||||
commitment_weight = vq_commitment_weight,
|
||||
accept_image_fmap = True,
|
||||
kmeans_init = vq_kmeans_init,
|
||||
use_cosine_sim = vq_use_cosine_sim,
|
||||
**vq_kwargs
|
||||
)
|
||||
|
||||
# reconstruction loss
|
||||
|
||||
self.recon_loss_fn = F.mse_loss if l2_recon_loss else F.l1_loss
|
||||
|
||||
# turn off GAN and perceptual loss if grayscale
|
||||
|
||||
self.vgg = None
|
||||
self.discr = None
|
||||
self.use_vgg_and_gan = use_vgg_and_gan
|
||||
|
||||
if not use_vgg_and_gan:
|
||||
return
|
||||
|
||||
# preceptual loss
|
||||
|
||||
if exists(vgg):
|
||||
self.vgg = vgg
|
||||
else:
|
||||
self.vgg = torchvision.models.vgg16(pretrained = True)
|
||||
self.vgg.classifier = nn.Sequential(*self.vgg.classifier[:-2])
|
||||
|
||||
# gan related losses
|
||||
|
||||
self.discr = Discriminator(dims = dims, channels = channels)
|
||||
|
||||
self.discr_loss = hinge_discr_loss if use_hinge_loss else bce_discr_loss
|
||||
self.gen_loss = hinge_gen_loss if use_hinge_loss else bce_gen_loss
|
||||
|
||||
def get_encoded_fmap_size(self, image_size):
|
||||
return image_size // (2 ** self.layers)
|
||||
|
||||
def copy_for_eval(self):
|
||||
device = next(self.parameters()).device
|
||||
vae_copy = copy.deepcopy(self.cpu())
|
||||
|
||||
if vae_copy.use_vgg_and_gan:
|
||||
del vae_copy.discr
|
||||
del vae_copy.vgg
|
||||
|
||||
vae_copy.eval()
|
||||
return vae_copy.to(device)
|
||||
|
||||
@remove_vgg
|
||||
def state_dict(self, *args, **kwargs):
|
||||
return super().state_dict(*args, **kwargs)
|
||||
|
||||
@remove_vgg
|
||||
def load_state_dict(self, *args, **kwargs):
|
||||
return super().load_state_dict(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def codebook(self):
|
||||
return self.vq.codebook
|
||||
|
||||
def encode(self, fmap):
|
||||
for enc in self.encoders:
|
||||
fmap = enc(fmap)
|
||||
|
||||
return fmap
|
||||
|
||||
def decode(self, fmap, return_indices_and_loss = False):
|
||||
fmap, indices, commit_loss = self.vq(fmap)
|
||||
|
||||
for dec in self.decoders:
|
||||
fmap = dec(fmap)
|
||||
|
||||
if not return_indices_and_loss:
|
||||
return fmap
|
||||
|
||||
return fmap, indices, commit_loss
|
||||
|
||||
def forward(
|
||||
self,
|
||||
img,
|
||||
return_loss = False,
|
||||
return_discr_loss = False,
|
||||
return_recons = False,
|
||||
add_gradient_penalty = True
|
||||
):
|
||||
batch, channels, height, width, device = *img.shape, img.device
|
||||
assert height == self.image_size and width == self.image_size, 'height and width of input image must be equal to {self.image_size}'
|
||||
assert channels == self.channels, 'number of channels on image or sketch is not equal to the channels set on this VQGanVAE'
|
||||
|
||||
fmap = self.encode(img)
|
||||
|
||||
fmap, indices, commit_loss = self.decode(fmap, return_indices_and_loss = True)
|
||||
|
||||
if not return_loss and not return_discr_loss:
|
||||
return fmap
|
||||
|
||||
assert return_loss ^ return_discr_loss, 'you should either return autoencoder loss or discriminator loss, but not both'
|
||||
|
||||
# whether to return discriminator loss
|
||||
|
||||
if return_discr_loss:
|
||||
assert exists(self.discr), 'discriminator must exist to train it'
|
||||
|
||||
fmap.detach_()
|
||||
img.requires_grad_()
|
||||
|
||||
fmap_discr_logits, img_discr_logits = map(self.discr, (fmap, img))
|
||||
|
||||
discr_loss = self.discr_loss(fmap_discr_logits, img_discr_logits)
|
||||
|
||||
if add_gradient_penalty:
|
||||
gp = gradient_penalty(img, img_discr_logits)
|
||||
loss = discr_loss + gp
|
||||
|
||||
if return_recons:
|
||||
return loss, fmap
|
||||
|
||||
return loss
|
||||
|
||||
# reconstruction loss
|
||||
|
||||
recon_loss = self.recon_loss_fn(fmap, img)
|
||||
|
||||
# early return if training on grayscale
|
||||
|
||||
if not self.use_vgg_and_gan:
|
||||
if return_recons:
|
||||
return recon_loss, fmap
|
||||
|
||||
return recon_loss
|
||||
|
||||
# perceptual loss
|
||||
|
||||
img_vgg_input = img
|
||||
fmap_vgg_input = fmap
|
||||
|
||||
if img.shape[1] == 1:
|
||||
# handle grayscale for vgg
|
||||
img_vgg_input, fmap_vgg_input = map(lambda t: repeat(t, 'b 1 ... -> b c ...', c = 3), (img_vgg_input, fmap_vgg_input))
|
||||
|
||||
img_vgg_feats = self.vgg(img_vgg_input)
|
||||
recon_vgg_feats = self.vgg(fmap_vgg_input)
|
||||
perceptual_loss = F.mse_loss(img_vgg_feats, recon_vgg_feats)
|
||||
|
||||
# generator loss
|
||||
|
||||
gen_loss = self.gen_loss(self.discr(fmap))
|
||||
|
||||
# calculate adaptive weight
|
||||
|
||||
last_dec_layer = self.decoders[-1].weight
|
||||
|
||||
norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_dec_layer).norm(p = 2)
|
||||
norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_dec_layer).norm(p = 2)
|
||||
|
||||
adaptive_weight = safe_div(norm_grad_wrt_perceptual_loss, norm_grad_wrt_gen_loss)
|
||||
adaptive_weight.clamp_(max = 1e4)
|
||||
|
||||
# combine losses
|
||||
|
||||
loss = recon_loss + perceptual_loss + commit_loss + adaptive_weight * gen_loss
|
||||
|
||||
if return_recons:
|
||||
return loss, fmap
|
||||
|
||||
return loss
|
||||
4
setup.py
4
setup.py
@@ -10,7 +10,7 @@ setup(
|
||||
'dream = dalle2_pytorch.cli:dream'
|
||||
],
|
||||
},
|
||||
version = '0.0.4',
|
||||
version = '0.0.41',
|
||||
license='MIT',
|
||||
description = 'DALL-E 2',
|
||||
author = 'Phil Wang',
|
||||
@@ -25,10 +25,12 @@ setup(
|
||||
'click',
|
||||
'einops>=0.4',
|
||||
'einops-exts>=0.0.3',
|
||||
'kornia>=0.5.4',
|
||||
'pillow',
|
||||
'torch>=1.10',
|
||||
'torchvision',
|
||||
'tqdm',
|
||||
'vector-quantize-pytorch',
|
||||
'x-clip>=0.4.4',
|
||||
'youtokentome'
|
||||
],
|
||||
|
||||
Reference in New Issue
Block a user