mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 11:34:29 +01:00
Compare commits
57 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9359ad2e91 | ||
|
|
9ff228188b | ||
|
|
2d9963d30e | ||
|
|
58d9b422f3 | ||
|
|
44b319cb57 | ||
|
|
c30f380689 | ||
|
|
e4e884bb8b | ||
|
|
803ad9c17d | ||
|
|
a88dd6a9c0 | ||
|
|
72c16b496e | ||
|
|
81d83dd7f2 | ||
|
|
fa66f7e1e9 | ||
|
|
aa8d135245 | ||
|
|
70282de23b | ||
|
|
83f761847e | ||
|
|
11469dc0c6 | ||
|
|
2d25c89f35 | ||
|
|
3fe96c208a | ||
|
|
0fc6c9cdf3 | ||
|
|
7ee0ecc388 | ||
|
|
1924c7cc3d | ||
|
|
f7df3caaf3 | ||
|
|
fc954ee788 | ||
|
|
c1db2753f5 | ||
|
|
ad87bfe28f | ||
|
|
76c767b1ce | ||
|
|
d991b8c39c | ||
|
|
902693e271 | ||
|
|
35cd63982d | ||
|
|
53ce6dfdf6 | ||
|
|
ad8d7a368b | ||
|
|
b8cf1e5c20 | ||
|
|
94aaa08d97 | ||
|
|
8b9bbec7d1 | ||
|
|
1bb9fc9829 | ||
|
|
5e421bd5bb | ||
|
|
67fcab1122 | ||
|
|
5bfbccda22 | ||
|
|
989275ff59 | ||
|
|
56408f4a40 | ||
|
|
d1a697ac23 | ||
|
|
ebe01749ed | ||
|
|
63195cc2cb | ||
|
|
a2ef69af66 | ||
|
|
5fff22834e | ||
|
|
a9421f49ec | ||
|
|
77fa34eae9 | ||
|
|
1c1e508369 | ||
|
|
f19c99ecb0 | ||
|
|
721a444686 | ||
|
|
63450b466d | ||
|
|
20e7eb5a9b | ||
|
|
e2f9615afa | ||
|
|
0d1c07c803 | ||
|
|
a389f81138 | ||
|
|
0283556608 | ||
|
|
5063d192b6 |
227
README.md
227
README.md
@@ -47,7 +47,7 @@ clip = CLIP(
|
||||
use_all_token_embeds = True, # whether to use fine-grained contrastive learning (FILIP)
|
||||
decoupled_contrastive_learning = True, # use decoupled contrastive learning (DCL) objective function, removing positive pairs from the denominator of the InfoNCE loss (CLOOB + DCL)
|
||||
extra_latent_projection = True, # whether to use separate projections for text-to-image vs image-to-text comparisons (CLOOB)
|
||||
use_visual_ssl = True, # whether to do self supervised learning on iages
|
||||
use_visual_ssl = True, # whether to do self supervised learning on images
|
||||
visual_ssl_type = 'simclr', # can be either 'simclr' or 'simsiam', depending on using DeCLIP or SLIP
|
||||
use_mlm = False, # use masked language learning (MLM) on text (DeCLIP)
|
||||
text_ssl_loss_weight = 0.05, # weight for text MLM loss
|
||||
@@ -110,7 +110,8 @@ decoder = Decoder(
|
||||
unet = unet,
|
||||
clip = clip,
|
||||
timesteps = 100,
|
||||
cond_drop_prob = 0.2
|
||||
image_cond_drop_prob = 0.1,
|
||||
text_cond_drop_prob = 0.5
|
||||
).cuda()
|
||||
|
||||
# mock images (get a lot of this)
|
||||
@@ -229,7 +230,8 @@ decoder = Decoder(
|
||||
unet = (unet1, unet2), # insert both unets in order of low resolution to highest resolution (you can have as many stages as you want here)
|
||||
image_sizes = (256, 512), # resolutions, 256 for first unet, 512 for second. these must be unique and in ascending order (matches with the unets passed in)
|
||||
timesteps = 1000,
|
||||
cond_drop_prob = 0.2
|
||||
image_cond_drop_prob = 0.1,
|
||||
text_cond_drop_prob = 0.5
|
||||
).cuda()
|
||||
|
||||
# mock images (get a lot of this)
|
||||
@@ -348,7 +350,8 @@ decoder = Decoder(
|
||||
image_sizes = (128, 256),
|
||||
clip = clip,
|
||||
timesteps = 100,
|
||||
cond_drop_prob = 0.2,
|
||||
image_cond_drop_prob = 0.1,
|
||||
text_cond_drop_prob = 0.5,
|
||||
condition_on_text_encodings = False # set this to True if you wish to condition on text during training and sampling
|
||||
).cuda()
|
||||
|
||||
@@ -430,8 +433,8 @@ images = torch.randn(4, 3, 256, 256).cuda()
|
||||
# precompute the text and image embeddings
|
||||
# here using the diffusion prior class, but could be done with CLIP alone
|
||||
|
||||
clip_image_embeds = diffusion_prior.get_image_embed(images)
|
||||
clip_text_embeds = diffusion_prior.get_text_cond(text).get('text_embed')
|
||||
clip_image_embeds = diffusion_prior.clip.embed_image(images).image_embed
|
||||
clip_text_embeds = diffusion_prior.clip.embed_text(text).text_embed
|
||||
|
||||
# feed text and images into diffusion prior network
|
||||
|
||||
@@ -495,6 +498,95 @@ loss.backward()
|
||||
# now the diffusion prior can generate image embeddings from the text embeddings
|
||||
```
|
||||
|
||||
## OpenAI CLIP
|
||||
|
||||
Although there is the possibility they are using an unreleased, more powerful CLIP, you can use one of the released ones, if you do not wish to train your own CLIP from scratch. This will also allow the community to more quickly validate the conclusions of the paper.
|
||||
|
||||
To use a pretrained OpenAI CLIP, simply import `OpenAIClipAdapter` and pass it into the `DiffusionPrior` or `Decoder` like so
|
||||
|
||||
```python
|
||||
import torch
|
||||
from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder, OpenAIClipAdapter
|
||||
|
||||
# openai pretrained clip - defaults to ViT/B-32
|
||||
|
||||
clip = OpenAIClipAdapter()
|
||||
|
||||
# mock data
|
||||
|
||||
text = torch.randint(0, 49408, (4, 256)).cuda()
|
||||
images = torch.randn(4, 3, 256, 256).cuda()
|
||||
|
||||
# prior networks (with transformer)
|
||||
|
||||
prior_network = DiffusionPriorNetwork(
|
||||
dim = 512,
|
||||
depth = 6,
|
||||
dim_head = 64,
|
||||
heads = 8
|
||||
).cuda()
|
||||
|
||||
diffusion_prior = DiffusionPrior(
|
||||
net = prior_network,
|
||||
clip = clip,
|
||||
timesteps = 100,
|
||||
cond_drop_prob = 0.2
|
||||
).cuda()
|
||||
|
||||
loss = diffusion_prior(text, images)
|
||||
loss.backward()
|
||||
|
||||
# do above for many steps ...
|
||||
|
||||
# decoder (with unet)
|
||||
|
||||
unet1 = Unet(
|
||||
dim = 128,
|
||||
image_embed_dim = 512,
|
||||
cond_dim = 128,
|
||||
channels = 3,
|
||||
dim_mults=(1, 2, 4, 8)
|
||||
).cuda()
|
||||
|
||||
unet2 = Unet(
|
||||
dim = 16,
|
||||
image_embed_dim = 512,
|
||||
cond_dim = 128,
|
||||
channels = 3,
|
||||
dim_mults = (1, 2, 4, 8, 16)
|
||||
).cuda()
|
||||
|
||||
decoder = Decoder(
|
||||
unet = (unet1, unet2),
|
||||
image_sizes = (128, 256),
|
||||
clip = clip,
|
||||
timesteps = 100,
|
||||
image_cond_drop_prob = 0.1,
|
||||
text_cond_drop_prob = 0.5,
|
||||
condition_on_text_encodings = False # set this to True if you wish to condition on text during training and sampling
|
||||
).cuda()
|
||||
|
||||
for unet_number in (1, 2):
|
||||
loss = decoder(images, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
|
||||
loss.backward()
|
||||
|
||||
# do above for many steps
|
||||
|
||||
dalle2 = DALLE2(
|
||||
prior = diffusion_prior,
|
||||
decoder = decoder
|
||||
)
|
||||
|
||||
images = dalle2(
|
||||
['a butterfly trying to escape a tornado'],
|
||||
cond_scale = 2. # classifier free guidance strength (> 1 would strengthen the condition)
|
||||
)
|
||||
|
||||
# save your image (in this example, of size 256x256)
|
||||
```
|
||||
|
||||
Now you'll just have to worry about training the Prior and the Decoder!
|
||||
|
||||
## Experimental
|
||||
|
||||
### DALL-E2 with Latent Diffusion
|
||||
@@ -528,7 +620,7 @@ clip = CLIP(
|
||||
# 3 unets for the decoder (a la cascading DDPM)
|
||||
|
||||
# first two unets are doing latent diffusion
|
||||
# vqgan-vae must be trained before hand
|
||||
# vqgan-vae must be trained beforehand
|
||||
|
||||
vae1 = VQGanVAE(
|
||||
dim = 32,
|
||||
@@ -581,7 +673,8 @@ decoder = Decoder(
|
||||
unet = (unet1, unet2, unet3), # insert unets in order of low resolution to highest resolution (you can have as many stages as you want here)
|
||||
image_sizes = (256, 512, 1024), # resolutions, 256 for first unet, 512 for second, 1024 for third
|
||||
timesteps = 100,
|
||||
cond_drop_prob = 0.2
|
||||
image_cond_drop_prob = 0.1,
|
||||
text_cond_drop_prob = 0.5
|
||||
).cuda()
|
||||
|
||||
# mock images (get a lot of this)
|
||||
@@ -615,7 +708,83 @@ images = decoder.sample(mock_image_embed) # (1, 3, 1024, 1024)
|
||||
|
||||
## Training wrapper (wip)
|
||||
|
||||
Offer training wrappers
|
||||
### Decoder Training
|
||||
|
||||
Training the `Decoder` may be confusing, as one needs to keep track of an optimizer for each of the `Unet`(s) separately. Each `Unet` will also need its own corresponding exponential moving average. The `DecoderTrainer` hopes to make this simple, as shown below
|
||||
|
||||
```python
|
||||
import torch
|
||||
from dalle2_pytorch import DALLE2, Unet, Decoder, CLIP, DecoderTrainer
|
||||
|
||||
clip = CLIP(
|
||||
dim_text = 512,
|
||||
dim_image = 512,
|
||||
dim_latent = 512,
|
||||
num_text_tokens = 49408,
|
||||
text_enc_depth = 6,
|
||||
text_seq_len = 256,
|
||||
text_heads = 8,
|
||||
visual_enc_depth = 6,
|
||||
visual_image_size = 256,
|
||||
visual_patch_size = 32,
|
||||
visual_heads = 8
|
||||
).cuda()
|
||||
|
||||
# mock data
|
||||
|
||||
text = torch.randint(0, 49408, (4, 256)).cuda()
|
||||
images = torch.randn(4, 3, 256, 256).cuda()
|
||||
|
||||
# decoder (with unet)
|
||||
|
||||
unet1 = Unet(
|
||||
dim = 128,
|
||||
image_embed_dim = 512,
|
||||
text_embed_dim = 512,
|
||||
cond_dim = 128,
|
||||
channels = 3,
|
||||
dim_mults=(1, 2, 4, 8)
|
||||
).cuda()
|
||||
|
||||
unet2 = Unet(
|
||||
dim = 16,
|
||||
image_embed_dim = 512,
|
||||
text_embed_dim = 512,
|
||||
cond_dim = 128,
|
||||
channels = 3,
|
||||
dim_mults = (1, 2, 4, 8, 16),
|
||||
cond_on_text_encodings = True
|
||||
).cuda()
|
||||
|
||||
decoder = Decoder(
|
||||
unet = (unet1, unet2),
|
||||
image_sizes = (128, 256),
|
||||
clip = clip,
|
||||
timesteps = 1000,
|
||||
condition_on_text_encodings = True
|
||||
).cuda()
|
||||
|
||||
decoder_trainer = DecoderTrainer(
|
||||
decoder,
|
||||
lr = 3e-4,
|
||||
wd = 1e-2,
|
||||
ema_beta = 0.99,
|
||||
ema_update_after_step = 1000,
|
||||
ema_update_every = 10,
|
||||
)
|
||||
|
||||
for unet_number in (1, 2):
|
||||
loss = decoder_trainer(images, text = text, unet_number = unet_number) # use the decoder_trainer forward
|
||||
loss.backward()
|
||||
|
||||
decoder_trainer.update(unet_number) # update the specific unet as well as its exponential moving average
|
||||
|
||||
# after much training
|
||||
# you can sample from the exponentially moving averaged unets as so
|
||||
|
||||
mock_image_embed = torch.randn(4, 512).cuda()
|
||||
images = decoder_trainer.sample(mock_image_embed, text = text) # (4, 3, 256, 256)
|
||||
```
|
||||
|
||||
## CLI (wip)
|
||||
|
||||
@@ -648,13 +817,25 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
- [x] use inheritance just this once for sharing logic between decoder and prior network ddpms
|
||||
- [x] bring in vit-vqgan https://arxiv.org/abs/2110.04627 for the latent diffusion
|
||||
- [x] abstract interface for CLIP adapter class, so other CLIPs can be brought in
|
||||
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
|
||||
- [x] take care of mixed precision as well as gradient accumulation within decoder trainer
|
||||
- [x] just take care of the training for the decoder in a wrapper class, as each unet in the cascade will need its own optimizer
|
||||
- [x] bring in tools to train vqgan-vae
|
||||
- [x] add convnext backbone for vqgan-vae (in addition to vit [vit-vqgan] + resnet)
|
||||
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo)
|
||||
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
|
||||
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
|
||||
- [ ] pull logic for training diffusion prior into a class DiffusionPriorTrainer, for eventual script based + CLI based training
|
||||
- [ ] train on a toy task, offer in colab
|
||||
- [ ] think about how best to design a declarative training config that handles preencoding for prior and training of multiple networks in decoder
|
||||
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
|
||||
- [ ] bring in tools to train vqgan-vae
|
||||
- [ ] bring in cross-scale embedding from iclr paper https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/crossformer.py#L14
|
||||
- [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824
|
||||
- [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove
|
||||
- [ ] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a>
|
||||
- [ ] make sure for the latter unets in the cascade, one can train on crops for learning super resolution (constrain the unet to be only convolutions in that case, or allow conv-like attention with rel pos bias)
|
||||
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
|
||||
- [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783
|
||||
- [ ] make sure DDPMs can be run with traditional resnet blocks (but leave convnext as an option for experimentation)
|
||||
|
||||
## Citations
|
||||
|
||||
@@ -686,12 +867,22 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
|
||||
```bibtex
|
||||
@inproceedings{Liu2022ACF,
|
||||
title = {A ConvNet for the 2020https://arxiv.org/abs/2112.11435s},
|
||||
title = {A ConvNet for the 2020s},
|
||||
author = {Zhuang Liu and Hanzi Mao and Chaozheng Wu and Christoph Feichtenhofer and Trevor Darrell and Saining Xie},
|
||||
year = {2022}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@article{shen2019efficient,
|
||||
author = {Zhuoran Shen and Mingyuan Zhang and Haiyu Zhao and Shuai Yi and Hongsheng Li},
|
||||
title = {Efficient Attention: Attention with Linear Complexities},
|
||||
journal = {CoRR},
|
||||
year = {2018},
|
||||
url = {http://arxiv.org/abs/1812.01243},
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@inproceedings{Tu2022MaxViTMV,
|
||||
title = {MaxViT: Multi-Axis Vision Transformer},
|
||||
@@ -710,4 +901,14 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
}
|
||||
```
|
||||
|
||||
*Creating noise from data is easy; creating data from noise is generative modeling.* - Yang Song's <a href="https://arxiv.org/abs/2011.13456">paper</a>
|
||||
```bibtex
|
||||
@article{Shleifer2021NormFormerIT,
|
||||
title = {NormFormer: Improved Transformer Pretraining with Extra Normalization},
|
||||
author = {Sam Shleifer and Jason Weston and Myle Ott},
|
||||
journal = {ArXiv},
|
||||
year = {2021},
|
||||
volume = {abs/2110.09456}
|
||||
}
|
||||
```
|
||||
|
||||
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
|
||||
from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter
|
||||
from dalle2_pytorch.train import DecoderTrainer
|
||||
|
||||
from dalle2_pytorch.vqgan_vae import VQGanVAE
|
||||
from x_clip import CLIP
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import click
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
from functools import reduce
|
||||
from pathlib import Path
|
||||
|
||||
from dalle2_pytorch import DALLE2, Decoder, DiffusionPrior
|
||||
|
||||
@@ -3,6 +3,7 @@ from tqdm import tqdm
|
||||
from inspect import isfunction
|
||||
from functools import partial
|
||||
from contextlib import contextmanager
|
||||
from collections import namedtuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -28,6 +29,9 @@ from x_clip import CLIP
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def identity(t, *args, **kwargs):
|
||||
return t
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
@@ -90,8 +94,21 @@ def resize_image_to(t, image_size, mode = 'bilinear'): # take a look at https://
|
||||
|
||||
return F.interpolate(t, size = shape, mode = mode, align_corners = False)
|
||||
|
||||
# image normalization functions
|
||||
# ddpms expect images to be in the range of -1 to 1
|
||||
# but CLIP may otherwise
|
||||
|
||||
def normalize_img(img):
|
||||
return img * 2 - 1
|
||||
|
||||
def unnormalize_img(normed_img):
|
||||
return (normed_img + 1) * 0.5
|
||||
|
||||
# clip related adapters
|
||||
|
||||
EmbeddedText = namedtuple('EmbedTextReturn', ['text_embed', 'text_encodings', 'text_mask'])
|
||||
EmbeddedImage = namedtuple('EmbedImageReturn', ['image_embed', 'image_encodings'])
|
||||
|
||||
class BaseClipAdapter(nn.Module):
|
||||
def __init__(self, clip):
|
||||
super().__init__()
|
||||
@@ -109,6 +126,10 @@ class BaseClipAdapter(nn.Module):
|
||||
def image_channels(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def max_text_len(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def embed_text(self, text):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -128,12 +149,18 @@ class XClipAdapter(BaseClipAdapter):
|
||||
def image_channels(self):
|
||||
return self.clip.image_channels
|
||||
|
||||
@property
|
||||
def max_text_len(self):
|
||||
return self.clip.text_seq_len
|
||||
|
||||
@torch.no_grad()
|
||||
def embed_text(self, text):
|
||||
text = text[..., :self.max_text_len]
|
||||
text_mask = text != 0
|
||||
encoder_output = self.clip.text_transformer(text)
|
||||
text_cls, text_encodings = encoder_output[:, 0], encoder_output[:, 1:]
|
||||
text_embed = self.clip.to_text_latent(text_cls)
|
||||
return l2norm(text_embed), text_encodings
|
||||
return EmbeddedText(l2norm(text_embed), text_encodings, text_mask)
|
||||
|
||||
@torch.no_grad()
|
||||
def embed_image(self, image):
|
||||
@@ -141,7 +168,69 @@ class XClipAdapter(BaseClipAdapter):
|
||||
encoder_output = self.clip.visual_transformer(image)
|
||||
image_cls, image_encodings = encoder_output[:, 0], encoder_output[:, 1:]
|
||||
image_embed = self.clip.to_visual_latent(image_cls)
|
||||
return l2norm(image_embed), image_encodings
|
||||
return EmbeddedImage(l2norm(image_embed), image_encodings)
|
||||
|
||||
class OpenAIClipAdapter(BaseClipAdapter):
|
||||
def __init__(
|
||||
self,
|
||||
name = 'ViT-B/32'
|
||||
):
|
||||
import clip
|
||||
openai_clip, preprocess = clip.load(name)
|
||||
super().__init__(openai_clip)
|
||||
|
||||
text_attention_final = self.find_layer('ln_final')
|
||||
self.handle = text_attention_final.register_forward_hook(self._hook)
|
||||
self.clip_normalize = preprocess.transforms[-1]
|
||||
self.cleared = False
|
||||
|
||||
def find_layer(self, layer):
|
||||
modules = dict([*self.clip.named_modules()])
|
||||
return modules.get(layer, None)
|
||||
|
||||
def clear(self):
|
||||
if self.cleared:
|
||||
return
|
||||
|
||||
self.handle()
|
||||
|
||||
def _hook(self, _, inputs, outputs):
|
||||
self.text_encodings = outputs
|
||||
|
||||
@property
|
||||
def dim_latent(self):
|
||||
return 512
|
||||
|
||||
@property
|
||||
def image_size(self):
|
||||
return self.clip.visual.input_resolution
|
||||
|
||||
@property
|
||||
def image_channels(self):
|
||||
return 3
|
||||
|
||||
@property
|
||||
def max_text_len(self):
|
||||
return self.clip.context_length
|
||||
|
||||
@torch.no_grad()
|
||||
def embed_text(self, text):
|
||||
text = text[..., :self.max_text_len]
|
||||
text_mask = text != 0
|
||||
assert not self.cleared
|
||||
|
||||
text_embed = self.clip.encode_text(text)
|
||||
text_encodings = self.text_encodings
|
||||
del self.text_encodings
|
||||
return EmbeddedText(text_embed.float(), text_encodings.float(), text_mask)
|
||||
|
||||
@torch.no_grad()
|
||||
def embed_image(self, image):
|
||||
assert not self.cleared
|
||||
image = resize_image_to(image, self.image_size)
|
||||
image = self.clip_normalize(unnormalize_img(image))
|
||||
image_embed = self.clip.encode_image(image)
|
||||
return EmbeddedImage(image_embed.float(), None)
|
||||
|
||||
# classifier free guidance functions
|
||||
|
||||
@@ -410,7 +499,12 @@ class SwiGLU(nn.Module):
|
||||
x, gate = x.chunk(2, dim = -1)
|
||||
return x * F.silu(gate)
|
||||
|
||||
def FeedForward(dim, mult = 4, dropout = 0., post_activation_norm = False):
|
||||
def FeedForward(
|
||||
dim,
|
||||
mult = 4,
|
||||
dropout = 0.,
|
||||
post_activation_norm = False
|
||||
):
|
||||
""" post-activation norm https://arxiv.org/abs/2110.09456 """
|
||||
|
||||
inner_dim = int(mult * dim)
|
||||
@@ -433,7 +527,8 @@ class Attention(nn.Module):
|
||||
dim_head = 64,
|
||||
heads = 8,
|
||||
dropout = 0.,
|
||||
causal = False
|
||||
causal = False,
|
||||
post_norm = False
|
||||
):
|
||||
super().__init__()
|
||||
self.scale = dim_head ** -0.5
|
||||
@@ -448,7 +543,11 @@ class Attention(nn.Module):
|
||||
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
||||
self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim, bias = False),
|
||||
LayerNorm(dim) if post_norm else nn.Identity()
|
||||
)
|
||||
|
||||
def forward(self, x, mask = None, attn_bias = None):
|
||||
b, n, device = *x.shape[:2], x.device
|
||||
@@ -510,10 +609,11 @@ class CausalTransformer(nn.Module):
|
||||
dim_head = 64,
|
||||
heads = 8,
|
||||
ff_mult = 4,
|
||||
norm_out = False,
|
||||
norm_out = True,
|
||||
attn_dropout = 0.,
|
||||
ff_dropout = 0.,
|
||||
final_proj = True
|
||||
final_proj = True,
|
||||
normformer = False
|
||||
):
|
||||
super().__init__()
|
||||
self.rel_pos_bias = RelPosBias(heads = heads)
|
||||
@@ -521,8 +621,8 @@ class CausalTransformer(nn.Module):
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout),
|
||||
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
|
||||
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, post_norm = normformer),
|
||||
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
|
||||
]))
|
||||
|
||||
self.norm = LayerNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
|
||||
@@ -549,12 +649,14 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
self,
|
||||
dim,
|
||||
num_timesteps = None,
|
||||
l2norm_output = False, # whether to restrict image embedding output with l2norm at the end (may make it easier to learn?)
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
self.time_embeddings = nn.Embedding(num_timesteps, dim) if exists(num_timesteps) else nn.Sequential(Rearrange('b -> b 1'), MLP(1, dim)) # also offer a continuous version of timestep embeddings, with a 2 layer MLP
|
||||
self.learned_query = nn.Parameter(torch.randn(dim))
|
||||
self.causal_transformer = CausalTransformer(dim = dim, **kwargs)
|
||||
self.l2norm_output = l2norm_output
|
||||
|
||||
def forward_with_cond_scale(
|
||||
self,
|
||||
@@ -598,14 +700,14 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
|
||||
# classifier free guidance
|
||||
|
||||
cond_prob_mask = prob_mask_like((batch,), cond_drop_prob, device = device)
|
||||
cond_prob_mask = rearrange(cond_prob_mask, 'b -> b 1')
|
||||
keep_mask = prob_mask_like((batch,), 1 - cond_drop_prob, device = device)
|
||||
keep_mask = rearrange(keep_mask, 'b -> b 1')
|
||||
|
||||
mask &= cond_prob_mask
|
||||
mask &= keep_mask
|
||||
|
||||
# whether text embedding is masked or not depends on the classifier free guidance conditional masking
|
||||
|
||||
mask = torch.cat((mask, cond_prob_mask), dim = 1)
|
||||
mask = torch.cat((mask, keep_mask), dim = 1)
|
||||
|
||||
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
|
||||
# but let's just do it right
|
||||
@@ -633,7 +735,8 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
|
||||
pred_image_embed = tokens[..., -1, :]
|
||||
|
||||
return pred_image_embed
|
||||
output_fn = l2norm if self.l2norm_output else identity
|
||||
return output_fn(pred_image_embed)
|
||||
|
||||
class DiffusionPrior(BaseGaussianDiffusion):
|
||||
def __init__(
|
||||
@@ -650,6 +753,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
predict_x_start = True,
|
||||
beta_schedule = "cosine",
|
||||
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
|
||||
sampling_clamp_l2norm = False
|
||||
):
|
||||
super().__init__(
|
||||
beta_schedule = beta_schedule,
|
||||
@@ -678,6 +782,9 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
self.predict_x_start = predict_x_start
|
||||
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
|
||||
|
||||
# whether to force an l2norm, similar to clipping denoised, when sampling
|
||||
self.sampling_clamp_l2norm = sampling_clamp_l2norm
|
||||
|
||||
def p_mean_variance(self, x, t, text_cond, clip_denoised: bool):
|
||||
pred = self.net(x, t, **text_cond)
|
||||
|
||||
@@ -691,6 +798,9 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
if clip_denoised and not self.predict_x_start:
|
||||
x_recon.clamp_(-1., 1.)
|
||||
|
||||
if self.predict_x_start and self.sampling_clamp_l2norm:
|
||||
x_recon = l2norm(x_recon)
|
||||
|
||||
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
||||
return model_mean, posterior_variance, posterior_log_variance
|
||||
|
||||
@@ -741,12 +851,12 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
batch_size = text.shape[0]
|
||||
image_embed_dim = self.image_embed_dim
|
||||
|
||||
text_embed, text_encodings = self.clip.embed_text(text)
|
||||
text_embed, text_encodings, text_mask = self.clip.embed_text(text)
|
||||
|
||||
text_cond = dict(text_embed = text_embed)
|
||||
|
||||
if self.condition_on_text_encodings:
|
||||
text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text != 0}
|
||||
text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask}
|
||||
|
||||
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond)
|
||||
text_embeds = text_cond['text_embed']
|
||||
@@ -783,8 +893,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
# calculate text conditionings, based on what is passed in
|
||||
|
||||
if exists(text):
|
||||
text_embed, text_encodings = self.clip.embed_text(text)
|
||||
text_mask = text != 0
|
||||
text_embed, text_encodings, text_mask = self.clip.embed_text(text)
|
||||
|
||||
text_cond = dict(text_embed = text_embed)
|
||||
|
||||
@@ -821,6 +930,72 @@ class SinusoidalPosEmb(nn.Module):
|
||||
emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j')
|
||||
return torch.cat((emb.sin(), emb.cos()), dim = -1)
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_out,
|
||||
groups = 8
|
||||
):
|
||||
super().__init__()
|
||||
self.block = nn.Sequential(
|
||||
nn.Conv2d(dim, dim_out, 3, padding = 1),
|
||||
nn.GroupNorm(groups, dim_out),
|
||||
nn.SiLU()
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.block(x)
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_out,
|
||||
*,
|
||||
cond_dim = None,
|
||||
time_cond_dim = None,
|
||||
groups = 8
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.time_mlp = None
|
||||
|
||||
if exists(time_cond_dim):
|
||||
self.time_mlp = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(time_cond_dim, dim_out)
|
||||
)
|
||||
|
||||
self.cross_attn = None
|
||||
|
||||
if exists(cond_dim):
|
||||
self.cross_attn = EinopsToAndFrom(
|
||||
'b c h w',
|
||||
'b (h w) c',
|
||||
CrossAttention(
|
||||
dim = dim_out,
|
||||
context_dim = cond_dim
|
||||
)
|
||||
)
|
||||
|
||||
self.block1 = Block(dim, dim_out, groups = groups)
|
||||
self.block2 = Block(dim_out, dim_out, groups = groups)
|
||||
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
|
||||
|
||||
def forward(self, x, cond = None, time_emb = None):
|
||||
h = self.block1(x)
|
||||
|
||||
if exists(self.time_mlp) and exists(time_emb):
|
||||
time_emb = self.time_mlp(time_emb)
|
||||
h = rearrange(time_emb, 'b c -> b c 1 1') + h
|
||||
|
||||
if exists(self.cross_attn):
|
||||
assert exists(cond)
|
||||
h = self.cross_attn(h, context = cond) + h
|
||||
|
||||
h = self.block2(h)
|
||||
return h + self.res_conv(x)
|
||||
|
||||
class ConvNextBlock(nn.Module):
|
||||
""" https://arxiv.org/abs/2201.03545 """
|
||||
|
||||
@@ -830,8 +1005,8 @@ class ConvNextBlock(nn.Module):
|
||||
dim_out,
|
||||
*,
|
||||
cond_dim = None,
|
||||
mult = 2,
|
||||
norm = True
|
||||
time_cond_dim = None,
|
||||
mult = 2
|
||||
):
|
||||
super().__init__()
|
||||
need_projection = dim != dim_out
|
||||
@@ -848,11 +1023,19 @@ class ConvNextBlock(nn.Module):
|
||||
)
|
||||
)
|
||||
|
||||
self.time_mlp = None
|
||||
|
||||
if exists(time_cond_dim):
|
||||
self.time_mlp = nn.Sequential(
|
||||
nn.GELU(),
|
||||
nn.Linear(time_cond_dim, dim)
|
||||
)
|
||||
|
||||
self.ds_conv = nn.Conv2d(dim, dim, 7, padding = 3, groups = dim)
|
||||
|
||||
inner_dim = int(dim_out * mult)
|
||||
self.net = nn.Sequential(
|
||||
ChanLayerNorm(dim) if norm else nn.Identity(),
|
||||
ChanLayerNorm(dim),
|
||||
nn.Conv2d(dim, inner_dim, 3, padding = 1),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(inner_dim, dim_out, 3, padding = 1)
|
||||
@@ -860,9 +1043,13 @@ class ConvNextBlock(nn.Module):
|
||||
|
||||
self.res_conv = nn.Conv2d(dim, dim_out, 1) if need_projection else nn.Identity()
|
||||
|
||||
def forward(self, x, cond = None):
|
||||
def forward(self, x, cond = None, time = None):
|
||||
h = self.ds_conv(x)
|
||||
|
||||
if exists(time) and exists(self.time_mlp):
|
||||
t = self.time_mlp(time)
|
||||
h = rearrange(t, 'b c -> b c 1 1') + h
|
||||
|
||||
if exists(self.cross_attn):
|
||||
assert exists(cond)
|
||||
h = self.cross_attn(h, context = cond) + h
|
||||
@@ -945,6 +1132,46 @@ class GridAttention(nn.Module):
|
||||
out = rearrange(out, '(b h w) (w1 w2) c -> b c (w1 h) (w2 w)', w1 = wsz, w2 = wsz, h = h // wsz, w = w // wsz)
|
||||
return out
|
||||
|
||||
class LinearAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_head = 32,
|
||||
heads = 8
|
||||
):
|
||||
super().__init__()
|
||||
self.scale = dim_head ** -0.5
|
||||
self.heads = heads
|
||||
inner_dim = dim_head * heads
|
||||
self.norm = ChanLayerNorm(dim)
|
||||
|
||||
self.nonlin = nn.GELU()
|
||||
self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Conv2d(inner_dim, dim, 1, bias = False),
|
||||
ChanLayerNorm(dim)
|
||||
)
|
||||
|
||||
def forward(self, fmap):
|
||||
h, x, y = self.heads, *fmap.shape[-2:]
|
||||
|
||||
fmap = self.norm(fmap)
|
||||
q, k, v = self.to_qkv(fmap).chunk(3, dim = 1)
|
||||
q, k, v = rearrange_many((q, k, v), 'b (h c) x y -> (b h) (x y) c', h = h)
|
||||
|
||||
q = q.softmax(dim = -1)
|
||||
k = k.softmax(dim = -2)
|
||||
|
||||
q = q * self.scale
|
||||
|
||||
context = einsum('b n d, b n e -> b d e', k, v)
|
||||
out = einsum('b n d, b d e -> b n e', q, context)
|
||||
out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, x = x, y = y)
|
||||
|
||||
out = self.nonlin(out)
|
||||
return self.to_out(out)
|
||||
|
||||
class Unet(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -959,14 +1186,17 @@ class Unet(nn.Module):
|
||||
dim_mults=(1, 2, 4, 8),
|
||||
channels = 3,
|
||||
attn_dim_head = 32,
|
||||
attn_heads = 8,
|
||||
attn_heads = 16,
|
||||
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
|
||||
sparse_attn = False,
|
||||
sparse_attn_window = 8, # window size for sparse attention
|
||||
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
|
||||
cond_on_text_encodings = False,
|
||||
max_text_len = 256,
|
||||
cond_on_image_embeds = False,
|
||||
init_dim = None,
|
||||
init_conv_kernel_size = 7,
|
||||
block_type = 'resnet',
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
# save locals to take care of some hyperparameters for cascading DDPM
|
||||
@@ -984,28 +1214,45 @@ class Unet(nn.Module):
|
||||
self.channels = channels
|
||||
|
||||
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
|
||||
init_dim = default(init_dim, dim // 2)
|
||||
|
||||
dims = [init_channels, *map(lambda m: dim * m, dim_mults)]
|
||||
assert (init_conv_kernel_size % 2) == 1
|
||||
self.init_conv = nn.Conv2d(init_channels, init_dim, init_conv_kernel_size, padding = init_conv_kernel_size // 2)
|
||||
|
||||
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
|
||||
in_out = list(zip(dims[:-1], dims[1:]))
|
||||
|
||||
# time, image embeddings, and optional text encoding
|
||||
|
||||
cond_dim = default(cond_dim, dim)
|
||||
time_cond_dim = dim * 4
|
||||
|
||||
self.time_mlp = nn.Sequential(
|
||||
self.to_time_hiddens = nn.Sequential(
|
||||
SinusoidalPosEmb(dim),
|
||||
nn.Linear(dim, dim * 4),
|
||||
nn.GELU(),
|
||||
nn.Linear(dim * 4, cond_dim * num_time_tokens),
|
||||
nn.Linear(dim, time_cond_dim),
|
||||
nn.GELU()
|
||||
)
|
||||
|
||||
self.to_time_tokens = nn.Sequential(
|
||||
nn.Linear(time_cond_dim, cond_dim * num_time_tokens),
|
||||
Rearrange('b (r d) -> b r d', r = num_time_tokens)
|
||||
)
|
||||
|
||||
self.to_time_cond = nn.Sequential(
|
||||
nn.Linear(time_cond_dim, time_cond_dim)
|
||||
)
|
||||
|
||||
self.image_to_cond = nn.Sequential(
|
||||
nn.Linear(image_embed_dim, cond_dim * num_image_tokens),
|
||||
Rearrange('b (n d) -> b n d', n = num_image_tokens)
|
||||
) if image_embed_dim != cond_dim else nn.Identity()
|
||||
|
||||
self.text_to_cond = nn.LazyLinear(cond_dim) if not exists(text_embed_dim) else nn.Linear(text_embed_dim, cond_dim)
|
||||
# text encoding conditioning (optional)
|
||||
|
||||
self.text_to_cond = None
|
||||
|
||||
if cond_on_text_encodings:
|
||||
self.text_to_cond = nn.LazyLinear(cond_dim) if not exists(text_embed_dim) else nn.Linear(text_embed_dim, cond_dim)
|
||||
|
||||
# finer control over whether to condition on image embeddings and text encodings
|
||||
# so one can have the latter unets in the cascading DDPMs only focus on super-resoluting
|
||||
@@ -1016,12 +1263,23 @@ class Unet(nn.Module):
|
||||
# for classifier free guidance
|
||||
|
||||
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
|
||||
|
||||
self.max_text_len = max_text_len
|
||||
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim))
|
||||
|
||||
# attention related params
|
||||
|
||||
attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head)
|
||||
|
||||
# whether to use resnet or the (improved?) convnext blocks
|
||||
|
||||
if block_type == 'resnet':
|
||||
block_klass = ResnetBlock
|
||||
elif block_type == 'convnext':
|
||||
block_klass = ConvNextBlock
|
||||
else:
|
||||
raise ValueError(f'unimplemented block type {block_type}')
|
||||
|
||||
# layers
|
||||
|
||||
self.downs = nn.ModuleList([])
|
||||
@@ -1034,32 +1292,32 @@ class Unet(nn.Module):
|
||||
layer_cond_dim = cond_dim if not is_first else None
|
||||
|
||||
self.downs.append(nn.ModuleList([
|
||||
ConvNextBlock(dim_in, dim_out, norm = ind != 0),
|
||||
Residual(GridAttention(dim_out, window_size = sparse_attn_window, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
||||
ConvNextBlock(dim_out, dim_out, cond_dim = layer_cond_dim),
|
||||
block_klass(dim_in, dim_out, time_cond_dim = time_cond_dim),
|
||||
Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
||||
block_klass(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
|
||||
Downsample(dim_out) if not is_last else nn.Identity()
|
||||
]))
|
||||
|
||||
mid_dim = dims[-1]
|
||||
|
||||
self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
|
||||
self.mid_block1 = block_klass(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim)
|
||||
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
|
||||
self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
|
||||
self.mid_block2 = block_klass(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim)
|
||||
|
||||
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
||||
is_last = ind >= (num_resolutions - 2)
|
||||
layer_cond_dim = cond_dim if not is_last else None
|
||||
|
||||
self.ups.append(nn.ModuleList([
|
||||
ConvNextBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim),
|
||||
Residual(GridAttention(dim_in, window_size = sparse_attn_window, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
||||
ConvNextBlock(dim_in, dim_in, cond_dim = layer_cond_dim),
|
||||
block_klass(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
|
||||
Residual(LinearAttention(dim_in, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
||||
block_klass(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
|
||||
Upsample(dim_in)
|
||||
]))
|
||||
|
||||
out_dim = default(out_dim, channels)
|
||||
self.final_conv = nn.Sequential(
|
||||
ConvNextBlock(dim, dim),
|
||||
block_klass(dim, dim),
|
||||
nn.Conv2d(dim, out_dim, 1)
|
||||
)
|
||||
|
||||
@@ -1089,7 +1347,7 @@ class Unet(nn.Module):
|
||||
if cond_scale == 1:
|
||||
return logits
|
||||
|
||||
null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
|
||||
null_logits = self.forward(*args, text_cond_drop_prob = 1., image_cond_drop_prob = 1., **kwargs)
|
||||
return null_logits + (logits - null_logits) * cond_scale
|
||||
|
||||
def forward(
|
||||
@@ -1100,7 +1358,9 @@ class Unet(nn.Module):
|
||||
image_embed,
|
||||
lowres_cond_img = None,
|
||||
text_encodings = None,
|
||||
cond_drop_prob = 0.,
|
||||
text_mask = None,
|
||||
image_cond_drop_prob = 0.,
|
||||
text_cond_drop_prob = 0.,
|
||||
blur_sigma = None,
|
||||
blur_kernel_size = None
|
||||
):
|
||||
@@ -1113,14 +1373,23 @@ class Unet(nn.Module):
|
||||
if exists(lowres_cond_img):
|
||||
x = torch.cat((x, lowres_cond_img), dim = 1)
|
||||
|
||||
# initial convolution
|
||||
|
||||
x = self.init_conv(x)
|
||||
|
||||
# time conditioning
|
||||
|
||||
time_tokens = self.time_mlp(time)
|
||||
time_hiddens = self.to_time_hiddens(time)
|
||||
|
||||
time_tokens = self.to_time_tokens(time_hiddens)
|
||||
t = self.to_time_cond(time_hiddens)
|
||||
|
||||
# conditional dropout
|
||||
|
||||
cond_prob_mask = prob_mask_like((batch_size,), cond_drop_prob, device = device)
|
||||
cond_prob_mask = rearrange(cond_prob_mask, 'b -> b 1 1')
|
||||
image_keep_mask = prob_mask_like((batch_size,), 1 - image_cond_drop_prob, device = device)
|
||||
text_keep_mask = prob_mask_like((batch_size,), 1 - text_cond_drop_prob, device = device)
|
||||
|
||||
image_keep_mask, text_keep_mask = rearrange_many((image_keep_mask, text_keep_mask), 'b -> b 1 1')
|
||||
|
||||
# mask out image embedding depending on condition dropout
|
||||
# for classifier free guidance
|
||||
@@ -1131,7 +1400,7 @@ class Unet(nn.Module):
|
||||
image_tokens = self.image_to_cond(image_embed)
|
||||
|
||||
image_tokens = torch.where(
|
||||
cond_prob_mask,
|
||||
image_keep_mask,
|
||||
image_tokens,
|
||||
self.null_image_embed
|
||||
)
|
||||
@@ -1142,10 +1411,25 @@ class Unet(nn.Module):
|
||||
|
||||
if exists(text_encodings) and self.cond_on_text_encodings:
|
||||
text_tokens = self.text_to_cond(text_encodings)
|
||||
text_tokens = text_tokens[:, :self.max_text_len]
|
||||
|
||||
text_tokens_len = text_tokens.shape[1]
|
||||
remainder = self.max_text_len - text_tokens_len
|
||||
|
||||
if remainder > 0:
|
||||
text_tokens = F.pad(text_tokens, (0, 0, 0, remainder))
|
||||
|
||||
if exists(text_mask):
|
||||
if remainder > 0:
|
||||
text_mask = F.pad(text_mask, (0, remainder), value = False)
|
||||
|
||||
text_mask = rearrange(text_mask, 'b n -> b n 1')
|
||||
text_keep_mask = text_mask & text_keep_mask
|
||||
|
||||
text_tokens = torch.where(
|
||||
cond_prob_mask,
|
||||
text_keep_mask,
|
||||
text_tokens,
|
||||
self.null_text_embed[:, :text_tokens.shape[1]]
|
||||
self.null_text_embed
|
||||
)
|
||||
|
||||
# main conditioning tokens (c)
|
||||
@@ -1164,25 +1448,25 @@ class Unet(nn.Module):
|
||||
|
||||
hiddens = []
|
||||
|
||||
for convnext, sparse_attn, convnext2, downsample in self.downs:
|
||||
x = convnext(x, c)
|
||||
for block1, sparse_attn, block2, downsample in self.downs:
|
||||
x = block1(x, c, t)
|
||||
x = sparse_attn(x)
|
||||
x = convnext2(x, c)
|
||||
x = block2(x, c, t)
|
||||
hiddens.append(x)
|
||||
x = downsample(x)
|
||||
|
||||
x = self.mid_block1(x, mid_c)
|
||||
x = self.mid_block1(x, mid_c, t)
|
||||
|
||||
if exists(self.mid_attn):
|
||||
x = self.mid_attn(x)
|
||||
|
||||
x = self.mid_block2(x, mid_c)
|
||||
x = self.mid_block2(x, mid_c, t)
|
||||
|
||||
for convnext, sparse_attn, convnext2, upsample in self.ups:
|
||||
for block1, sparse_attn, block2, upsample in self.ups:
|
||||
x = torch.cat((x, hiddens.pop()), dim=1)
|
||||
x = convnext(x, c)
|
||||
x = block1(x, c, t)
|
||||
x = sparse_attn(x)
|
||||
x = convnext2(x, c)
|
||||
x = block2(x, c, t)
|
||||
x = upsample(x)
|
||||
|
||||
return self.final_conv(x)
|
||||
@@ -1233,7 +1517,8 @@ class Decoder(BaseGaussianDiffusion):
|
||||
clip,
|
||||
vae = tuple(),
|
||||
timesteps = 1000,
|
||||
cond_drop_prob = 0.2,
|
||||
image_cond_drop_prob = 0.1,
|
||||
text_cond_drop_prob = 0.5,
|
||||
loss_type = 'l1',
|
||||
beta_schedule = 'cosine',
|
||||
predict_x_start = False,
|
||||
@@ -1244,6 +1529,8 @@ class Decoder(BaseGaussianDiffusion):
|
||||
blur_sigma = 0.1, # cascading ddpm - blur sigma
|
||||
blur_kernel_size = 3, # cascading ddpm - blur kernel size
|
||||
condition_on_text_encodings = False, # the paper suggested that this didn't do much in the decoder, but i'm allowing the option for experimentation
|
||||
clip_denoised = True,
|
||||
clip_x_start = True
|
||||
):
|
||||
super().__init__(
|
||||
beta_schedule = beta_schedule,
|
||||
@@ -1317,7 +1604,13 @@ class Decoder(BaseGaussianDiffusion):
|
||||
|
||||
# classifier free guidance
|
||||
|
||||
self.cond_drop_prob = cond_drop_prob
|
||||
self.image_cond_drop_prob = image_cond_drop_prob
|
||||
self.text_cond_drop_prob = text_cond_drop_prob
|
||||
|
||||
# whether to clip when sampling
|
||||
|
||||
self.clip_denoised = clip_denoised
|
||||
self.clip_x_start = clip_x_start
|
||||
|
||||
def get_unet(self, unet_number):
|
||||
assert 0 < unet_number <= len(self.unets)
|
||||
@@ -1341,37 +1634,34 @@ class Decoder(BaseGaussianDiffusion):
|
||||
|
||||
@torch.no_grad()
|
||||
def get_image_embed(self, image):
|
||||
image = resize_image_to(image, self.clip_image_size)
|
||||
image_encoding = self.clip.visual_transformer(image)
|
||||
image_cls = image_encoding[:, 0]
|
||||
image_embed = self.clip.to_visual_latent(image_cls)
|
||||
return l2norm(image_embed)
|
||||
image_embed, _ = self.clip.embed_image(image)
|
||||
return image_embed
|
||||
|
||||
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, cond_scale = 1.):
|
||||
pred = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
|
||||
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, cond_scale = 1.):
|
||||
pred = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
|
||||
|
||||
if predict_x_start:
|
||||
x_recon = pred
|
||||
else:
|
||||
x_recon = self.predict_start_from_noise(x, t = t, noise = pred)
|
||||
|
||||
if clip_denoised and not predict_x_start:
|
||||
if clip_denoised:
|
||||
x_recon.clamp_(-1., 1.)
|
||||
|
||||
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
||||
return model_mean, posterior_variance, posterior_log_variance
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample(self, unet, x, t, image_embed, text_encodings = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, clip_denoised = True, repeat_noise = False):
|
||||
def p_sample(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, clip_denoised = True, repeat_noise = False):
|
||||
b, *_, device = *x.shape, x.device
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start)
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start)
|
||||
noise = noise_like(x.shape, device, repeat_noise)
|
||||
# no noise when t == 0
|
||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, lowres_cond_img = None, text_encodings = None, cond_scale = 1):
|
||||
def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1):
|
||||
device = self.betas.device
|
||||
|
||||
b = shape[0]
|
||||
@@ -1384,14 +1674,16 @@ class Decoder(BaseGaussianDiffusion):
|
||||
torch.full((b,), i, device = device, dtype = torch.long),
|
||||
image_embed = image_embed,
|
||||
text_encodings = text_encodings,
|
||||
text_mask = text_mask,
|
||||
cond_scale = cond_scale,
|
||||
lowres_cond_img = lowres_cond_img,
|
||||
predict_x_start = predict_x_start
|
||||
predict_x_start = predict_x_start,
|
||||
clip_denoised = clip_denoised
|
||||
)
|
||||
|
||||
return img
|
||||
|
||||
def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, predict_x_start = False, noise = None):
|
||||
def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None):
|
||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||
|
||||
x_noisy = self.q_sample(x_start = x_start, t = times, noise = noise)
|
||||
@@ -1401,8 +1693,10 @@ class Decoder(BaseGaussianDiffusion):
|
||||
times,
|
||||
image_embed = image_embed,
|
||||
text_encodings = text_encodings,
|
||||
text_mask = text_mask,
|
||||
lowres_cond_img = lowres_cond_img,
|
||||
cond_drop_prob = self.cond_drop_prob
|
||||
image_cond_drop_prob = self.image_cond_drop_prob,
|
||||
text_cond_drop_prob = self.text_cond_drop_prob,
|
||||
)
|
||||
|
||||
target = noise if not predict_x_start else x_start
|
||||
@@ -1412,19 +1706,25 @@ class Decoder(BaseGaussianDiffusion):
|
||||
|
||||
@torch.no_grad()
|
||||
@eval_decorator
|
||||
def sample(self, image_embed, text = None, cond_scale = 1.):
|
||||
def sample(
|
||||
self,
|
||||
image_embed,
|
||||
text = None,
|
||||
cond_scale = 1.,
|
||||
stop_at_unet_number = None
|
||||
):
|
||||
batch_size = image_embed.shape[0]
|
||||
|
||||
text_encodings = None
|
||||
text_encodings = text_mask = None
|
||||
if exists(text):
|
||||
_, text_encodings = self.clip.embed_text(text)
|
||||
_, text_encodings, text_mask = self.clip.embed_text(text)
|
||||
|
||||
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
|
||||
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
|
||||
|
||||
img = None
|
||||
|
||||
for unet, vae, channel, image_size, predict_x_start in tqdm(zip(self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)):
|
||||
for unet_number, unet, vae, channel, image_size, predict_x_start in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)):
|
||||
|
||||
context = self.one_unet_in_gpu(unet = unet) if image_embed.is_cuda else null_context()
|
||||
|
||||
@@ -1435,6 +1735,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
if unet.lowres_cond:
|
||||
lowres_cond_img = self.to_lowres_cond(img, target_image_size = image_size)
|
||||
|
||||
is_latent_diffusion = isinstance(vae, VQGanVAE)
|
||||
image_size = vae.get_encoded_fmap_size(image_size)
|
||||
shape = (batch_size, vae.encoded_dim, image_size, image_size)
|
||||
|
||||
@@ -1446,13 +1747,18 @@ class Decoder(BaseGaussianDiffusion):
|
||||
shape,
|
||||
image_embed = image_embed,
|
||||
text_encodings = text_encodings,
|
||||
text_mask = text_mask,
|
||||
cond_scale = cond_scale,
|
||||
predict_x_start = predict_x_start,
|
||||
clip_denoised = not is_latent_diffusion,
|
||||
lowres_cond_img = lowres_cond_img
|
||||
)
|
||||
|
||||
img = vae.decode(img)
|
||||
|
||||
if exists(stop_at_unet_number) and stop_at_unet_number == unet_number:
|
||||
break
|
||||
|
||||
return img
|
||||
|
||||
def forward(
|
||||
@@ -1483,9 +1789,9 @@ class Decoder(BaseGaussianDiffusion):
|
||||
if not exists(image_embed):
|
||||
image_embed, _ = self.clip.embed_image(image)
|
||||
|
||||
text_encodings = None
|
||||
text_encodings = text_mask = None
|
||||
if exists(text) and not exists(text_encodings):
|
||||
_, text_encodings = self.clip.embed_text(text)
|
||||
_, text_encodings, text_mask = self.clip.embed_text(text)
|
||||
|
||||
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
|
||||
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
|
||||
@@ -1500,7 +1806,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
if exists(lowres_cond_img):
|
||||
lowres_cond_img = vae.encode(lowres_cond_img)
|
||||
|
||||
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start)
|
||||
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start)
|
||||
|
||||
# main class
|
||||
|
||||
@@ -1550,4 +1856,3 @@ class DALLE2(nn.Module):
|
||||
return images[0]
|
||||
|
||||
return images
|
||||
|
||||
|
||||
29
dalle2_pytorch/optimizer.py
Normal file
29
dalle2_pytorch/optimizer.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from torch.optim import AdamW, Adam
|
||||
|
||||
def separate_weight_decayable_params(params):
|
||||
no_wd_params = set([param for param in params if param.ndim < 2])
|
||||
wd_params = set(params) - no_wd_params
|
||||
return wd_params, no_wd_params
|
||||
|
||||
def get_optimizer(
|
||||
params,
|
||||
lr = 3e-4,
|
||||
wd = 1e-2,
|
||||
betas = (0.9, 0.999),
|
||||
filter_by_requires_grad = False
|
||||
):
|
||||
if filter_by_requires_grad:
|
||||
params = list(filter(lambda t: t.requires_grad, params))
|
||||
|
||||
if wd == 0:
|
||||
return Adam(params, lr = lr, betas = betas)
|
||||
|
||||
params = set(params)
|
||||
wd_params, no_wd_params = separate_weight_decayable_params(params)
|
||||
|
||||
param_groups = [
|
||||
{'params': list(wd_params)},
|
||||
{'params': list(no_wd_params), 'weight_decay': 0},
|
||||
]
|
||||
|
||||
return AdamW(param_groups, lr = lr, weight_decay = wd, betas = betas)
|
||||
@@ -1,15 +1,43 @@
|
||||
import copy
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.cuda.amp import autocast, GradScaler
|
||||
|
||||
# image related normalizations
|
||||
# ddpms expect images to be in the range of -1 to 1
|
||||
from dalle2_pytorch.dalle2_pytorch import Decoder
|
||||
from dalle2_pytorch.optimizer import get_optimizer
|
||||
|
||||
def normalize_img(img):
|
||||
return img * 2 - 1
|
||||
# helper functions
|
||||
|
||||
def unnormalize_img(normed_img):
|
||||
return (normed_img + 1) * 0.5
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def cast_tuple(val, length = 1):
|
||||
return val if isinstance(val, tuple) else ((val,) * length)
|
||||
|
||||
def pick_and_pop(keys, d):
|
||||
values = list(map(lambda key: d.pop(key), keys))
|
||||
return dict(zip(keys, values))
|
||||
|
||||
def group_dict_by_key(cond, d):
|
||||
return_val = [dict(),dict()]
|
||||
for key in d.keys():
|
||||
match = bool(cond(key))
|
||||
ind = int(not match)
|
||||
return_val[ind][key] = d[key]
|
||||
return (*return_val,)
|
||||
|
||||
def string_begins_with(prefix, str):
|
||||
return str.startswith(prefix)
|
||||
|
||||
def group_by_key_prefix(prefix, d):
|
||||
return group_dict_by_key(partial(string_begins_with, prefix), d)
|
||||
|
||||
def groupby_prefix_and_trim(prefix, d):
|
||||
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
|
||||
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
|
||||
return kwargs_without_prefix, kwargs
|
||||
|
||||
# exponential moving average wrapper
|
||||
|
||||
@@ -18,16 +46,16 @@ class EMA(nn.Module):
|
||||
self,
|
||||
model,
|
||||
beta = 0.99,
|
||||
ema_update_after_step = 1000,
|
||||
ema_update_every = 10,
|
||||
update_after_step = 1000,
|
||||
update_every = 10,
|
||||
):
|
||||
super().__init__()
|
||||
self.beta = beta
|
||||
self.online_model = model
|
||||
self.ema_model = copy.deepcopy(model)
|
||||
|
||||
self.ema_update_after_step = ema_update_after_step # only start EMA after this step number, starting at 0
|
||||
self.ema_update_every = ema_update_every
|
||||
self.update_after_step = update_after_step # only start EMA after this step number, starting at 0
|
||||
self.update_every = update_every
|
||||
|
||||
self.register_buffer('initted', torch.Tensor([False]))
|
||||
self.register_buffer('step', torch.tensor([0.]))
|
||||
@@ -35,7 +63,7 @@ class EMA(nn.Module):
|
||||
def update(self):
|
||||
self.step += 1
|
||||
|
||||
if self.step <= self.ema_update_after_step or (self.step % self.ema_update_every) != 0:
|
||||
if self.step <= self.update_after_step or (self.step % self.update_every) != 0:
|
||||
return
|
||||
|
||||
if not self.initted:
|
||||
@@ -60,3 +88,112 @@ class EMA(nn.Module):
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.ema_model(*args, **kwargs)
|
||||
|
||||
# trainers
|
||||
|
||||
class DecoderTrainer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
decoder,
|
||||
use_ema = True,
|
||||
lr = 3e-4,
|
||||
wd = 1e-2,
|
||||
max_grad_norm = None,
|
||||
amp = False,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(decoder, Decoder)
|
||||
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
|
||||
|
||||
self.decoder = decoder
|
||||
self.num_unets = len(self.decoder.unets)
|
||||
|
||||
self.use_ema = use_ema
|
||||
|
||||
if use_ema:
|
||||
has_lazy_linear = any([type(module) == nn.LazyLinear for module in decoder.modules()])
|
||||
assert not has_lazy_linear, 'you must set the text_embed_dim on your u-nets if you plan on doing automatic exponential moving average'
|
||||
|
||||
self.ema_unets = nn.ModuleList([])
|
||||
|
||||
self.amp = amp
|
||||
|
||||
# be able to finely customize learning rate, weight decay
|
||||
# per unet
|
||||
|
||||
lr, wd = map(partial(cast_tuple, length = self.num_unets), (lr, wd))
|
||||
|
||||
for ind, (unet, unet_lr, unet_wd) in enumerate(zip(self.decoder.unets, lr, wd)):
|
||||
optimizer = get_optimizer(
|
||||
unet.parameters(),
|
||||
lr = unet_lr,
|
||||
wd = unet_wd,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
setattr(self, f'optim{ind}', optimizer) # cannot use pytorch ModuleList for some reason with optimizers
|
||||
|
||||
if self.use_ema:
|
||||
self.ema_unets.append(EMA(unet, **ema_kwargs))
|
||||
|
||||
scaler = GradScaler(enabled = amp)
|
||||
setattr(self, f'scaler{ind}', scaler)
|
||||
|
||||
# gradient clipping if needed
|
||||
|
||||
self.max_grad_norm = max_grad_norm
|
||||
|
||||
@property
|
||||
def unets(self):
|
||||
return nn.ModuleList([ema.ema_model for ema in self.ema_unets])
|
||||
|
||||
def scale(self, loss, *, unet_number):
|
||||
assert 1 <= unet_number <= self.num_unets
|
||||
index = unet_number - 1
|
||||
scaler = getattr(self, f'scaler{index}')
|
||||
return scaler.scale(loss)
|
||||
|
||||
def update(self, unet_number):
|
||||
assert 1 <= unet_number <= self.num_unets
|
||||
index = unet_number - 1
|
||||
unet = self.decoder.unets[index]
|
||||
|
||||
optimizer = getattr(self, f'optim{index}')
|
||||
scaler = getattr(self, f'scaler{index}')
|
||||
|
||||
if exists(self.max_grad_norm):
|
||||
scaler.unscale_(optimizer)
|
||||
nn.utils.clip_grad_norm_(unet.parameters(), self.max_grad_norm)
|
||||
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if self.use_ema:
|
||||
ema_unet = self.ema_unets[index]
|
||||
ema_unet.update()
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self, *args, **kwargs):
|
||||
if self.use_ema:
|
||||
trainable_unets = self.decoder.unets
|
||||
self.decoder.unets = self.unets # swap in exponential moving averaged unets for sampling
|
||||
|
||||
output = self.decoder.sample(*args, **kwargs)
|
||||
|
||||
if self.use_ema:
|
||||
self.decoder.unets = trainable_unets # restore original training unets
|
||||
return output
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
*,
|
||||
unet_number,
|
||||
divisor = 1,
|
||||
**kwargs
|
||||
):
|
||||
with autocast(enabled = self.amp):
|
||||
loss = self.decoder(x, unet_number = unet_number, **kwargs)
|
||||
return self.scale(loss / divisor, unet_number = unet_number)
|
||||
|
||||
266
dalle2_pytorch/train_vqgan_vae.py
Normal file
266
dalle2_pytorch/train_vqgan_vae.py
Normal file
@@ -0,0 +1,266 @@
|
||||
from math import sqrt
|
||||
import copy
|
||||
from random import choice
|
||||
from pathlib import Path
|
||||
from shutil import rmtree
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from PIL import Image
|
||||
from torchvision.datasets import ImageFolder
|
||||
import torchvision.transforms as T
|
||||
from torch.utils.data import Dataset, DataLoader, random_split
|
||||
from torchvision.utils import make_grid, save_image
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
from dalle2_pytorch.train import EMA
|
||||
from dalle2_pytorch.vqgan_vae import VQGanVAE
|
||||
from dalle2_pytorch.optimizer import get_optimizer
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def noop(*args, **kwargs):
|
||||
pass
|
||||
|
||||
def cycle(dl):
|
||||
while True:
|
||||
for data in dl:
|
||||
yield data
|
||||
|
||||
def cast_tuple(t):
|
||||
return t if isinstance(t, (tuple, list)) else (t,)
|
||||
|
||||
def yes_or_no(question):
|
||||
answer = input(f'{question} (y/n) ')
|
||||
return answer.lower() in ('yes', 'y')
|
||||
|
||||
def accum_log(log, new_logs):
|
||||
for key, new_value in new_logs.items():
|
||||
old_value = log.get(key, 0.)
|
||||
log[key] = old_value + new_value
|
||||
return log
|
||||
|
||||
# classes
|
||||
|
||||
class ImageDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
folder,
|
||||
image_size,
|
||||
exts = ['jpg', 'jpeg', 'png']
|
||||
):
|
||||
super().__init__()
|
||||
self.folder = folder
|
||||
self.image_size = image_size
|
||||
self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]
|
||||
|
||||
print(f'{len(self.paths)} training samples found at {folder}')
|
||||
|
||||
self.transform = T.Compose([
|
||||
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
|
||||
T.Resize(image_size),
|
||||
T.RandomHorizontalFlip(),
|
||||
T.CenterCrop(image_size),
|
||||
T.ToTensor()
|
||||
])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.paths)
|
||||
|
||||
def __getitem__(self, index):
|
||||
path = self.paths[index]
|
||||
img = Image.open(path)
|
||||
return self.transform(img)
|
||||
|
||||
# main trainer class
|
||||
|
||||
class VQGanVAETrainer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
vae,
|
||||
*,
|
||||
num_train_steps,
|
||||
lr,
|
||||
batch_size,
|
||||
folder,
|
||||
grad_accum_every,
|
||||
wd = 0.,
|
||||
save_results_every = 100,
|
||||
save_model_every = 1000,
|
||||
results_folder = './results',
|
||||
valid_frac = 0.05,
|
||||
random_split_seed = 42,
|
||||
ema_beta = 0.995,
|
||||
ema_update_after_step = 2000,
|
||||
ema_update_every = 10,
|
||||
apply_grad_penalty_every = 4,
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(vae, VQGanVAE), 'vae must be instance of VQGanVAE'
|
||||
image_size = vae.image_size
|
||||
|
||||
self.vae = vae
|
||||
self.ema_vae = EMA(vae, update_after_step = ema_update_after_step, update_every = ema_update_every)
|
||||
|
||||
self.register_buffer('steps', torch.Tensor([0]))
|
||||
|
||||
self.num_train_steps = num_train_steps
|
||||
self.batch_size = batch_size
|
||||
self.grad_accum_every = grad_accum_every
|
||||
|
||||
all_parameters = set(vae.parameters())
|
||||
discr_parameters = set(vae.discr.parameters())
|
||||
vae_parameters = all_parameters - discr_parameters
|
||||
|
||||
self.optim = get_optimizer(vae_parameters, lr = lr, wd = wd)
|
||||
self.discr_optim = get_optimizer(discr_parameters, lr = lr, wd = wd)
|
||||
|
||||
# create dataset
|
||||
|
||||
self.ds = ImageDataset(folder, image_size = image_size)
|
||||
|
||||
# split for validation
|
||||
|
||||
if valid_frac > 0:
|
||||
train_size = int((1 - valid_frac) * len(self.ds))
|
||||
valid_size = len(self.ds) - train_size
|
||||
self.ds, self.valid_ds = random_split(self.ds, [train_size, valid_size], generator = torch.Generator().manual_seed(random_split_seed))
|
||||
print(f'training with dataset of {len(self.ds)} samples and validating with randomly splitted {len(self.valid_ds)} samples')
|
||||
else:
|
||||
self.valid_ds = self.ds
|
||||
print(f'training with shared training and valid dataset of {len(self.ds)} samples')
|
||||
|
||||
# dataloader
|
||||
|
||||
self.dl = cycle(DataLoader(
|
||||
self.ds,
|
||||
batch_size = batch_size,
|
||||
shuffle = True
|
||||
))
|
||||
|
||||
self.valid_dl = cycle(DataLoader(
|
||||
self.valid_ds,
|
||||
batch_size = batch_size,
|
||||
shuffle = True
|
||||
))
|
||||
|
||||
self.save_model_every = save_model_every
|
||||
self.save_results_every = save_results_every
|
||||
|
||||
self.apply_grad_penalty_every = apply_grad_penalty_every
|
||||
|
||||
self.results_folder = Path(results_folder)
|
||||
|
||||
if len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?'):
|
||||
rmtree(str(self.results_folder))
|
||||
|
||||
self.results_folder.mkdir(parents = True, exist_ok = True)
|
||||
|
||||
def train_step(self):
|
||||
device = next(self.vae.parameters()).device
|
||||
steps = int(self.steps.item())
|
||||
apply_grad_penalty = not (steps % self.apply_grad_penalty_every)
|
||||
|
||||
self.vae.train()
|
||||
|
||||
# logs
|
||||
|
||||
logs = {}
|
||||
|
||||
# update vae (generator)
|
||||
|
||||
for _ in range(self.grad_accum_every):
|
||||
img = next(self.dl)
|
||||
img = img.to(device)
|
||||
|
||||
loss = self.vae(
|
||||
img,
|
||||
return_loss = True,
|
||||
apply_grad_penalty = apply_grad_penalty
|
||||
)
|
||||
|
||||
accum_log(logs, {'loss': loss.item() / self.grad_accum_every})
|
||||
|
||||
(loss / self.grad_accum_every).backward()
|
||||
|
||||
self.optim.step()
|
||||
self.optim.zero_grad()
|
||||
|
||||
|
||||
# update discriminator
|
||||
|
||||
if exists(self.vae.discr):
|
||||
discr_loss = 0
|
||||
for _ in range(self.grad_accum_every):
|
||||
img = next(self.dl)
|
||||
img = img.to(device)
|
||||
|
||||
loss = self.vae(img, return_discr_loss = True)
|
||||
accum_log(logs, {'discr_loss': loss.item() / self.grad_accum_every})
|
||||
|
||||
(loss / self.grad_accum_every).backward()
|
||||
|
||||
self.discr_optim.step()
|
||||
self.discr_optim.zero_grad()
|
||||
|
||||
# log
|
||||
|
||||
print(f"{steps}: vae loss: {logs['loss']} - discr loss: {logs['discr_loss']}")
|
||||
|
||||
# update exponential moving averaged generator
|
||||
|
||||
self.ema_vae.update()
|
||||
|
||||
# sample results every so often
|
||||
|
||||
if not (steps % self.save_results_every):
|
||||
for model, filename in ((self.ema_vae.ema_model, f'{steps}.ema'), (self.vae, str(steps))):
|
||||
model.eval()
|
||||
|
||||
imgs = next(self.dl)
|
||||
imgs = imgs.to(device)
|
||||
|
||||
recons = model(imgs)
|
||||
nrows = int(sqrt(self.batch_size))
|
||||
|
||||
imgs_and_recons = torch.stack((imgs, recons), dim = 0)
|
||||
imgs_and_recons = rearrange(imgs_and_recons, 'r b ... -> (b r) ...')
|
||||
|
||||
imgs_and_recons = imgs_and_recons.detach().cpu().float().clamp(0., 1.)
|
||||
grid = make_grid(imgs_and_recons, nrow = 2, normalize = True, value_range = (0, 1))
|
||||
|
||||
logs['reconstructions'] = grid
|
||||
|
||||
save_image(grid, str(self.results_folder / f'{filename}.png'))
|
||||
|
||||
print(f'{steps}: saving to {str(self.results_folder)}')
|
||||
|
||||
# save model every so often
|
||||
|
||||
if not (steps % self.save_model_every):
|
||||
state_dict = self.vae.state_dict()
|
||||
model_path = str(self.results_folder / f'vae.{steps}.pt')
|
||||
torch.save(state_dict, model_path)
|
||||
|
||||
ema_state_dict = self.ema_vae.state_dict()
|
||||
model_path = str(self.results_folder / f'vae.{steps}.ema.pt')
|
||||
torch.save(ema_state_dict, model_path)
|
||||
|
||||
print(f'{steps}: saving model to {str(self.results_folder)}')
|
||||
|
||||
self.steps += 1
|
||||
return logs
|
||||
|
||||
def train(self, log_fn = noop):
|
||||
device = next(self.vae.parameters()).device
|
||||
|
||||
while self.steps < self.num_train_steps:
|
||||
logs = self.train_step()
|
||||
log_fn(logs)
|
||||
|
||||
print('training complete')
|
||||
@@ -285,6 +285,10 @@ class ResnetEncDec(nn.Module):
|
||||
def get_encoded_fmap_size(self, image_size):
|
||||
return image_size // (2 ** self.layers)
|
||||
|
||||
@property
|
||||
def last_dec_layer(self):
|
||||
return self.decoders[-1].weight
|
||||
|
||||
def encode(self, x):
|
||||
for enc in self.encoders:
|
||||
x = enc(x)
|
||||
@@ -327,6 +331,112 @@ class ResBlock(nn.Module):
|
||||
def forward(self, x):
|
||||
return self.net(x) + x
|
||||
|
||||
# convnext enc dec
|
||||
|
||||
class ChanLayerNorm(nn.Module):
|
||||
def __init__(self, dim, eps = 1e-5):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
|
||||
|
||||
def forward(self, x):
|
||||
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
|
||||
mean = torch.mean(x, dim = 1, keepdim = True)
|
||||
return (x - mean) / (var + self.eps).sqrt() * self.g
|
||||
|
||||
class ConvNext(nn.Module):
|
||||
def __init__(self, dim, mult = 4, kernel_size = 3, ds_kernel_size = 7):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(dim, dim, ds_kernel_size, padding = ds_kernel_size // 2, groups = dim),
|
||||
ChanLayerNorm(dim),
|
||||
nn.Conv2d(dim, inner_dim, kernel_size, padding = kernel_size // 2),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(inner_dim, dim, kernel_size, padding = kernel_size // 2)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x) + x
|
||||
|
||||
class ConvNextEncDec(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
*,
|
||||
channels = 3,
|
||||
layers = 4,
|
||||
layer_mults = None,
|
||||
num_blocks = 1,
|
||||
first_conv_kernel_size = 5,
|
||||
use_attn = True,
|
||||
attn_dim_head = 64,
|
||||
attn_heads = 8,
|
||||
attn_dropout = 0.,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.layers = layers
|
||||
|
||||
self.encoders = MList([])
|
||||
self.decoders = MList([])
|
||||
|
||||
layer_mults = default(layer_mults, list(map(lambda t: 2 ** t, range(layers))))
|
||||
assert len(layer_mults) == layers, 'layer multipliers must be equal to designated number of layers'
|
||||
|
||||
layer_dims = [dim * mult for mult in layer_mults]
|
||||
dims = (dim, *layer_dims)
|
||||
|
||||
self.encoded_dim = dims[-1]
|
||||
|
||||
dim_pairs = zip(dims[:-1], dims[1:])
|
||||
|
||||
append = lambda arr, t: arr.append(t)
|
||||
prepend = lambda arr, t: arr.insert(0, t)
|
||||
|
||||
if not isinstance(num_blocks, tuple):
|
||||
num_blocks = (*((0,) * (layers - 1)), num_blocks)
|
||||
|
||||
if not isinstance(use_attn, tuple):
|
||||
use_attn = (*((False,) * (layers - 1)), use_attn)
|
||||
|
||||
assert len(num_blocks) == layers, 'number of blocks config must be equal to number of layers'
|
||||
assert len(use_attn) == layers
|
||||
|
||||
for layer_index, (dim_in, dim_out), layer_num_blocks, layer_use_attn in zip(range(layers), dim_pairs, num_blocks, use_attn):
|
||||
append(self.encoders, nn.Sequential(nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), leaky_relu()))
|
||||
prepend(self.decoders, nn.Sequential(nn.ConvTranspose2d(dim_out, dim_in, 4, 2, 1), leaky_relu()))
|
||||
|
||||
if layer_use_attn:
|
||||
prepend(self.decoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
|
||||
|
||||
for _ in range(layer_num_blocks):
|
||||
append(self.encoders, ConvNext(dim_out))
|
||||
prepend(self.decoders, ConvNext(dim_out))
|
||||
|
||||
if layer_use_attn:
|
||||
append(self.encoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
|
||||
|
||||
prepend(self.encoders, nn.Conv2d(channels, dim, first_conv_kernel_size, padding = first_conv_kernel_size // 2))
|
||||
append(self.decoders, nn.Conv2d(dim, channels, 1))
|
||||
|
||||
def get_encoded_fmap_size(self, image_size):
|
||||
return image_size // (2 ** self.layers)
|
||||
|
||||
@property
|
||||
def last_dec_layer(self):
|
||||
return self.decoders[-1].weight
|
||||
|
||||
def encode(self, x):
|
||||
for enc in self.encoders:
|
||||
x = enc(x)
|
||||
return x
|
||||
|
||||
def decode(self, x):
|
||||
for dec in self.decoders:
|
||||
x = dec(x)
|
||||
return x
|
||||
|
||||
# vqgan attention layer
|
||||
|
||||
class VQGanAttention(nn.Module):
|
||||
@@ -504,6 +614,10 @@ class ViTEncDec(nn.Module):
|
||||
def get_encoded_fmap_size(self, image_size):
|
||||
return image_size // self.patch_size
|
||||
|
||||
@property
|
||||
def last_dec_layer(self):
|
||||
return self.decoder[-3][-1].weight
|
||||
|
||||
def encode(self, x):
|
||||
return self.encoder(x)
|
||||
|
||||
@@ -568,6 +682,8 @@ class VQGanVAE(nn.Module):
|
||||
enc_dec_klass = ResnetEncDec
|
||||
elif vae_type == 'vit':
|
||||
enc_dec_klass = ViTEncDec
|
||||
elif vae_type == 'convnext':
|
||||
enc_dec_klass = ConvNextEncDec
|
||||
else:
|
||||
raise ValueError(f'{vae_type} not valid')
|
||||
|
||||
@@ -739,7 +855,7 @@ class VQGanVAE(nn.Module):
|
||||
|
||||
# calculate adaptive weight
|
||||
|
||||
last_dec_layer = self.decoders[-1].weight
|
||||
last_dec_layer = self.enc_dec.last_dec_layer
|
||||
|
||||
norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_dec_layer).norm(p = 2)
|
||||
norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_dec_layer).norm(p = 2)
|
||||
|
||||
7
setup.py
7
setup.py
@@ -10,7 +10,7 @@ setup(
|
||||
'dream = dalle2_pytorch.cli:dream'
|
||||
],
|
||||
},
|
||||
version = '0.0.65',
|
||||
version = '0.0.95',
|
||||
license='MIT',
|
||||
description = 'DALL-E 2',
|
||||
author = 'Phil Wang',
|
||||
@@ -23,15 +23,18 @@ setup(
|
||||
],
|
||||
install_requires=[
|
||||
'click',
|
||||
'clip-anytorch',
|
||||
'einops>=0.4',
|
||||
'einops-exts>=0.0.3',
|
||||
'embedding-reader',
|
||||
'kornia>=0.5.4',
|
||||
'pillow',
|
||||
'torch>=1.10',
|
||||
'torchvision',
|
||||
'tqdm',
|
||||
'vector-quantize-pytorch',
|
||||
'x-clip>=0.4.4',
|
||||
'webdataset',
|
||||
'x-clip>=0.5.1',
|
||||
'youtokentome'
|
||||
],
|
||||
classifiers=[
|
||||
|
||||
291
train_diffusion_prior.py
Normal file
291
train_diffusion_prior.py
Normal file
@@ -0,0 +1,291 @@
|
||||
import os
|
||||
import math
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from embedding_reader import EmbeddingReader
|
||||
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork
|
||||
from dalle2_pytorch.optimizer import get_optimizer
|
||||
from torch.cuda.amp import autocast,GradScaler
|
||||
|
||||
import time
|
||||
from tqdm import tqdm
|
||||
|
||||
import wandb
|
||||
os.environ["WANDB_SILENT"] = "true"
|
||||
NUM_TEST_EMBEDDINGS = 100 # for cosine similarity reporting during training
|
||||
REPORT_METRICS_EVERY = 100 # for cosine similarity and other metric reporting during training
|
||||
|
||||
|
||||
def eval_model(model,device,image_reader,text_reader,start,end,batch_size,loss_type,phase="Validation"):
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
total_loss = 0.
|
||||
total_samples = 0.
|
||||
|
||||
for emb_images, emb_text in zip(image_reader(batch_size=batch_size, start=start, end=end),
|
||||
text_reader(batch_size=batch_size, start=start, end=end)):
|
||||
|
||||
emb_images_tensor = torch.tensor(emb_images[0]).to(device)
|
||||
emb_text_tensor = torch.tensor(emb_text[0]).to(device)
|
||||
|
||||
batches = emb_images_tensor.shape[0]
|
||||
|
||||
loss = model(text_embed = emb_text_tensor, image_embed = emb_images_tensor)
|
||||
|
||||
total_loss += loss.item() * batches
|
||||
total_samples += batches
|
||||
|
||||
avg_loss = (total_loss / total_samples)
|
||||
wandb.log({f'{phase} {loss_type}': avg_loss})
|
||||
|
||||
def save_model(save_path, state_dict):
|
||||
# Saving State Dict
|
||||
print("====================================== Saving checkpoint ======================================")
|
||||
torch.save(state_dict, save_path+'/'+str(time.time())+'_saved_model.pth')
|
||||
|
||||
def report_cosine_sims(diffusion_prior,image_reader,text_reader,train_set_size,val_set_size,NUM_TEST_EMBEDDINGS,device):
|
||||
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
|
||||
|
||||
tstart = train_set_size+val_set_size
|
||||
tend = train_set_size+val_set_size+NUM_TEST_EMBEDDINGS
|
||||
|
||||
for embt, embi in zip(text_reader(batch_size = NUM_TEST_EMBEDDINGS, start=tstart, end = tend),image_reader(batch_size = NUM_TEST_EMBEDDINGS, start=tstart, end = tend)):
|
||||
text_embed = torch.tensor(embt[0]).to(device)
|
||||
text_embed = text_embed / text_embed.norm(dim=1, keepdim=True)
|
||||
test_text_cond = dict(text_embed = text_embed)
|
||||
|
||||
test_image_embeddings = torch.tensor(embi[0]).to(device)
|
||||
test_image_embeddings = test_image_embeddings / test_image_embeddings.norm(dim=1, keepdim=True)
|
||||
|
||||
predicted_image_embeddings = diffusion_prior.p_sample_loop((NUM_TEST_EMBEDDINGS, 768), text_cond = test_text_cond)
|
||||
predicted_image_embeddings = predicted_image_embeddings / predicted_image_embeddings.norm(dim=1, keepdim=True)
|
||||
|
||||
original_similarity = cos(text_embed,test_image_embeddings).cpu().numpy()
|
||||
predicted_similarity = cos(text_embed,predicted_image_embeddings).cpu().numpy()
|
||||
|
||||
wandb.log({"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity)})
|
||||
wandb.log({"CosineSimilarity(text_embed,predicted_image_embed)":np.mean(predicted_similarity)})
|
||||
|
||||
return np.mean(predicted_similarity - original_similarity)
|
||||
|
||||
|
||||
|
||||
def train(image_embed_dim,
|
||||
image_embed_url,
|
||||
text_embed_url,
|
||||
batch_size,
|
||||
train_percent,
|
||||
val_percent,
|
||||
test_percent,
|
||||
num_epochs,
|
||||
dp_loss_type,
|
||||
clip,
|
||||
dp_condition_on_text_encodings,
|
||||
dp_timesteps,
|
||||
dp_l2norm_output,
|
||||
dp_normformer,
|
||||
dp_cond_drop_prob,
|
||||
dpn_depth,
|
||||
dpn_dim_head,
|
||||
dpn_heads,
|
||||
save_interval,
|
||||
save_path,
|
||||
device,
|
||||
learning_rate=0.001,
|
||||
max_grad_norm=0.5,
|
||||
weight_decay=0.01,
|
||||
amp=False):
|
||||
|
||||
# DiffusionPriorNetwork
|
||||
prior_network = DiffusionPriorNetwork(
|
||||
dim = image_embed_dim,
|
||||
depth = dpn_depth,
|
||||
dim_head = dpn_dim_head,
|
||||
heads = dpn_heads,
|
||||
normformer = dp_normformer,
|
||||
l2norm_output = dp_l2norm_output).to(device)
|
||||
|
||||
# DiffusionPrior with text embeddings and image embeddings pre-computed
|
||||
diffusion_prior = DiffusionPrior(
|
||||
net = prior_network,
|
||||
clip = clip,
|
||||
image_embed_dim = image_embed_dim,
|
||||
timesteps = dp_timesteps,
|
||||
cond_drop_prob = dp_cond_drop_prob,
|
||||
loss_type = dp_loss_type,
|
||||
condition_on_text_encodings = dp_condition_on_text_encodings).to(device)
|
||||
|
||||
# Get image and text embeddings from the servers
|
||||
print("==============Downloading embeddings - image and text====================")
|
||||
image_reader = EmbeddingReader(embeddings_folder=image_embed_url, file_format="npy")
|
||||
text_reader = EmbeddingReader(embeddings_folder=text_embed_url, file_format="npy")
|
||||
num_data_points = text_reader.count
|
||||
|
||||
# Create save_path if it doesn't exist
|
||||
if not os.path.exists(save_path):
|
||||
os.makedirs(save_path)
|
||||
|
||||
### Training code ###
|
||||
scaler = GradScaler(enabled=amp)
|
||||
optimizer = get_optimizer(diffusion_prior.net.parameters(), wd=weight_decay, lr=learning_rate)
|
||||
epochs = num_epochs
|
||||
|
||||
step = 0
|
||||
t = time.time()
|
||||
|
||||
train_set_size = int(train_percent*num_data_points)
|
||||
val_set_size = int(val_percent*num_data_points)
|
||||
|
||||
for _ in range(epochs):
|
||||
diffusion_prior.train()
|
||||
|
||||
for emb_images,emb_text in zip(image_reader(batch_size=batch_size, start=0, end=train_set_size),
|
||||
text_reader(batch_size=batch_size, start=0, end=train_set_size)):
|
||||
emb_images_tensor = torch.tensor(emb_images[0]).to(device)
|
||||
emb_text_tensor = torch.tensor(emb_text[0]).to(device)
|
||||
|
||||
with autocast(enabled=amp):
|
||||
loss = diffusion_prior(text_embed = emb_text_tensor,image_embed = emb_images_tensor)
|
||||
scaler.scale(loss).backward()
|
||||
|
||||
# Samples per second
|
||||
step+=1
|
||||
samples_per_sec = batch_size*step/(time.time()-t)
|
||||
# Save checkpoint every save_interval minutes
|
||||
if(int(time.time()-t) >= 60*save_interval):
|
||||
t = time.time()
|
||||
|
||||
save_model(
|
||||
save_path,
|
||||
dict(model=diffusion_prior.state_dict(), optimizer=optimizer.state_dict(), scaler=scaler.state_dict()))
|
||||
|
||||
# Log to wandb
|
||||
wandb.log({"Training loss": loss.item(),
|
||||
"Steps": step,
|
||||
"Samples per second": samples_per_sec})
|
||||
# Log cosineSim(text_embed,predicted_image_embed) - cosineSim(text_embed,image_embed)
|
||||
# Use NUM_TEST_EMBEDDINGS samples from the test set each time
|
||||
# Get embeddings from the most recently saved model
|
||||
if(step % REPORT_METRICS_EVERY) == 0:
|
||||
diff_cosine_sim = report_cosine_sims(diffusion_prior,
|
||||
image_reader,
|
||||
text_reader,
|
||||
train_set_size,
|
||||
val_set_size,
|
||||
NUM_TEST_EMBEDDINGS,
|
||||
device)
|
||||
wandb.log({"Cosine similarity difference": diff_cosine_sim})
|
||||
|
||||
scaler.unscale_(optimizer)
|
||||
nn.utils.clip_grad_norm_(diffusion_prior.parameters(), max_grad_norm)
|
||||
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
optimizer.zero_grad()
|
||||
|
||||
### Evaluate model(validation run) ###
|
||||
start = train_set_size
|
||||
end=start+val_set_size
|
||||
eval_model(diffusion_prior,device,image_reader,text_reader,start,end,batch_size,dp_loss_type,phase="Validation")
|
||||
|
||||
### Test run ###
|
||||
test_set_size = int(test_percent*train_set_size)
|
||||
start=train_set_size+val_set_size
|
||||
end=num_data_points
|
||||
eval_model(diffusion_prior,device,image_reader,text_reader,start,end,batch_size,dp_loss_type,phase="Test")
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
# Logging
|
||||
parser.add_argument("--wandb-entity", type=str, default="laion")
|
||||
parser.add_argument("--wandb-project", type=str, default="diffusion-prior")
|
||||
parser.add_argument("--wandb-name", type=str, default="laion-dprior")
|
||||
parser.add_argument("--wandb-dataset", type=str, default="LAION-5B")
|
||||
parser.add_argument("--wandb-arch", type=str, default="DiffusionPrior")
|
||||
# URLs for embeddings
|
||||
parser.add_argument("--image-embed-url", type=str, default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/")
|
||||
parser.add_argument("--text-embed-url", type=str, default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/")
|
||||
# Hyperparameters
|
||||
parser.add_argument("--learning-rate", type=float, default=1.1e-4)
|
||||
parser.add_argument("--weight-decay", type=float, default=6.02e-2)
|
||||
parser.add_argument("--max-grad-norm", type=float, default=0.5)
|
||||
parser.add_argument("--batch-size", type=int, default=10**4)
|
||||
parser.add_argument("--num-epochs", type=int, default=5)
|
||||
# Image embed dimension
|
||||
parser.add_argument("--image-embed-dim", type=int, default=768)
|
||||
# Train-test split
|
||||
parser.add_argument("--train-percent", type=float, default=0.7)
|
||||
parser.add_argument("--val-percent", type=float, default=0.2)
|
||||
parser.add_argument("--test-percent", type=float, default=0.1)
|
||||
# LAION training(pre-computed embeddings)
|
||||
# DiffusionPriorNetwork(dpn) parameters
|
||||
parser.add_argument("--dpn-depth", type=int, default=6)
|
||||
parser.add_argument("--dpn-dim-head", type=int, default=64)
|
||||
parser.add_argument("--dpn-heads", type=int, default=8)
|
||||
# DiffusionPrior(dp) parameters
|
||||
parser.add_argument("--dp-condition-on-text-encodings", type=bool, default=False)
|
||||
parser.add_argument("--dp-timesteps", type=int, default=100)
|
||||
parser.add_argument("--dp-l2norm-output", type=bool, default=False)
|
||||
parser.add_argument("--dp-normformer", type=bool, default=False)
|
||||
parser.add_argument("--dp-cond-drop-prob", type=float, default=0.1)
|
||||
parser.add_argument("--dp-loss-type", type=str, default="l2")
|
||||
parser.add_argument("--clip", type=str, default=None)
|
||||
parser.add_argument("--amp", type=bool, default=False)
|
||||
# Model checkpointing interval(minutes)
|
||||
parser.add_argument("--save-interval", type=int, default=30)
|
||||
parser.add_argument("--save-path", type=str, default="./diffusion_prior_checkpoints")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print("Setting up wandb logging... Please wait...")
|
||||
|
||||
wandb.init(
|
||||
entity=args.wandb_entity,
|
||||
project=args.wandb_project,
|
||||
config={
|
||||
"learning_rate": args.learning_rate,
|
||||
"architecture": args.wandb_arch,
|
||||
"dataset": args.wandb_dataset,
|
||||
"epochs": args.num_epochs,
|
||||
})
|
||||
|
||||
print("wandb logging setup done!")
|
||||
# Obtain the utilized device.
|
||||
|
||||
has_cuda = torch.cuda.is_available()
|
||||
if has_cuda:
|
||||
device = torch.device("cuda:0")
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
# Training loop
|
||||
train(args.image_embed_dim,
|
||||
args.image_embed_url,
|
||||
args.text_embed_url,
|
||||
args.batch_size,
|
||||
args.train_percent,
|
||||
args.val_percent,
|
||||
args.test_percent,
|
||||
args.num_epochs,
|
||||
args.dp_loss_type,
|
||||
args.clip,
|
||||
args.dp_condition_on_text_encodings,
|
||||
args.dp_timesteps,
|
||||
args.dp_l2norm_output,
|
||||
args.dp_normformer,
|
||||
args.dp_cond_drop_prob,
|
||||
args.dpn_depth,
|
||||
args.dpn_dim_head,
|
||||
args.dpn_heads,
|
||||
args.save_interval,
|
||||
args.save_path,
|
||||
device,
|
||||
args.learning_rate,
|
||||
args.max_grad_norm,
|
||||
args.weight_decay,
|
||||
args.amp)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user