mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 11:34:29 +01:00
Compare commits
60 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
430961cb97 | ||
|
|
721f9687c1 | ||
|
|
e0524a6aff | ||
|
|
c85e0d5c35 | ||
|
|
db0642c4cd | ||
|
|
bb86ab2404 | ||
|
|
ae056dd67c | ||
|
|
033d6b0ce8 | ||
|
|
c7ea8748db | ||
|
|
13382885d9 | ||
|
|
c3d4a7ffe4 | ||
|
|
164d9be444 | ||
|
|
5562ec6be2 | ||
|
|
89ff04cfe2 | ||
|
|
f4016f6302 | ||
|
|
1212f7058d | ||
|
|
dab106d4e5 | ||
|
|
bb151ca6b1 | ||
|
|
4a59dea4cf | ||
|
|
ecf9e8027d | ||
|
|
36c5079bd7 | ||
|
|
4a4c7ac9e6 | ||
|
|
fad7481479 | ||
|
|
123658d082 | ||
|
|
11d4e11f10 | ||
|
|
99778e12de | ||
|
|
0f0011caf0 | ||
|
|
7b7a62044a | ||
|
|
156fe5ed9f | ||
|
|
5ec34bebe1 | ||
|
|
8eaacf1ac1 | ||
|
|
e66c7b0249 | ||
|
|
f7cd4a0992 | ||
|
|
68e7d2f241 | ||
|
|
74f222596a | ||
|
|
aa6772dcff | ||
|
|
71d0c4edae | ||
|
|
f7eee09d8b | ||
|
|
89de5af63e | ||
|
|
4ec6d0ba81 | ||
|
|
aee92dba4a | ||
|
|
b0cd5f24b6 | ||
|
|
b494ed81d4 | ||
|
|
ff3474f05c | ||
|
|
d5293f19f1 | ||
|
|
e697183849 | ||
|
|
591d37e266 | ||
|
|
d1f02e8f49 | ||
|
|
9faab59b23 | ||
|
|
5d27029e98 | ||
|
|
3115fa17b3 | ||
|
|
124d8577c8 | ||
|
|
2db0c9794c | ||
|
|
2277b47ffd | ||
|
|
28b58e568c | ||
|
|
924455d97d | ||
|
|
6021945fc8 | ||
|
|
6f76652d11 | ||
|
|
3dda2570ed | ||
|
|
2f3c02dba8 |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -1,3 +1,6 @@
|
||||
# default experiment tracker data
|
||||
.tracker-data/
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
|
||||
154
README.md
154
README.md
@@ -14,6 +14,16 @@ Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord
|
||||
|
||||
There was enough interest for a <a href="https://github.com/lucidrains/dalle2-jax">Jax version</a>. I will also eventually extend this to <a href="https://github.com/lucidrains/dalle2-video">text to video</a>, once the repository is in a good place.
|
||||
|
||||
## Status
|
||||
|
||||
- A research group has used the code in this repository to train a functional diffusion prior for their CLIP generations. Will share their work once they release their preprint. This, and <a href="https://github.com/crowsonkb">Katherine's</a> own experiments, validate OpenAI's finding that the extra prior increases variety of generations.
|
||||
|
||||
- Decoder is now verified working for unconditional generation on my experimental setup for Oxford flowers. 2 researchers have also confirmed Decoder is working for them.
|
||||
|
||||
<img src="./samples/oxford.png" width="600px" />
|
||||
|
||||
*ongoing at 21k steps*
|
||||
|
||||
## Install
|
||||
|
||||
```bash
|
||||
@@ -508,7 +518,7 @@ To use a pretrained OpenAI CLIP, simply import `OpenAIClipAdapter` and pass it i
|
||||
import torch
|
||||
from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder, OpenAIClipAdapter
|
||||
|
||||
# openai pretrained clip - defaults to ViT/B-32
|
||||
# openai pretrained clip - defaults to ViT-B/32
|
||||
|
||||
clip = OpenAIClipAdapter()
|
||||
|
||||
@@ -706,7 +716,7 @@ mock_image_embed = torch.randn(1, 512).cuda()
|
||||
images = decoder.sample(mock_image_embed) # (1, 3, 1024, 1024)
|
||||
```
|
||||
|
||||
## Training wrapper (wip)
|
||||
## Training wrapper
|
||||
|
||||
### Decoder Training
|
||||
|
||||
@@ -732,8 +742,8 @@ clip = CLIP(
|
||||
|
||||
# mock data
|
||||
|
||||
text = torch.randint(0, 49408, (4, 256)).cuda()
|
||||
images = torch.randn(4, 3, 256, 256).cuda()
|
||||
text = torch.randint(0, 49408, (32, 256)).cuda()
|
||||
images = torch.randn(32, 3, 256, 256).cuda()
|
||||
|
||||
# decoder (with unet)
|
||||
|
||||
@@ -774,8 +784,12 @@ decoder_trainer = DecoderTrainer(
|
||||
)
|
||||
|
||||
for unet_number in (1, 2):
|
||||
loss = decoder_trainer(images, text = text, unet_number = unet_number) # use the decoder_trainer forward
|
||||
loss.backward()
|
||||
loss = decoder_trainer(
|
||||
images,
|
||||
text = text,
|
||||
unet_number = unet_number, # which unet to train on
|
||||
max_batch_size = 4 # gradient accumulation - this sets the maximum batch size in which to do forward and backwards pass - for this example 32 / 4 == 8 times
|
||||
)
|
||||
|
||||
decoder_trainer.update(unet_number) # update the specific unet as well as its exponential moving average
|
||||
|
||||
@@ -810,8 +824,8 @@ clip = CLIP(
|
||||
|
||||
# mock data
|
||||
|
||||
text = torch.randint(0, 49408, (4, 256)).cuda()
|
||||
images = torch.randn(4, 3, 256, 256).cuda()
|
||||
text = torch.randint(0, 49408, (512, 256)).cuda()
|
||||
images = torch.randn(512, 3, 256, 256).cuda()
|
||||
|
||||
# prior networks (with transformer)
|
||||
|
||||
@@ -838,16 +852,70 @@ diffusion_prior_trainer = DiffusionPriorTrainer(
|
||||
ema_update_every = 10,
|
||||
)
|
||||
|
||||
loss = diffusion_prior_trainer(text, images)
|
||||
loss.backward()
|
||||
loss = diffusion_prior_trainer(text, images, max_batch_size = 4)
|
||||
diffusion_prior_trainer.update() # this will update the optimizer as well as the exponential moving averaged diffusion prior
|
||||
|
||||
# after much of the above three lines in a loop
|
||||
# you can sample from the exponential moving average of the diffusion prior identically to how you do so for DiffusionPrior
|
||||
|
||||
image_embeds = diffusion_prior_trainer.sample(text) # (4, 512) - exponential moving averaged image embeddings
|
||||
image_embeds = diffusion_prior_trainer.sample(text, max_batch_size = 4) # (512, 512) - exponential moving averaged image embeddings
|
||||
```
|
||||
|
||||
## Bonus
|
||||
|
||||
### Unconditional Training
|
||||
|
||||
The repository also contains the means to train unconditional DDPM model, or even cascading DDPMs. You simply have to set `unconditional = True` in the `Decoder`
|
||||
|
||||
ex.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from dalle2_pytorch import Unet, Decoder, DecoderTrainer
|
||||
|
||||
# unet for the cascading ddpm
|
||||
|
||||
unet1 = Unet(
|
||||
dim = 128,
|
||||
dim_mults=(1, 2, 4, 8)
|
||||
).cuda()
|
||||
|
||||
unet2 = Unet(
|
||||
dim = 32,
|
||||
dim_mults = (1, 2, 4, 8, 16)
|
||||
).cuda()
|
||||
|
||||
# decoder, which contains the unets
|
||||
|
||||
decoder = Decoder(
|
||||
unet = (unet1, unet2),
|
||||
image_sizes = (256, 512), # first unet up to 256px, then second to 512px
|
||||
timesteps = 1000,
|
||||
unconditional = True
|
||||
).cuda()
|
||||
|
||||
# decoder trainer
|
||||
|
||||
decoder_trainer = DecoderTrainer(decoder)
|
||||
|
||||
# images (get a lot of this)
|
||||
|
||||
images = torch.randn(1, 3, 512, 512).cuda()
|
||||
|
||||
# feed images into decoder
|
||||
|
||||
for i in (1, 2):
|
||||
loss = decoder_trainer(images, unet_number = i)
|
||||
decoder_trainer.update(unet_number = i)
|
||||
|
||||
# do the above for many many many many images
|
||||
# then it will learn to generate images
|
||||
|
||||
images = decoder_trainer.sample(batch_size = 36, max_batch_size = 4) # (36, 3, 512, 512)
|
||||
```
|
||||
|
||||
## Dataloaders
|
||||
|
||||
### Decoder Dataloaders
|
||||
|
||||
In order to make loading data simple and efficient, we include some general dataloaders that can be used to train portions of the network.
|
||||
@@ -892,14 +960,14 @@ dataset = ImageEmbeddingDataset(
|
||||
)
|
||||
```
|
||||
|
||||
## Scripts
|
||||
### Scripts (wip)
|
||||
|
||||
### Using the `train_diffusion_prior.py` script
|
||||
#### `train_diffusion_prior.py`
|
||||
|
||||
This script allows training the DiffusionPrior on pre-computed text and image embeddings. The working example below elucidates this process.
|
||||
Please note that the script internally passes text_embed and image_embed to the DiffusionPrior, unlike the example below.
|
||||
|
||||
### Usage
|
||||
#### Usage
|
||||
|
||||
```bash
|
||||
$ python train_diffusion_prior.py
|
||||
@@ -907,58 +975,49 @@ $ python train_diffusion_prior.py
|
||||
|
||||
The most significant parameters for the script are as follows:
|
||||
|
||||
--image-embed-url, default = "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/")
|
||||
- `image-embed-url`, default = `"https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/"`
|
||||
|
||||
--text-embed-url, default = "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/")
|
||||
- `text-embed-url`, default = `"https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/"`
|
||||
|
||||
--image-embed-dim, default=768 - 768 corresponds to the ViT iL/14 embedding size,change it to what your chosen ViT generates
|
||||
- `image-embed-dim`, default = `768` - 768 corresponds to the ViT iL/14 embedding size,change it to what your chosen ViT generates
|
||||
|
||||
--learning-rate, default=1.1e-4
|
||||
- `learning-rate`, default = `1.1e-4`
|
||||
|
||||
--weight-decay, default=6.02e-2
|
||||
- `weight-decay`, default = `6.02e-2`
|
||||
|
||||
--max-grad-norm, default=0.5
|
||||
- `max-grad-norm`, default = `0.5`
|
||||
|
||||
--batch-size, default=10 ** 4
|
||||
- `batch-size`, default = `10 ** 4`
|
||||
|
||||
--num-epochs, default=5
|
||||
- `num-epochs`, default = `5`
|
||||
|
||||
--clip, default=None # Signals the prior to use pre-computed embeddings
|
||||
- `clip`, default = `None` # Signals the prior to use pre-computed embeddings
|
||||
|
||||
### Sample wandb run log
|
||||
|
||||
Please find a sample wandb run log at : https://wandb.ai/laion/diffusion-prior/runs/1blxu24j
|
||||
|
||||
### Loading and saving the Diffusion Prior model
|
||||
#### Loading and Saving the DiffusionPrior model
|
||||
|
||||
Two methods are provided, load_diffusion_model and save_diffusion_model, the names being self-explanatory.
|
||||
|
||||
## from dalle2_pytorch.train import load_diffusion_model, save_diffusion_model
|
||||
```python
|
||||
from dalle2_pytorch.train import load_diffusion_model, save_diffusion_model
|
||||
```
|
||||
|
||||
##### Loading
|
||||
|
||||
load_diffusion_model(dprior_path, device)
|
||||
|
||||
dprior_path : path to saved model(.pth)
|
||||
|
||||
device : the cuda device you're running on
|
||||
|
||||
##### Saving
|
||||
|
||||
save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim)
|
||||
|
||||
save_path : path to save at
|
||||
|
||||
model : object of Diffusion_Prior
|
||||
|
||||
optimizer : optimizer object - see train_diffusion_prior.py for how to create one.
|
||||
|
||||
e.g: optimizer = get_optimizer(diffusion_prior.net.parameters(), wd=weight_decay, lr=learning_rate)
|
||||
|
||||
scaler : a GradScaler object.
|
||||
|
||||
e.g: scaler = GradScaler(enabled=amp)
|
||||
|
||||
config : config object created in train_diffusion_prior.py - see file for example.
|
||||
|
||||
image_embed_dim - the dimension of the image_embedding
|
||||
|
||||
e.g: 768
|
||||
|
||||
## CLI (wip)
|
||||
@@ -1004,6 +1063,7 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
- [x] make sure the cascading ddpm in the repository can be trained unconditionally, offer a one-line CLI tool for training on a folder of images
|
||||
- [x] bring in cross-scale embedding from iclr paper https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/crossformer.py#L14
|
||||
- [x] cross embed layers for downsampling, as an option
|
||||
- [x] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a>
|
||||
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
|
||||
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
|
||||
- [ ] train on a toy task, offer in colab
|
||||
@@ -1011,12 +1071,15 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
|
||||
- [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824
|
||||
- [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove
|
||||
- [ ] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a>
|
||||
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
|
||||
- [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783
|
||||
- [ ] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
|
||||
- [ ] bring in skip-layer excitatons (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training
|
||||
- [ ] decoder needs one day worth of refactor for tech debt
|
||||
- [ ] allow for unet to be able to condition non-cross attention style as well
|
||||
- [ ] for all model classes with hyperparameters that changes the network architecture, make it requirement that they must expose a config property, and write a simple function that asserts that it restores the object correctly
|
||||
- [ ] for both diffusion prior and decoder, all exponential moving averaged models needs to be saved and restored as well (as well as the step number)
|
||||
- [ ] read the paper, figure it out, and build it https://github.com/lucidrains/DALLE2-pytorch/issues/89
|
||||
|
||||
## Citations
|
||||
|
||||
@@ -1105,4 +1168,13 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@article{ho2021cascaded,
|
||||
title = {Cascaded Diffusion Models for High Fidelity Image Generation},
|
||||
author = {Ho, Jonathan and Saharia, Chitwan and Chan, William and Fleet, David J and Norouzi, Mohammad and Salimans, Tim},
|
||||
journal = {arXiv preprint arXiv:2106.15282},
|
||||
year = {2021}
|
||||
}
|
||||
```
|
||||
|
||||
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
|
||||
from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter
|
||||
from dalle2_pytorch.train import DecoderTrainer, DiffusionPriorTrainer
|
||||
from dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer
|
||||
|
||||
from dalle2_pytorch.vqgan_vae import VQGanVAE
|
||||
from x_clip import CLIP
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import math
|
||||
from tqdm import tqdm
|
||||
from inspect import isfunction
|
||||
from functools import partial
|
||||
from functools import partial, wraps
|
||||
from contextlib import contextmanager
|
||||
from collections import namedtuple
|
||||
from pathlib import Path
|
||||
@@ -33,6 +33,10 @@ from rotary_embedding_torch import RotaryEmbedding
|
||||
from x_clip import CLIP
|
||||
from coca_pytorch import CoCa
|
||||
|
||||
# constants
|
||||
|
||||
NAT = 1. / math.log(2.)
|
||||
|
||||
# helper functions
|
||||
|
||||
def exists(val):
|
||||
@@ -41,6 +45,14 @@ def exists(val):
|
||||
def identity(t, *args, **kwargs):
|
||||
return t
|
||||
|
||||
def maybe(fn):
|
||||
@wraps(fn)
|
||||
def inner(x):
|
||||
if not exists(x):
|
||||
return x
|
||||
return fn(x)
|
||||
return inner
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
@@ -49,6 +61,9 @@ def default(val, d):
|
||||
def cast_tuple(val, length = 1):
|
||||
return val if isinstance(val, tuple) else ((val,) * length)
|
||||
|
||||
def module_device(module):
|
||||
return next(module.parameters()).device
|
||||
|
||||
@contextmanager
|
||||
def null_context(*args, **kwargs):
|
||||
yield
|
||||
@@ -91,6 +106,9 @@ def freeze_model_and_make_eval_(model):
|
||||
|
||||
# tensor helpers
|
||||
|
||||
def log(t, eps = 1e-12):
|
||||
return torch.log(t.clamp(min = eps))
|
||||
|
||||
def l2norm(t):
|
||||
return F.normalize(t, dim = -1)
|
||||
|
||||
@@ -107,10 +125,10 @@ def resize_image_to(image, target_image_size):
|
||||
# ddpms expect images to be in the range of -1 to 1
|
||||
# but CLIP may otherwise
|
||||
|
||||
def normalize_img(img):
|
||||
def normalize_neg_one_to_one(img):
|
||||
return img * 2 - 1
|
||||
|
||||
def unnormalize_img(normed_img):
|
||||
def unnormalize_zero_to_one(normed_img):
|
||||
return (normed_img + 1) * 0.5
|
||||
|
||||
# clip related adapters
|
||||
@@ -271,7 +289,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
|
||||
def embed_image(self, image):
|
||||
assert not self.cleared
|
||||
image = resize_image_to(image, self.image_size)
|
||||
image = self.clip_normalize(unnormalize_img(image))
|
||||
image = self.clip_normalize(image)
|
||||
image_embed = self.clip.encode_image(image)
|
||||
return EmbeddedImage(l2norm(image_embed.float()), None)
|
||||
|
||||
@@ -297,13 +315,43 @@ def noise_like(shape, device, repeat=False):
|
||||
noise = lambda: torch.randn(shape, device=device)
|
||||
return repeat_noise() if repeat else noise()
|
||||
|
||||
def meanflat(x):
|
||||
return x.mean(dim = tuple(range(1, len(x.shape))))
|
||||
|
||||
def normal_kl(mean1, logvar1, mean2, logvar2):
|
||||
return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2))
|
||||
|
||||
def approx_standard_normal_cdf(x):
|
||||
return 0.5 * (1.0 + torch.tanh(((2.0 / math.pi) ** 0.5) * (x + 0.044715 * (x ** 3))))
|
||||
|
||||
def discretized_gaussian_log_likelihood(x, *, means, log_scales, thres = 0.999):
|
||||
assert x.shape == means.shape == log_scales.shape
|
||||
|
||||
centered_x = x - means
|
||||
inv_stdv = torch.exp(-log_scales)
|
||||
plus_in = inv_stdv * (centered_x + 1. / 255.)
|
||||
cdf_plus = approx_standard_normal_cdf(plus_in)
|
||||
min_in = inv_stdv * (centered_x - 1. / 255.)
|
||||
cdf_min = approx_standard_normal_cdf(min_in)
|
||||
log_cdf_plus = log(cdf_plus)
|
||||
log_one_minus_cdf_min = log(1. - cdf_min)
|
||||
cdf_delta = cdf_plus - cdf_min
|
||||
|
||||
log_probs = torch.where(x < -thres,
|
||||
log_cdf_plus,
|
||||
torch.where(x > thres,
|
||||
log_one_minus_cdf_min,
|
||||
log(cdf_delta)))
|
||||
|
||||
return log_probs
|
||||
|
||||
def cosine_beta_schedule(timesteps, s = 0.008):
|
||||
"""
|
||||
cosine schedule
|
||||
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
|
||||
"""
|
||||
steps = timesteps + 1
|
||||
x = torch.linspace(0, timesteps, steps)
|
||||
x = torch.linspace(0, timesteps, steps, dtype = torch.float64)
|
||||
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
|
||||
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
|
||||
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
||||
@@ -314,21 +362,21 @@ def linear_beta_schedule(timesteps):
|
||||
scale = 1000 / timesteps
|
||||
beta_start = scale * 0.0001
|
||||
beta_end = scale * 0.02
|
||||
return torch.linspace(beta_start, beta_end, timesteps)
|
||||
return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64)
|
||||
|
||||
|
||||
def quadratic_beta_schedule(timesteps):
|
||||
scale = 1000 / timesteps
|
||||
beta_start = scale * 0.0001
|
||||
beta_end = scale * 0.02
|
||||
return torch.linspace(beta_start**2, beta_end**2, timesteps) ** 2
|
||||
return torch.linspace(beta_start**2, beta_end**2, timesteps, dtype = torch.float64) ** 2
|
||||
|
||||
|
||||
def sigmoid_beta_schedule(timesteps):
|
||||
scale = 1000 / timesteps
|
||||
beta_start = scale * 0.0001
|
||||
beta_end = scale * 0.02
|
||||
betas = torch.linspace(-6, 6, timesteps)
|
||||
betas = torch.linspace(-6, 6, timesteps, dtype = torch.float64)
|
||||
return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
|
||||
|
||||
|
||||
@@ -368,17 +416,21 @@ class BaseGaussianDiffusion(nn.Module):
|
||||
self.loss_type = loss_type
|
||||
self.loss_fn = loss_fn
|
||||
|
||||
self.register_buffer('betas', betas)
|
||||
self.register_buffer('alphas_cumprod', alphas_cumprod)
|
||||
self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
|
||||
# register buffer helper function to cast double back to float
|
||||
|
||||
register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))
|
||||
|
||||
register_buffer('betas', betas)
|
||||
register_buffer('alphas_cumprod', alphas_cumprod)
|
||||
register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
|
||||
self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
|
||||
self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
|
||||
self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
|
||||
self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
|
||||
self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
|
||||
register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
|
||||
register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
|
||||
register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
|
||||
register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
|
||||
register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
|
||||
|
||||
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
||||
|
||||
@@ -386,19 +438,13 @@ class BaseGaussianDiffusion(nn.Module):
|
||||
|
||||
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
||||
|
||||
self.register_buffer('posterior_variance', posterior_variance)
|
||||
register_buffer('posterior_variance', posterior_variance)
|
||||
|
||||
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
||||
|
||||
self.register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
|
||||
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
||||
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
|
||||
|
||||
def q_mean_variance(self, x_start, t):
|
||||
mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
||||
variance = extract(1. - self.alphas_cumprod, t, x_start.shape)
|
||||
log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
|
||||
return mean, variance, log_variance
|
||||
register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
|
||||
register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
||||
register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
|
||||
|
||||
def q_posterior(self, x_start, x_t, t):
|
||||
posterior_mean = (
|
||||
@@ -571,7 +617,6 @@ class Attention(nn.Module):
|
||||
heads = 8,
|
||||
dropout = 0.,
|
||||
causal = False,
|
||||
post_norm = False,
|
||||
rotary_emb = None
|
||||
):
|
||||
super().__init__()
|
||||
@@ -581,7 +626,6 @@ class Attention(nn.Module):
|
||||
|
||||
self.causal = causal
|
||||
self.norm = LayerNorm(dim)
|
||||
self.post_norm = LayerNorm(dim) # sandwich norm from Coqview paper + Normformer
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
|
||||
@@ -592,7 +636,7 @@ class Attention(nn.Module):
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim, bias = False),
|
||||
LayerNorm(dim) if post_norm else nn.Identity()
|
||||
LayerNorm(dim)
|
||||
)
|
||||
|
||||
def forward(self, x, mask = None, attn_bias = None):
|
||||
@@ -649,8 +693,7 @@ class Attention(nn.Module):
|
||||
out = einsum('b h i j, b j d -> b h i d', attn, v)
|
||||
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
out = self.to_out(out)
|
||||
return self.post_norm(out)
|
||||
return self.to_out(out)
|
||||
|
||||
class CausalTransformer(nn.Module):
|
||||
def __init__(
|
||||
@@ -676,7 +719,7 @@ class CausalTransformer(nn.Module):
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, post_norm = normformer, rotary_emb = rotary_emb),
|
||||
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, rotary_emb = rotary_emb),
|
||||
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
|
||||
]))
|
||||
|
||||
@@ -754,7 +797,7 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
text_embed,
|
||||
text_encodings = None,
|
||||
mask = None,
|
||||
cond_drop_prob = 0.2
|
||||
cond_drop_prob = 0.
|
||||
):
|
||||
batch, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype
|
||||
|
||||
@@ -827,7 +870,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
image_channels = 3,
|
||||
timesteps = 1000,
|
||||
cond_drop_prob = 0.,
|
||||
loss_type = "l1",
|
||||
loss_type = "l2",
|
||||
predict_x_start = True,
|
||||
beta_schedule = "cosine",
|
||||
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
|
||||
@@ -861,6 +904,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
self.channels = default(image_channels, lambda: clip.image_channels)
|
||||
|
||||
self.cond_drop_prob = cond_drop_prob
|
||||
self.can_classifier_guidance = cond_drop_prob > 0.
|
||||
self.condition_on_text_encodings = condition_on_text_encodings
|
||||
|
||||
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
|
||||
@@ -874,8 +918,10 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
self.training_clamp_l2norm = training_clamp_l2norm
|
||||
self.init_image_embed_l2norm = init_image_embed_l2norm
|
||||
|
||||
def p_mean_variance(self, x, t, text_cond, clip_denoised: bool):
|
||||
pred = self.net(x, t, **text_cond)
|
||||
def p_mean_variance(self, x, t, text_cond, clip_denoised = False, cond_scale = 1.):
|
||||
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the model was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
|
||||
|
||||
pred = self.net.forward_with_cond_scale(x, t, cond_scale = cond_scale, **text_cond)
|
||||
|
||||
if self.predict_x_start:
|
||||
x_recon = pred
|
||||
@@ -893,17 +939,17 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
||||
return model_mean, posterior_variance, posterior_log_variance
|
||||
|
||||
@torch.inference_mode()
|
||||
def p_sample(self, x, t, text_cond = None, clip_denoised = True, repeat_noise = False):
|
||||
@torch.no_grad()
|
||||
def p_sample(self, x, t, text_cond = None, clip_denoised = True, repeat_noise = False, cond_scale = 1.):
|
||||
b, *_, device = *x.shape, x.device
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, text_cond = text_cond, clip_denoised = clip_denoised)
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, text_cond = text_cond, clip_denoised = clip_denoised, cond_scale = cond_scale)
|
||||
noise = noise_like(x.shape, device, repeat_noise)
|
||||
# no noise when t == 0
|
||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||
|
||||
@torch.inference_mode()
|
||||
def p_sample_loop(self, shape, text_cond):
|
||||
@torch.no_grad()
|
||||
def p_sample_loop(self, shape, text_cond, cond_scale = 1.):
|
||||
device = self.betas.device
|
||||
|
||||
b = shape[0]
|
||||
@@ -914,7 +960,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
|
||||
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
|
||||
times = torch.full((b,), i, device = device, dtype = torch.long)
|
||||
image_embed = self.p_sample(image_embed, times, text_cond = text_cond)
|
||||
image_embed = self.p_sample(image_embed, times, text_cond = text_cond, cond_scale = cond_scale)
|
||||
|
||||
return image_embed
|
||||
|
||||
@@ -938,21 +984,21 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
loss = self.loss_fn(pred, target)
|
||||
return loss
|
||||
|
||||
@torch.inference_mode()
|
||||
@torch.no_grad()
|
||||
@eval_decorator
|
||||
def sample_batch_size(self, batch_size, text_cond):
|
||||
def sample_batch_size(self, batch_size, text_cond, cond_scale = 1.):
|
||||
device = self.betas.device
|
||||
shape = (batch_size, self.image_embed_dim)
|
||||
|
||||
img = torch.randn(shape, device = device)
|
||||
|
||||
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
|
||||
img = self.p_sample(img, torch.full((batch_size,), i, device = device, dtype = torch.long), text_cond = text_cond)
|
||||
img = self.p_sample(img, torch.full((batch_size,), i, device = device, dtype = torch.long), text_cond = text_cond, cond_scale = cond_scale)
|
||||
return img
|
||||
|
||||
@torch.inference_mode()
|
||||
@torch.no_grad()
|
||||
@eval_decorator
|
||||
def sample(self, text, num_samples_per_batch = 2):
|
||||
def sample(self, text, num_samples_per_batch = 2, cond_scale = 1.):
|
||||
# in the paper, what they did was
|
||||
# sample 2 image embeddings, choose the top 1 similarity, as judged by CLIP
|
||||
text = repeat(text, 'b ... -> (b r) ...', r = num_samples_per_batch)
|
||||
@@ -967,7 +1013,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
if self.condition_on_text_encodings:
|
||||
text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask}
|
||||
|
||||
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond)
|
||||
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond, cond_scale = cond_scale)
|
||||
|
||||
# retrieve original unscaled image embed
|
||||
|
||||
@@ -1123,6 +1169,7 @@ class CrossAttention(nn.Module):
|
||||
dim_head = 64,
|
||||
heads = 8,
|
||||
dropout = 0.,
|
||||
norm_context = False
|
||||
):
|
||||
super().__init__()
|
||||
self.scale = dim_head ** -0.5
|
||||
@@ -1132,13 +1179,17 @@ class CrossAttention(nn.Module):
|
||||
context_dim = default(context_dim, dim)
|
||||
|
||||
self.norm = LayerNorm(dim)
|
||||
self.norm_context = LayerNorm(context_dim)
|
||||
self.norm_context = LayerNorm(context_dim) if norm_context else nn.Identity()
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
||||
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim, bias = False),
|
||||
LayerNorm(dim)
|
||||
)
|
||||
|
||||
def forward(self, x, context, mask = None):
|
||||
b, n, device = *x.shape[:2], x.device
|
||||
@@ -1260,7 +1311,7 @@ class Unet(nn.Module):
|
||||
self,
|
||||
dim,
|
||||
*,
|
||||
image_embed_dim,
|
||||
image_embed_dim = None,
|
||||
text_embed_dim = None,
|
||||
cond_dim = None,
|
||||
num_image_tokens = 4,
|
||||
@@ -1268,6 +1319,7 @@ class Unet(nn.Module):
|
||||
out_dim = None,
|
||||
dim_mults=(1, 2, 4, 8),
|
||||
channels = 3,
|
||||
channels_out = None,
|
||||
attn_dim_head = 32,
|
||||
attn_heads = 16,
|
||||
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
|
||||
@@ -1298,6 +1350,7 @@ class Unet(nn.Module):
|
||||
# determine dimensions
|
||||
|
||||
self.channels = channels
|
||||
self.channels_out = default(channels_out, channels)
|
||||
|
||||
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
|
||||
init_dim = default(init_dim, dim // 3 * 2)
|
||||
@@ -1330,14 +1383,18 @@ class Unet(nn.Module):
|
||||
self.image_to_cond = nn.Sequential(
|
||||
nn.Linear(image_embed_dim, cond_dim * num_image_tokens),
|
||||
Rearrange('b (n d) -> b n d', n = num_image_tokens)
|
||||
) if image_embed_dim != cond_dim else nn.Identity()
|
||||
) if cond_on_image_embeds and image_embed_dim != cond_dim else nn.Identity()
|
||||
|
||||
self.norm_cond = nn.LayerNorm(cond_dim)
|
||||
self.norm_mid_cond = nn.LayerNorm(cond_dim)
|
||||
|
||||
# text encoding conditioning (optional)
|
||||
|
||||
self.text_to_cond = None
|
||||
|
||||
if cond_on_text_encodings:
|
||||
self.text_to_cond = nn.LazyLinear(cond_dim) if not exists(text_embed_dim) else nn.Linear(text_embed_dim, cond_dim)
|
||||
assert exists(text_embed_dim), 'text_embed_dim must be given to the unet if cond_on_text_encodings is True'
|
||||
self.text_to_cond = nn.Linear(text_embed_dim, cond_dim)
|
||||
|
||||
# finer control over whether to condition on image embeddings and text encodings
|
||||
# so one can have the latter unets in the cascading DDPMs only focus on super-resoluting
|
||||
@@ -1403,11 +1460,9 @@ class Unet(nn.Module):
|
||||
Upsample(dim_in)
|
||||
]))
|
||||
|
||||
out_dim = default(out_dim, channels)
|
||||
|
||||
self.final_conv = nn.Sequential(
|
||||
ResnetBlock(dim, dim, groups = resnet_groups[0]),
|
||||
nn.Conv2d(dim, out_dim, 1)
|
||||
nn.Conv2d(dim, self.channels_out, 1)
|
||||
)
|
||||
|
||||
# if the current settings for the unet are not correct
|
||||
@@ -1417,13 +1472,25 @@ class Unet(nn.Module):
|
||||
*,
|
||||
lowres_cond,
|
||||
channels,
|
||||
channels_out,
|
||||
cond_on_image_embeds,
|
||||
cond_on_text_encodings
|
||||
):
|
||||
if lowres_cond == self.lowres_cond and channels == self.channels and cond_on_image_embeds == self.cond_on_image_embeds and cond_on_text_encodings == self.cond_on_text_encodings:
|
||||
if lowres_cond == self.lowres_cond and \
|
||||
channels == self.channels and \
|
||||
cond_on_image_embeds == self.cond_on_image_embeds and \
|
||||
cond_on_text_encodings == self.cond_on_text_encodings and \
|
||||
channels_out == self.channels_out:
|
||||
return self
|
||||
|
||||
updated_kwargs = {'lowres_cond': lowres_cond, 'channels': channels, 'cond_on_image_embeds': cond_on_image_embeds, 'cond_on_text_encodings': cond_on_text_encodings}
|
||||
updated_kwargs = dict(
|
||||
lowres_cond = lowres_cond,
|
||||
channels = channels,
|
||||
channels_out = channels_out,
|
||||
cond_on_image_embeds = cond_on_image_embeds,
|
||||
cond_on_text_encodings = cond_on_text_encodings
|
||||
)
|
||||
|
||||
return self.__class__(**{**self._locals, **updated_kwargs})
|
||||
|
||||
def forward_with_cond_scale(
|
||||
@@ -1488,11 +1555,12 @@ class Unet(nn.Module):
|
||||
|
||||
if self.cond_on_image_embeds:
|
||||
image_tokens = self.image_to_cond(image_embed)
|
||||
null_image_embed = self.null_image_embed.to(image_tokens.dtype) # for some reason pytorch AMP not working
|
||||
|
||||
image_tokens = torch.where(
|
||||
image_keep_mask,
|
||||
image_tokens,
|
||||
self.null_image_embed
|
||||
null_image_embed
|
||||
)
|
||||
|
||||
# take care of text encodings (optional)
|
||||
@@ -1516,10 +1584,12 @@ class Unet(nn.Module):
|
||||
text_mask = rearrange(text_mask, 'b n -> b n 1')
|
||||
text_keep_mask = text_mask & text_keep_mask
|
||||
|
||||
null_text_embed = self.null_text_embed.to(text_tokens.dtype) # for some reason pytorch AMP not working
|
||||
|
||||
text_tokens = torch.where(
|
||||
text_keep_mask,
|
||||
text_tokens,
|
||||
self.null_text_embed
|
||||
null_text_embed
|
||||
)
|
||||
|
||||
# main conditioning tokens (c)
|
||||
@@ -1534,6 +1604,11 @@ class Unet(nn.Module):
|
||||
|
||||
mid_c = c if not exists(text_tokens) else torch.cat((c, text_tokens), dim = -2)
|
||||
|
||||
# normalize conditioning tokens
|
||||
|
||||
c = self.norm_cond(c)
|
||||
mid_c = self.norm_mid_cond(mid_c)
|
||||
|
||||
# go through the layers of the unet, down and up
|
||||
|
||||
hiddens = []
|
||||
@@ -1607,7 +1682,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
timesteps = 1000,
|
||||
image_cond_drop_prob = 0.1,
|
||||
text_cond_drop_prob = 0.5,
|
||||
loss_type = 'l1',
|
||||
loss_type = 'l2',
|
||||
beta_schedule = 'cosine',
|
||||
predict_x_start = False,
|
||||
predict_x_start_for_latent_diffusion = False,
|
||||
@@ -1620,7 +1695,10 @@ class Decoder(BaseGaussianDiffusion):
|
||||
clip_denoised = True,
|
||||
clip_x_start = True,
|
||||
clip_adapter_overrides = dict(),
|
||||
unconditional = False
|
||||
learned_variance = True,
|
||||
vb_loss_weight = 0.001,
|
||||
unconditional = False,
|
||||
auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader
|
||||
):
|
||||
super().__init__(
|
||||
beta_schedule = beta_schedule,
|
||||
@@ -1631,7 +1709,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
self.unconditional = unconditional
|
||||
assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present'
|
||||
|
||||
assert exists(clip) ^ exists(image_size), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)'
|
||||
assert self.unconditional or (exists(clip) ^ exists(image_size)), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)'
|
||||
|
||||
self.clip = None
|
||||
if exists(clip):
|
||||
@@ -1658,10 +1736,18 @@ class Decoder(BaseGaussianDiffusion):
|
||||
unets = cast_tuple(unet)
|
||||
vaes = pad_tuple_to_length(cast_tuple(vae), len(unets), fillvalue = NullVQGanVAE(channels = self.channels))
|
||||
|
||||
# whether to use learned variance, defaults to True for the first unet in the cascade, as in paper
|
||||
|
||||
learned_variance = pad_tuple_to_length(cast_tuple(learned_variance), len(unets), fillvalue = False)
|
||||
self.learned_variance = learned_variance
|
||||
self.vb_loss_weight = vb_loss_weight
|
||||
|
||||
# construct unets and vaes
|
||||
|
||||
self.unets = nn.ModuleList([])
|
||||
self.vaes = nn.ModuleList([])
|
||||
|
||||
for ind, (one_unet, one_vae) in enumerate(zip(unets, vaes)):
|
||||
for ind, (one_unet, one_vae, one_unet_learned_var) in enumerate(zip(unets, vaes, learned_variance)):
|
||||
assert isinstance(one_unet, Unet)
|
||||
assert isinstance(one_vae, (VQGanVAE, NullVQGanVAE))
|
||||
|
||||
@@ -1669,12 +1755,14 @@ class Decoder(BaseGaussianDiffusion):
|
||||
latent_dim = one_vae.encoded_dim if exists(one_vae) else None
|
||||
|
||||
unet_channels = default(latent_dim, self.channels)
|
||||
unet_channels_out = unet_channels * (1 if not one_unet_learned_var else 2)
|
||||
|
||||
one_unet = one_unet.cast_model_parameters(
|
||||
lowres_cond = not is_first,
|
||||
cond_on_image_embeds = is_first and not unconditional,
|
||||
cond_on_text_encodings = one_unet.cond_on_text_encodings and not unconditional,
|
||||
channels = unet_channels
|
||||
channels = unet_channels,
|
||||
channels_out = unet_channels_out
|
||||
)
|
||||
|
||||
self.unets.append(one_unet)
|
||||
@@ -1712,12 +1800,17 @@ class Decoder(BaseGaussianDiffusion):
|
||||
|
||||
self.image_cond_drop_prob = image_cond_drop_prob
|
||||
self.text_cond_drop_prob = text_cond_drop_prob
|
||||
self.can_classifier_guidance = image_cond_drop_prob > 0. or text_cond_drop_prob > 0.
|
||||
|
||||
# whether to clip when sampling
|
||||
|
||||
self.clip_denoised = clip_denoised
|
||||
self.clip_x_start = clip_x_start
|
||||
|
||||
# normalize and unnormalize image functions
|
||||
self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity
|
||||
self.unnormalize_img = unnormalize_zero_to_one if auto_normalize_img else identity
|
||||
|
||||
def get_unet(self, unet_number):
|
||||
assert 0 < unet_number <= len(self.unets)
|
||||
index = unet_number - 1
|
||||
@@ -1731,14 +1824,23 @@ class Decoder(BaseGaussianDiffusion):
|
||||
unet = self.get_unet(unet_number)
|
||||
|
||||
self.cuda()
|
||||
|
||||
devices = [module_device(unet) for unet in self.unets]
|
||||
self.unets.cpu()
|
||||
|
||||
unet.cuda()
|
||||
yield
|
||||
unet.cpu()
|
||||
|
||||
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, cond_scale = 1.):
|
||||
pred = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
|
||||
yield
|
||||
|
||||
for unet, device in zip(self.unets, devices):
|
||||
unet.to(device)
|
||||
|
||||
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None):
|
||||
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
|
||||
|
||||
pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img))
|
||||
|
||||
if learned_variance:
|
||||
pred, var_interp_frac_unnormalized = pred.chunk(2, dim = 1)
|
||||
|
||||
if predict_x_start:
|
||||
x_recon = pred
|
||||
@@ -1749,24 +1851,39 @@ class Decoder(BaseGaussianDiffusion):
|
||||
x_recon.clamp_(-1., 1.)
|
||||
|
||||
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
||||
|
||||
if learned_variance:
|
||||
# if learned variance, posterio variance and posterior log variance are predicted by the network
|
||||
# by an interpolation of the max and min log beta values
|
||||
# eq 15 - https://arxiv.org/abs/2102.09672
|
||||
min_log = extract(self.posterior_log_variance_clipped, t, x.shape)
|
||||
max_log = extract(torch.log(self.betas), t, x.shape)
|
||||
var_interp_frac = unnormalize_zero_to_one(var_interp_frac_unnormalized)
|
||||
|
||||
posterior_log_variance = var_interp_frac * max_log + (1 - var_interp_frac) * min_log
|
||||
posterior_variance = posterior_log_variance.exp()
|
||||
|
||||
return model_mean, posterior_variance, posterior_log_variance
|
||||
|
||||
@torch.inference_mode()
|
||||
def p_sample(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, clip_denoised = True, repeat_noise = False):
|
||||
@torch.no_grad()
|
||||
def p_sample(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True, repeat_noise = False):
|
||||
b, *_, device = *x.shape, x.device
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start)
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, learned_variance = learned_variance)
|
||||
noise = noise_like(x.shape, device, repeat_noise)
|
||||
# no noise when t == 0
|
||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||
|
||||
@torch.inference_mode()
|
||||
def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1):
|
||||
@torch.no_grad()
|
||||
def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1, is_latent_diffusion = False):
|
||||
device = self.betas.device
|
||||
|
||||
b = shape[0]
|
||||
img = torch.randn(shape, device = device)
|
||||
|
||||
if not is_latent_diffusion:
|
||||
lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
|
||||
|
||||
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
|
||||
img = self.p_sample(
|
||||
unet,
|
||||
@@ -1778,17 +1895,27 @@ class Decoder(BaseGaussianDiffusion):
|
||||
cond_scale = cond_scale,
|
||||
lowres_cond_img = lowres_cond_img,
|
||||
predict_x_start = predict_x_start,
|
||||
learned_variance = learned_variance,
|
||||
clip_denoised = clip_denoised
|
||||
)
|
||||
|
||||
return img
|
||||
unnormalize_img = self.unnormalize_img(img)
|
||||
return unnormalize_img
|
||||
|
||||
def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None):
|
||||
def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False):
|
||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||
|
||||
# normalize to [-1, 1]
|
||||
|
||||
if not is_latent_diffusion:
|
||||
x_start = self.normalize_img(x_start)
|
||||
lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
|
||||
|
||||
# get x_t
|
||||
|
||||
x_noisy = self.q_sample(x_start = x_start, t = times, noise = noise)
|
||||
|
||||
pred = unet(
|
||||
model_output = unet(
|
||||
x_noisy,
|
||||
times,
|
||||
image_embed = image_embed,
|
||||
@@ -1799,17 +1926,57 @@ class Decoder(BaseGaussianDiffusion):
|
||||
text_cond_drop_prob = self.text_cond_drop_prob,
|
||||
)
|
||||
|
||||
if learned_variance:
|
||||
pred, _ = model_output.chunk(2, dim = 1)
|
||||
else:
|
||||
pred = model_output
|
||||
|
||||
target = noise if not predict_x_start else x_start
|
||||
|
||||
loss = self.loss_fn(pred, target)
|
||||
return loss
|
||||
|
||||
@torch.inference_mode()
|
||||
if not learned_variance:
|
||||
# return simple loss if not using learned variance
|
||||
return loss
|
||||
|
||||
# most of the code below is transcribed from
|
||||
# https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/diffusion_utils_2.py
|
||||
# the Improved DDPM paper then further modified it so that the mean is detached (shown a couple lines before), and weighted to be smaller than the l1 or l2 "simple" loss
|
||||
# it is questionable whether this is really needed, looking at some of the figures in the paper, but may as well stay faithful to their implementation
|
||||
|
||||
# if learning the variance, also include the extra weight kl loss
|
||||
|
||||
true_mean, _, true_log_variance_clipped = self.q_posterior(x_start = x_start, x_t = x_noisy, t = times)
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x_noisy, t = times, image_embed = image_embed, clip_denoised = clip_denoised, learned_variance = True, model_output = model_output)
|
||||
|
||||
# kl loss with detached model predicted mean, for stability reasons as in paper
|
||||
|
||||
detached_model_mean = model_mean.detach()
|
||||
|
||||
kl = normal_kl(true_mean, true_log_variance_clipped, detached_model_mean, model_log_variance)
|
||||
kl = meanflat(kl) * NAT
|
||||
|
||||
decoder_nll = -discretized_gaussian_log_likelihood(x_start, means = detached_model_mean, log_scales = 0.5 * model_log_variance)
|
||||
decoder_nll = meanflat(decoder_nll) * NAT
|
||||
|
||||
# at the first timestep return the decoder NLL, otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
|
||||
|
||||
vb_losses = torch.where(times == 0, decoder_nll, kl)
|
||||
|
||||
# weight the vb loss smaller, for stability, as in the paper (recommended 0.001)
|
||||
|
||||
vb_loss = vb_losses.mean() * self.vb_loss_weight
|
||||
|
||||
return loss + vb_loss
|
||||
|
||||
@torch.no_grad()
|
||||
@eval_decorator
|
||||
def sample(
|
||||
self,
|
||||
image_embed = None,
|
||||
text = None,
|
||||
text_mask = None,
|
||||
text_encodings = None,
|
||||
batch_size = 1,
|
||||
cond_scale = 1.,
|
||||
stop_at_unet_number = None
|
||||
@@ -1819,8 +1986,8 @@ class Decoder(BaseGaussianDiffusion):
|
||||
if not self.unconditional:
|
||||
batch_size = image_embed.shape[0]
|
||||
|
||||
text_encodings = text_mask = None
|
||||
if exists(text):
|
||||
if exists(text) and not exists(text_encodings) and not self.unconditional:
|
||||
assert exists(self.clip)
|
||||
_, text_encodings, text_mask = self.clip.embed_text(text)
|
||||
|
||||
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
|
||||
@@ -1829,7 +1996,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
img = None
|
||||
is_cuda = next(self.parameters()).is_cuda
|
||||
|
||||
for unet_number, unet, vae, channel, image_size, predict_x_start in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)):
|
||||
for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance)):
|
||||
|
||||
context = self.one_unet_in_gpu(unet = unet) if is_cuda else null_context()
|
||||
|
||||
@@ -1844,8 +2011,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
image_size = vae.get_encoded_fmap_size(image_size)
|
||||
shape = (batch_size, vae.encoded_dim, image_size, image_size)
|
||||
|
||||
if exists(lowres_cond_img):
|
||||
lowres_cond_img = vae.encode(lowres_cond_img)
|
||||
lowres_cond_img = maybe(vae.encode)(lowres_cond_img)
|
||||
|
||||
img = self.p_sample_loop(
|
||||
unet,
|
||||
@@ -1855,8 +2021,10 @@ class Decoder(BaseGaussianDiffusion):
|
||||
text_mask = text_mask,
|
||||
cond_scale = cond_scale,
|
||||
predict_x_start = predict_x_start,
|
||||
learned_variance = learned_variance,
|
||||
clip_denoised = not is_latent_diffusion,
|
||||
lowres_cond_img = lowres_cond_img
|
||||
lowres_cond_img = lowres_cond_img,
|
||||
is_latent_diffusion = is_latent_diffusion
|
||||
)
|
||||
|
||||
img = vae.decode(img)
|
||||
@@ -1872,6 +2040,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
text = None,
|
||||
image_embed = None,
|
||||
text_encodings = None,
|
||||
text_mask = None,
|
||||
unet_number = None
|
||||
):
|
||||
assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
|
||||
@@ -1884,6 +2053,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
target_image_size = self.image_sizes[unet_index]
|
||||
predict_x_start = self.predict_x_start[unet_index]
|
||||
random_crop_size = self.random_crop_sizes[unet_index]
|
||||
learned_variance = self.learned_variance[unet_index]
|
||||
b, c, h, w, device, = *image.shape, image.device
|
||||
|
||||
check_shape(image, 'b c h w', c = self.channels)
|
||||
@@ -1891,12 +2061,11 @@ class Decoder(BaseGaussianDiffusion):
|
||||
|
||||
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
|
||||
|
||||
if not exists(image_embed):
|
||||
if not exists(image_embed) and not self.unconditional:
|
||||
assert exists(self.clip), 'if you want to derive CLIP image embeddings automatically, you must supply `clip` to the decoder on init'
|
||||
image_embed, _ = self.clip.embed_image(image)
|
||||
|
||||
text_encodings = text_mask = None
|
||||
if exists(text) and not exists(text_encodings):
|
||||
if exists(text) and not exists(text_encodings) and not self.unconditional:
|
||||
assert exists(self.clip), 'if you are passing in raw text, you need to supply `clip` to the decoder'
|
||||
_, text_encodings, text_mask = self.clip.embed_text(text)
|
||||
|
||||
@@ -1914,14 +2083,14 @@ class Decoder(BaseGaussianDiffusion):
|
||||
image = aug(image)
|
||||
lowres_cond_img = aug(lowres_cond_img, params = aug._params)
|
||||
|
||||
is_latent_diffusion = not isinstance(vae, NullVQGanVAE)
|
||||
|
||||
vae.eval()
|
||||
with torch.no_grad():
|
||||
image = vae.encode(image)
|
||||
lowres_cond_img = maybe(vae.encode)(lowres_cond_img)
|
||||
|
||||
if exists(lowres_cond_img):
|
||||
lowres_cond_img = vae.encode(lowres_cond_img)
|
||||
|
||||
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start)
|
||||
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion)
|
||||
|
||||
# main class
|
||||
|
||||
@@ -1944,22 +2113,23 @@ class DALLE2(nn.Module):
|
||||
|
||||
self.to_pil = T.ToPILImage()
|
||||
|
||||
@torch.inference_mode()
|
||||
@torch.no_grad()
|
||||
@eval_decorator
|
||||
def forward(
|
||||
self,
|
||||
text,
|
||||
cond_scale = 1.,
|
||||
prior_cond_scale = 1.,
|
||||
return_pil_images = False
|
||||
):
|
||||
device = next(self.parameters()).device
|
||||
device = module_device(self)
|
||||
one_text = isinstance(text, str) or (not is_list_str(text) and text.shape[0] == 1)
|
||||
|
||||
if isinstance(text, str) or is_list_str(text):
|
||||
text = [text] if not isinstance(text, (list, tuple)) else text
|
||||
text = tokenizer.tokenize(text).to(device)
|
||||
|
||||
image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples)
|
||||
image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples, cond_scale = prior_cond_scale)
|
||||
|
||||
text_cond = text if self.decoder_need_text_cond else None
|
||||
images = self.decoder.sample(image_embed, text = text_cond, cond_scale = cond_scale)
|
||||
@@ -1971,4 +2141,3 @@ class DALLE2(nn.Module):
|
||||
return images[0]
|
||||
|
||||
return images
|
||||
|
||||
|
||||
41
dalle2_pytorch/dataloaders/README.md
Normal file
41
dalle2_pytorch/dataloaders/README.md
Normal file
@@ -0,0 +1,41 @@
|
||||
## Dataloaders
|
||||
In order to make loading data simple and efficient, we include some general dataloaders that can be used to train portions of the network.
|
||||
|
||||
### Decoder: Image Embedding Dataset
|
||||
When training the decoder (and up samplers if training together) in isolation, you will need to load images and corresponding image embeddings. This dataset can read two similar types of datasets. First, it can read a [webdataset](https://github.com/webdataset/webdataset) that contains `.jpg` and `.npy` files in the `.tar`s that contain the images and associated image embeddings respectively. Alternatively, you can also specify a source for the embeddings outside of the webdataset. In this case, the path to the embeddings should contain `.npy` files with the same shard numbers as the webdataset and there should be a correspondence between the filename of the `.jpg` and the index of the embedding in the `.npy`. So, for example, `0001.tar` from the webdataset with image `00010509.jpg` (the first 4 digits are the shard number and the last 4 are the index) in it should be paralleled by a `img_emb_0001.npy` which contains a NumPy array with the embedding at index 509.
|
||||
|
||||
Generating a dataset of this type:
|
||||
1. Use [img2dataset](https://github.com/rom1504/img2dataset) to generate a webdataset.
|
||||
2. Use [clip-retrieval](https://github.com/rom1504/clip-retrieval) to convert the images to embeddings.
|
||||
3. Use [embedding-dataset-reordering](https://github.com/Veldrovive/embedding-dataset-reordering) to reorder the embeddings into the expected format.
|
||||
|
||||
Usage:
|
||||
```python
|
||||
from dalle2_pytorch.dataloaders import ImageEmbeddingDataset, create_image_embedding_dataloader
|
||||
|
||||
# Create a dataloader directly.
|
||||
dataloader = create_image_embedding_dataloader(
|
||||
tar_url="/path/or/url/to/webdataset/{0000..9999}.tar", # Uses braket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar
|
||||
embeddings_url="path/or/url/to/embeddings/folder", # Included if .npy files are not in webdataset. Left out or set to None otherwise
|
||||
num_workers=4,
|
||||
batch_size=32,
|
||||
shard_width=4, # If a file in the webdataset shard 3 is named 0003039.jpg, we know the shard width is 4 and the last three digits are the index
|
||||
shuffle_num=200, # Does a shuffle of the data with a buffer size of 200
|
||||
shuffle_shards=True, # Shuffle the order the shards are read in
|
||||
resample_shards=False, # Sample shards with replacement. If true, an epoch will be infinite unless stopped manually
|
||||
)
|
||||
for img, emb in dataloader:
|
||||
print(img.shape) # torch.Size([32, 3, 256, 256])
|
||||
print(emb.shape) # torch.Size([32, 512])
|
||||
# Train decoder only as shown above
|
||||
|
||||
# Or create a dataset without a loader so you can configure it manually
|
||||
dataset = ImageEmbeddingDataset(
|
||||
urls="/path/or/url/to/webdataset/{0000..9999}.tar",
|
||||
embedding_folder_url="path/or/url/to/embeddings/folder",
|
||||
shard_width=4,
|
||||
shuffle_shards=True,
|
||||
resample=False
|
||||
)
|
||||
```
|
||||
|
||||
@@ -1 +1,2 @@
|
||||
from dalle2_pytorch.dataloaders.decoder_loader import ImageEmbeddingDataset, create_image_embedding_dataloader
|
||||
from dalle2_pytorch.dataloaders.decoder_loader import ImageEmbeddingDataset, create_image_embedding_dataloader
|
||||
from dalle2_pytorch.dataloaders.embedding_wrapper import make_splits
|
||||
|
||||
@@ -3,6 +3,7 @@ import webdataset as wds
|
||||
import torch
|
||||
import numpy as np
|
||||
import fsspec
|
||||
import shutil
|
||||
|
||||
def get_shard(filename):
|
||||
"""
|
||||
@@ -20,7 +21,7 @@ def get_example_file(fs, path, file_format):
|
||||
"""
|
||||
return fs.glob(os.path.join(path, f"*.{file_format}"))[0]
|
||||
|
||||
def embedding_inserter(samples, embeddings_url, shard_width, handler=wds.handlers.reraise_exception):
|
||||
def embedding_inserter(samples, embeddings_url, index_width, handler=wds.handlers.reraise_exception):
|
||||
"""Given a datum of {"__key__": str, "__url__": str, ...} adds the cooresponding embedding and yields"""
|
||||
previous_tar_url = None
|
||||
current_embeddings = None
|
||||
@@ -50,8 +51,12 @@ def embedding_inserter(samples, embeddings_url, shard_width, handler=wds.handler
|
||||
previous_tar_url = tar_url
|
||||
current_embeddings = load_corresponding_embeds(tar_url)
|
||||
|
||||
embedding_index = int(key[shard_width:])
|
||||
sample["npy"] = current_embeddings[embedding_index]
|
||||
embedding_index = int(key[-index_width:])
|
||||
embedding = current_embeddings[embedding_index]
|
||||
# We need to check if this sample is nonzero. If it is, this embedding is not valid and we should continue to the next loop
|
||||
if torch.count_nonzero(embedding) == 0:
|
||||
raise RuntimeError(f"Webdataset had a sample, but no embedding was found. ImgShard: {key[:-index_width]} - Index: {key[-index_width:]}")
|
||||
sample["npy"] = embedding
|
||||
yield sample
|
||||
except Exception as exn: # From wds implementation
|
||||
if handler(exn):
|
||||
@@ -60,6 +65,28 @@ def embedding_inserter(samples, embeddings_url, shard_width, handler=wds.handler
|
||||
break
|
||||
insert_embedding = wds.filters.pipelinefilter(embedding_inserter)
|
||||
|
||||
def unassociated_shard_skipper(tarfiles, embeddings_url, handler=wds.handlers.reraise_exception):
|
||||
"""Finds if the is a corresponding embedding for the tarfile at { url: [URL] }"""
|
||||
embeddings_fs, embeddings_path = fsspec.core.url_to_fs(embeddings_url)
|
||||
embedding_files = embeddings_fs.ls(embeddings_path)
|
||||
get_embedding_shard = lambda embedding_file: int(embedding_file.split("_")[-1].split(".")[0])
|
||||
embedding_shards = set([get_embedding_shard(filename) for filename in embedding_files]) # Sets have O(1) check for member
|
||||
|
||||
get_tar_shard = lambda tar_file: int(tar_file.split("/")[-1].split(".")[0])
|
||||
for tarfile in tarfiles:
|
||||
try:
|
||||
webdataset_shard = get_tar_shard(tarfile["url"])
|
||||
# If this shard has an associated embeddings file, we pass it through. Otherwise we iterate until we do have one
|
||||
if webdataset_shard in embedding_shards:
|
||||
yield tarfile
|
||||
except Exception as exn: # From wds implementation
|
||||
if handler(exn):
|
||||
continue
|
||||
else:
|
||||
break
|
||||
|
||||
skip_unassociated_shards = wds.filters.pipelinefilter(unassociated_shard_skipper)
|
||||
|
||||
def verify_keys(samples, handler=wds.handlers.reraise_exception):
|
||||
"""
|
||||
Requires that both the image and embedding are present in the sample
|
||||
@@ -86,7 +113,9 @@ class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
|
||||
self,
|
||||
urls,
|
||||
embedding_folder_url=None,
|
||||
shard_width=None,
|
||||
index_width=None,
|
||||
img_preproc=None,
|
||||
extra_keys=[],
|
||||
handler=wds.handlers.reraise_exception,
|
||||
resample=False,
|
||||
shuffle_shards=True
|
||||
@@ -97,13 +126,31 @@ class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
|
||||
:param urls: A url pointing to the tar files of the webdataset formatted as /path/to/webdataset/{0000..9999}.tar
|
||||
:param embedding_folder_url: Required if webdataset does not contain embeddings. A url pointing to the npy files of the embeddings. Should have the same number of shards as the webdataset.
|
||||
Webdataset image keys should align with the index of the embedding. This means missing image indices must have a corresponding embedding of all zeros.
|
||||
:param shard_width: The number of digits in the shard number. This is used to align the embedding index with the image index.
|
||||
For example, if a file in the webdataset shard 3 is named 0003039.jpg, we know the shard with this 4 and the last three digits are the index.
|
||||
:param index_width: The number of digits in the index. This is used to align the embedding index with the image index.
|
||||
For example, if a file in the webdataset shard 3 is named 0003039.jpg, we know the shard is 4 digits and the last 3 digits are the index_width.
|
||||
:param img_preproc: This function is run on the img before it is batched and returned. Useful for data augmentation or converting to torch tensor.
|
||||
:param handler: A webdataset handler.
|
||||
:param resample: If true, resample webdataset shards with replacement. You need to set your own epoch size if this is true since it will resample infinitely.
|
||||
:param shuffle_shards: If true, shuffle the shards before resampling. This cannot be true if resample is true.
|
||||
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
keys = ["jpg", "npy"] + extra_keys
|
||||
self.key_map = {key: i for i, key in enumerate(keys)}
|
||||
self.resampling = resample
|
||||
self.img_preproc = img_preproc
|
||||
# If s3, check if s3fs is installed and s3cmd is installed and check if the data is piped instead of straight up
|
||||
if (isinstance(urls, str) and "s3:" in urls) or (isinstance(urls, list) and any(["s3:" in url for url in urls])):
|
||||
# Then this has an s3 link for the webdataset and we need extra packages
|
||||
if shutil.which("s3cmd") is None:
|
||||
raise RuntimeError("s3cmd is required for s3 webdataset")
|
||||
if "s3:" in embedding_folder_url:
|
||||
# Then the embeddings are being loaded from s3 and fsspec requires s3fs
|
||||
try:
|
||||
import s3fs
|
||||
except ImportError:
|
||||
raise RuntimeError("s3fs is required to load embeddings from s3")
|
||||
# Add the shardList and randomize or resample if requested
|
||||
if resample:
|
||||
assert not shuffle_shards, "Cannot both resample and shuffle"
|
||||
@@ -112,28 +159,43 @@ class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
|
||||
self.append(wds.SimpleShardList(urls))
|
||||
if shuffle_shards:
|
||||
self.append(wds.filters.shuffle(1000))
|
||||
|
||||
if embedding_folder_url is not None:
|
||||
# There may be webdataset shards that do not have a embedding shard associated with it. If we do not skip these, they would cause issues.
|
||||
self.append(skip_unassociated_shards(embeddings_url=embedding_folder_url, handler=handler))
|
||||
|
||||
self.append(wds.split_by_node)
|
||||
self.append(wds.split_by_worker)
|
||||
|
||||
self.append(wds.tarfile_to_samples(handler=handler))
|
||||
self.append(wds.decode("torchrgb"))
|
||||
self.append(wds.decode("pilrgb", handler=handler))
|
||||
if embedding_folder_url is not None:
|
||||
assert shard_width is not None, "Reading embeddings separately requires shard length to be given"
|
||||
self.append(insert_embedding(embeddings_url=embedding_folder_url, shard_width=shard_width, handler=handler))
|
||||
# Then we are loading embeddings for a remote source
|
||||
assert index_width is not None, "Reading embeddings separately requires index width length to be given"
|
||||
self.append(insert_embedding(embeddings_url=embedding_folder_url, index_width=index_width, handler=handler))
|
||||
self.append(verify_keys)
|
||||
self.append(wds.to_tuple("jpg", "npy"))
|
||||
# Apply preprocessing
|
||||
self.append(wds.map(self.preproc))
|
||||
self.append(wds.to_tuple(*keys))
|
||||
|
||||
def preproc(self, sample):
|
||||
"""Applies the preprocessing for images"""
|
||||
if self.img_preproc is not None:
|
||||
sample["jpg"] = self.img_preproc(sample["jpg"])
|
||||
return sample
|
||||
|
||||
def create_image_embedding_dataloader(
|
||||
tar_url,
|
||||
num_workers,
|
||||
batch_size,
|
||||
embeddings_url=None,
|
||||
shard_width=None,
|
||||
index_width=None,
|
||||
shuffle_num = None,
|
||||
shuffle_shards = True,
|
||||
resample_shards = False,
|
||||
handler=wds.handlers.warn_and_continue
|
||||
img_preproc=None,
|
||||
extra_keys=[],
|
||||
handler=wds.handlers.reraise_exception#warn_and_continue
|
||||
):
|
||||
"""
|
||||
Convenience function to create an image embedding dataseta and dataloader in one line
|
||||
@@ -143,8 +205,8 @@ def create_image_embedding_dataloader(
|
||||
:param batch_size: The batch size to use for the dataloader
|
||||
:param embeddings_url: Required if webdataset does not contain embeddings. A url pointing to the npy files of the embeddings. Should have the same number of shards as the webdataset.
|
||||
Webdataset image keys should align with the index of the embedding. This means missing image indices must have a corresponding embedding of all zeros.
|
||||
:param shard_width: The number of digits in the shard number. This is used to align the embedding index with the image index.
|
||||
For example, if a file in the webdataset shard 3 is named 0003039.jpg, we know the shard width is 4 and the last three digits are the index.
|
||||
:param index_width: The number of digits in the index. This is used to align the embedding index with the image index.
|
||||
For example, if a file in the webdataset shard 3 is named 0003039.jpg, we know the shard is 4 digits and the last 3 digits are the index_width.
|
||||
:param shuffle_num: If not None, shuffle the dataset with this size buffer after sampling.
|
||||
:param shuffle_shards: If true, shuffle the shards before sampling. This cannot be true if resample is true.
|
||||
:param resample_shards: If true, resample webdataset shards with replacement. You need to set your own epoch size if this is true since it will resample infinitely.
|
||||
@@ -153,9 +215,11 @@ def create_image_embedding_dataloader(
|
||||
ds = ImageEmbeddingDataset(
|
||||
tar_url,
|
||||
embeddings_url,
|
||||
shard_width=shard_width,
|
||||
index_width=index_width,
|
||||
shuffle_shards=shuffle_shards,
|
||||
resample=resample_shards,
|
||||
extra_keys=extra_keys,
|
||||
img_preproc=img_preproc,
|
||||
handler=handler
|
||||
)
|
||||
if shuffle_num is not None and shuffle_num > 0:
|
||||
|
||||
180
dalle2_pytorch/dataloaders/embedding_wrapper.py
Normal file
180
dalle2_pytorch/dataloaders/embedding_wrapper.py
Normal file
@@ -0,0 +1,180 @@
|
||||
from torch.utils.data import IterableDataset
|
||||
from torch import from_numpy
|
||||
from clip import tokenize
|
||||
from embedding_reader import EmbeddingReader
|
||||
|
||||
|
||||
class PriorEmbeddingLoader(IterableDataset):
|
||||
def __init__(
|
||||
self,
|
||||
text_conditioned: bool,
|
||||
batch_size: int,
|
||||
start: int,
|
||||
stop: int,
|
||||
image_reader,
|
||||
text_reader: EmbeddingReader = None,
|
||||
device: str = "cpu",
|
||||
) -> None:
|
||||
super(PriorEmbeddingLoader).__init__()
|
||||
|
||||
self.text_conditioned = text_conditioned
|
||||
|
||||
if not self.text_conditioned:
|
||||
self.text_reader = text_reader
|
||||
|
||||
self.image_reader = image_reader
|
||||
self.batch_size = batch_size
|
||||
self.start = start
|
||||
self.stop = stop
|
||||
self.device = device
|
||||
|
||||
def __iter__(self):
|
||||
self.n = 0
|
||||
loader_args = dict(
|
||||
batch_size=self.batch_size,
|
||||
start=self.start,
|
||||
end=self.stop,
|
||||
show_progress=False,
|
||||
)
|
||||
if self.text_conditioned:
|
||||
self.loader = self.image_reader(**loader_args)
|
||||
else:
|
||||
self.loader = zip(
|
||||
self.image_reader(**loader_args), self.text_reader(**loader_args)
|
||||
)
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
try:
|
||||
return self.get_sample()
|
||||
except StopIteration:
|
||||
raise StopIteration
|
||||
|
||||
def get_sample(self):
|
||||
"""
|
||||
pre-proocess data from either reader into a common format
|
||||
"""
|
||||
self.n += 1
|
||||
|
||||
if self.text_conditioned:
|
||||
image_embedding, caption = next(self.loader)
|
||||
|
||||
image_embedding = from_numpy(image_embedding).to(self.device)
|
||||
tokenized_caption = tokenize(
|
||||
caption["caption"].to_list(), truncate=True
|
||||
).to(self.device)
|
||||
|
||||
return image_embedding, tokenized_caption
|
||||
|
||||
else:
|
||||
(image_embedding, _), (text_embedding, _) = next(self.loader)
|
||||
|
||||
image_embedding = from_numpy(image_embedding).to(self.device)
|
||||
text_embedding = from_numpy(text_embedding).to(self.device)
|
||||
|
||||
return image_embedding, text_embedding
|
||||
|
||||
|
||||
def make_splits(
|
||||
text_conditioned: bool,
|
||||
batch_size: int,
|
||||
num_data_points: int,
|
||||
train_split: float,
|
||||
eval_split: float,
|
||||
device: str,
|
||||
img_url: str,
|
||||
meta_url: str = None,
|
||||
txt_url: str = None,
|
||||
):
|
||||
|
||||
assert img_url is not None, "Must supply some image embeddings"
|
||||
|
||||
if text_conditioned:
|
||||
assert meta_url is not None, "Must supply metadata url if text-conditioning"
|
||||
image_reader = EmbeddingReader(
|
||||
embeddings_folder=img_url,
|
||||
file_format="parquet_npy",
|
||||
meta_columns=["caption"],
|
||||
metadata_folder=meta_url,
|
||||
)
|
||||
|
||||
# compute split points
|
||||
if num_data_points > image_reader.count:
|
||||
print("Specified point count is larger than the number of points available...defaulting to max length of reader.")
|
||||
num_data_points = image_reader.count
|
||||
|
||||
train_set_size = int(train_split * num_data_points)
|
||||
eval_set_size = int(eval_split * num_data_points)
|
||||
eval_stop = int(train_set_size + eval_set_size)
|
||||
|
||||
train_loader = PriorEmbeddingLoader(
|
||||
text_conditioned=text_conditioned,
|
||||
image_reader=image_reader,
|
||||
batch_size=batch_size,
|
||||
start=0,
|
||||
stop=train_set_size,
|
||||
device=device,
|
||||
)
|
||||
eval_loader = PriorEmbeddingLoader(
|
||||
text_conditioned=text_conditioned,
|
||||
image_reader=image_reader,
|
||||
batch_size=batch_size,
|
||||
start=train_set_size,
|
||||
stop=eval_stop,
|
||||
device=device,
|
||||
)
|
||||
test_loader = PriorEmbeddingLoader(
|
||||
text_conditioned=text_conditioned,
|
||||
image_reader=image_reader,
|
||||
batch_size=batch_size,
|
||||
start=eval_stop,
|
||||
stop=int(num_data_points),
|
||||
device=device,
|
||||
)
|
||||
|
||||
else:
|
||||
assert (
|
||||
txt_url is not None
|
||||
), "Must supply text embedding url if not text-conditioning"
|
||||
|
||||
image_reader = EmbeddingReader(img_url, file_format="npy")
|
||||
text_reader = EmbeddingReader(txt_url, file_format="npy")
|
||||
|
||||
# compute split points
|
||||
if num_data_points > image_reader.count:
|
||||
print("Specified point count is larger than the number of points available...defaulting to max length of reader.")
|
||||
num_data_points = image_reader.count
|
||||
|
||||
train_set_size = int(train_split * num_data_points)
|
||||
eval_set_size = int(eval_split * num_data_points)
|
||||
eval_stop = int(train_set_size + eval_set_size)
|
||||
|
||||
train_loader = PriorEmbeddingLoader(
|
||||
text_conditioned=text_conditioned,
|
||||
image_reader=image_reader,
|
||||
text_reader=text_reader,
|
||||
batch_size=batch_size,
|
||||
start=0,
|
||||
stop=train_set_size,
|
||||
device=device,
|
||||
)
|
||||
eval_loader = PriorEmbeddingLoader(
|
||||
text_conditioned=text_conditioned,
|
||||
image_reader=image_reader,
|
||||
text_reader=text_reader,
|
||||
batch_size=batch_size,
|
||||
start=train_set_size,
|
||||
stop=eval_stop,
|
||||
device=device,
|
||||
)
|
||||
test_loader = PriorEmbeddingLoader(
|
||||
text_conditioned=text_conditioned,
|
||||
image_reader=image_reader,
|
||||
text_reader=text_reader,
|
||||
batch_size=batch_size,
|
||||
start=eval_stop,
|
||||
stop=int(num_data_points),
|
||||
device=device,
|
||||
)
|
||||
|
||||
return train_loader, eval_loader, test_loader
|
||||
59
dalle2_pytorch/dataloaders/simple_image_only_dataloader.py
Normal file
59
dalle2_pytorch/dataloaders/simple_image_only_dataloader.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch.utils import data
|
||||
from torchvision import transforms, utils
|
||||
|
||||
from PIL import Image
|
||||
|
||||
# helpers functions
|
||||
|
||||
def cycle(dl):
|
||||
while True:
|
||||
for data in dl:
|
||||
yield data
|
||||
|
||||
# dataset and dataloader
|
||||
|
||||
class Dataset(data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
folder,
|
||||
image_size,
|
||||
exts = ['jpg', 'jpeg', 'png']
|
||||
):
|
||||
super().__init__()
|
||||
self.folder = folder
|
||||
self.image_size = image_size
|
||||
self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]
|
||||
|
||||
self.transform = transforms.Compose([
|
||||
transforms.Resize(image_size),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.CenterCrop(image_size),
|
||||
transforms.ToTensor()
|
||||
])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.paths)
|
||||
|
||||
def __getitem__(self, index):
|
||||
path = self.paths[index]
|
||||
img = Image.open(path)
|
||||
return self.transform(img)
|
||||
|
||||
def get_images_dataloader(
|
||||
folder,
|
||||
*,
|
||||
batch_size,
|
||||
image_size,
|
||||
shuffle = True,
|
||||
cycle_dl = True,
|
||||
pin_memory = True
|
||||
):
|
||||
ds = Dataset(folder, image_size)
|
||||
dl = data.DataLoader(ds, batch_size = batch_size, shuffle = shuffle, pin_memory = pin_memory)
|
||||
|
||||
if cycle_dl:
|
||||
dl = cycle(dl)
|
||||
return dl
|
||||
@@ -7,16 +7,17 @@ def separate_weight_decayable_params(params):
|
||||
|
||||
def get_optimizer(
|
||||
params,
|
||||
lr = 3e-4,
|
||||
lr = 1e-4,
|
||||
wd = 1e-2,
|
||||
betas = (0.9, 0.999),
|
||||
eps = 1e-8,
|
||||
filter_by_requires_grad = False
|
||||
):
|
||||
if filter_by_requires_grad:
|
||||
params = list(filter(lambda t: t.requires_grad, params))
|
||||
|
||||
if wd == 0:
|
||||
return Adam(params, lr = lr, betas = betas)
|
||||
return Adam(params, lr = lr, betas = betas, eps = eps)
|
||||
|
||||
params = set(params)
|
||||
wd_params, no_wd_params = separate_weight_decayable_params(params)
|
||||
@@ -26,4 +27,4 @@ def get_optimizer(
|
||||
{'params': list(no_wd_params), 'weight_decay': 0},
|
||||
]
|
||||
|
||||
return AdamW(param_groups, lr = lr, weight_decay = wd, betas = betas)
|
||||
return AdamW(param_groups, lr = lr, weight_decay = wd, betas = betas, eps = eps)
|
||||
|
||||
115
dalle2_pytorch/trackers.py
Normal file
115
dalle2_pytorch/trackers.py
Normal file
@@ -0,0 +1,115 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
import importlib
|
||||
from itertools import zip_longest
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
# constants
|
||||
|
||||
DEFAULT_DATA_PATH = './.tracker-data'
|
||||
|
||||
# helper functions
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def import_or_print_error(pkg_name, err_str = None):
|
||||
try:
|
||||
return importlib.import_module(pkg_name)
|
||||
except ModuleNotFoundError as e:
|
||||
if exists(err_str):
|
||||
print(err_str)
|
||||
exit()
|
||||
|
||||
# load state dict functions
|
||||
|
||||
def load_wandb_state_dict(run_path, file_path, **kwargs):
|
||||
wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb recall function')
|
||||
file_reference = wandb.restore(file_path, run_path=run_path)
|
||||
return torch.load(file_reference.name)
|
||||
|
||||
def load_local_state_dict(file_path, **kwargs):
|
||||
return torch.load(file_path)
|
||||
|
||||
# base class
|
||||
|
||||
class BaseTracker(nn.Module):
|
||||
def __init__(self, data_path = DEFAULT_DATA_PATH):
|
||||
super().__init__()
|
||||
self.data_path = Path(data_path)
|
||||
self.data_path.mkdir(parents = True, exist_ok = True)
|
||||
|
||||
def init(self, config, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def log(self, log, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def log_images(self, images, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def save_state_dict(self, state_dict, relative_path, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def recall_state_dict(self, recall_source, *args, **kwargs):
|
||||
"""
|
||||
Loads a state dict from any source.
|
||||
Since a user may wish to load a model from a different source than their own tracker (i.e. tracking using wandb but recalling from disk),
|
||||
this should not be linked to any individual tracker.
|
||||
"""
|
||||
# TODO: Pull this into a dict or something similar so that we can add more sources without having a massive switch statement
|
||||
if recall_source == 'wandb':
|
||||
return load_wandb_state_dict(*args, **kwargs)
|
||||
elif recall_source == 'local':
|
||||
return load_local_state_dict(*args, **kwargs)
|
||||
else:
|
||||
raise ValueError('`recall_source` must be one of `wandb` or `local`')
|
||||
|
||||
|
||||
# basic stdout class
|
||||
|
||||
class ConsoleTracker(BaseTracker):
|
||||
def init(self, **config):
|
||||
print(config)
|
||||
|
||||
def log(self, log, **kwargs):
|
||||
print(log)
|
||||
|
||||
def log_images(self, images, **kwargs): # noop for logging images
|
||||
pass
|
||||
|
||||
def save_state_dict(self, state_dict, relative_path, **kwargs):
|
||||
torch.save(state_dict, str(self.data_path / relative_path))
|
||||
|
||||
# basic wandb class
|
||||
|
||||
class WandbTracker(BaseTracker):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb experiment tracker')
|
||||
os.environ["WANDB_SILENT"] = "true"
|
||||
|
||||
def init(self, **config):
|
||||
self.wandb.init(**config)
|
||||
|
||||
def log(self, log, verbose=False, **kwargs):
|
||||
if verbose:
|
||||
print(log)
|
||||
self.wandb.log(log, **kwargs)
|
||||
|
||||
def log_images(self, images, captions=[], image_section="images", **kwargs):
|
||||
"""
|
||||
Takes a tensor of images and a list of captions and logs them to wandb.
|
||||
"""
|
||||
wandb_images = [self.wandb.Image(image, caption=caption) for image, caption in zip_longest(images, captions)]
|
||||
self.log({ image_section: wandb_images }, **kwargs)
|
||||
|
||||
def save_state_dict(self, state_dict, relative_path, **kwargs):
|
||||
"""
|
||||
Saves a state_dict to disk and uploads it
|
||||
"""
|
||||
full_path = str(self.data_path / relative_path)
|
||||
torch.save(state_dict, full_path)
|
||||
self.wandb.save(full_path, base_path = str(self.data_path)) # Upload and keep relative to data_path
|
||||
@@ -1,6 +1,8 @@
|
||||
import time
|
||||
import copy
|
||||
from functools import partial
|
||||
from math import ceil
|
||||
from functools import partial, wraps
|
||||
from collections.abc import Iterable
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -9,11 +11,16 @@ from torch.cuda.amp import autocast, GradScaler
|
||||
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
|
||||
from dalle2_pytorch.optimizer import get_optimizer
|
||||
|
||||
import numpy as np
|
||||
|
||||
# helper functions
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
def cast_tuple(val, length = 1):
|
||||
return val if isinstance(val, tuple) else ((val,) * length)
|
||||
|
||||
@@ -40,6 +47,87 @@ def groupby_prefix_and_trim(prefix, d):
|
||||
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
|
||||
return kwargs_without_prefix, kwargs
|
||||
|
||||
def num_to_groups(num, divisor):
|
||||
groups = num // divisor
|
||||
remainder = num % divisor
|
||||
arr = [divisor] * groups
|
||||
if remainder > 0:
|
||||
arr.append(remainder)
|
||||
return arr
|
||||
|
||||
# decorators
|
||||
|
||||
def cast_torch_tensor(fn):
|
||||
@wraps(fn)
|
||||
def inner(model, *args, **kwargs):
|
||||
device = kwargs.pop('_device', next(model.parameters()).device)
|
||||
cast_device = kwargs.pop('_cast_device', True)
|
||||
|
||||
kwargs_keys = kwargs.keys()
|
||||
all_args = (*args, *kwargs.values())
|
||||
split_kwargs_index = len(all_args) - len(kwargs_keys)
|
||||
all_args = tuple(map(lambda t: torch.from_numpy(t) if exists(t) and isinstance(t, np.ndarray) else t, all_args))
|
||||
|
||||
if cast_device:
|
||||
all_args = tuple(map(lambda t: t.to(device) if exists(t) and isinstance(t, torch.Tensor) else t, all_args))
|
||||
|
||||
args, kwargs_values = all_args[:split_kwargs_index], all_args[split_kwargs_index:]
|
||||
kwargs = dict(tuple(zip(kwargs_keys, kwargs_values)))
|
||||
|
||||
out = fn(model, *args, **kwargs)
|
||||
return out
|
||||
return inner
|
||||
|
||||
# gradient accumulation functions
|
||||
|
||||
def split_iterable(it, split_size):
|
||||
accum = []
|
||||
for ind in range(ceil(len(it) / split_size)):
|
||||
start_index = ind * split_size
|
||||
accum.append(it[start_index: (start_index + split_size)])
|
||||
return accum
|
||||
|
||||
def split(t, split_size = None):
|
||||
if not exists(split_size):
|
||||
return t
|
||||
|
||||
if isinstance(t, torch.Tensor):
|
||||
return t.split(split_size, dim = 0)
|
||||
|
||||
if isinstance(t, Iterable):
|
||||
return split_iterable(t, split_size)
|
||||
|
||||
return TypeError
|
||||
|
||||
def find_first(cond, arr):
|
||||
for el in arr:
|
||||
if cond(el):
|
||||
return el
|
||||
return None
|
||||
|
||||
def split_args_and_kwargs(*args, split_size = None, **kwargs):
|
||||
all_args = (*args, *kwargs.values())
|
||||
len_all_args = len(all_args)
|
||||
first_tensor = find_first(lambda t: isinstance(t, torch.Tensor), all_args)
|
||||
assert exists(first_tensor)
|
||||
|
||||
batch_size = len(first_tensor)
|
||||
split_size = default(split_size, batch_size)
|
||||
num_chunks = ceil(batch_size / split_size)
|
||||
|
||||
dict_len = len(kwargs)
|
||||
dict_keys = kwargs.keys()
|
||||
split_kwargs_index = len_all_args - dict_len
|
||||
|
||||
split_all_args = [split(arg, split_size = split_size) if exists(arg) and isinstance(arg, (torch.Tensor, Iterable)) else ((arg,) * num_chunks) for arg in all_args]
|
||||
chunk_sizes = tuple(map(len, split_all_args[0]))
|
||||
|
||||
for (chunk_size, *chunked_all_args) in tuple(zip(chunk_sizes, *split_all_args)):
|
||||
chunked_args, chunked_kwargs_values = chunked_all_args[:split_kwargs_index], chunked_all_args[split_kwargs_index:]
|
||||
chunked_kwargs = dict(tuple(zip(dict_keys, chunked_kwargs_values)))
|
||||
chunk_size_frac = chunk_size / batch_size
|
||||
yield chunk_size_frac, (chunked_args, chunked_kwargs)
|
||||
|
||||
# print helpers
|
||||
|
||||
def print_ribbon(s, symbol = '=', repeat = 40):
|
||||
@@ -71,7 +159,7 @@ def load_diffusion_model(dprior_path, device):
|
||||
# Load state dict from saved model
|
||||
diffusion_prior.load_state_dict(loaded_obj['model'])
|
||||
|
||||
return diffusion_prior
|
||||
return diffusion_prior, loaded_obj
|
||||
|
||||
def save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim):
|
||||
# Saving State Dict
|
||||
@@ -90,7 +178,7 @@ class EMA(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
beta = 0.99,
|
||||
beta = 0.9999,
|
||||
update_after_step = 1000,
|
||||
update_every = 10,
|
||||
):
|
||||
@@ -99,20 +187,31 @@ class EMA(nn.Module):
|
||||
self.online_model = model
|
||||
self.ema_model = copy.deepcopy(model)
|
||||
|
||||
self.update_after_step = update_after_step # only start EMA after this step number, starting at 0
|
||||
self.update_every = update_every
|
||||
self.update_after_step = update_after_step // update_every # only start EMA after this step number, starting at 0
|
||||
|
||||
self.register_buffer('initted', torch.Tensor([False]))
|
||||
self.register_buffer('step', torch.tensor([0.]))
|
||||
|
||||
def restore_ema_model_device(self):
|
||||
device = self.initted.device
|
||||
self.ema_model.to(device)
|
||||
|
||||
def copy_params_from_model_to_ema(self):
|
||||
self.ema_model.state_dict(self.online_model.state_dict())
|
||||
|
||||
def update(self):
|
||||
self.step += 1
|
||||
|
||||
if self.step <= self.update_after_step or (self.step % self.update_every) != 0:
|
||||
if (self.step % self.update_every) != 0:
|
||||
return
|
||||
|
||||
if self.step <= self.update_after_step:
|
||||
self.copy_params_from_model_to_ema()
|
||||
return
|
||||
|
||||
if not self.initted:
|
||||
self.ema_model.state_dict(self.online_model.state_dict())
|
||||
self.copy_params_from_model_to_ema()
|
||||
self.initted.data.copy_(torch.Tensor([True]))
|
||||
|
||||
self.update_moving_average(self.ema_model, self.online_model)
|
||||
@@ -136,6 +235,16 @@ class EMA(nn.Module):
|
||||
|
||||
# diffusion prior trainer
|
||||
|
||||
def prior_sample_in_chunks(fn):
|
||||
@wraps(fn)
|
||||
def inner(self, *args, max_batch_size = None, **kwargs):
|
||||
if not exists(max_batch_size):
|
||||
return fn(self, *args, **kwargs)
|
||||
|
||||
outputs = [fn(self, *chunked_args, **chunked_kwargs) for _, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs)]
|
||||
return torch.cat(outputs, dim = 0)
|
||||
return inner
|
||||
|
||||
class DiffusionPriorTrainer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -143,6 +252,7 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
use_ema = True,
|
||||
lr = 3e-4,
|
||||
wd = 1e-2,
|
||||
eps = 1e-6,
|
||||
max_grad_norm = None,
|
||||
amp = False,
|
||||
**kwargs
|
||||
@@ -169,6 +279,7 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
diffusion_prior.parameters(),
|
||||
lr = lr,
|
||||
wd = wd,
|
||||
eps = eps,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@@ -176,6 +287,8 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
|
||||
self.max_grad_norm = max_grad_norm
|
||||
|
||||
self.register_buffer('step', torch.tensor([0.]))
|
||||
|
||||
def update(self):
|
||||
if exists(self.max_grad_norm):
|
||||
self.scaler.unscale_(self.optimizer)
|
||||
@@ -188,38 +301,72 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
if self.use_ema:
|
||||
self.ema_diffusion_prior.update()
|
||||
|
||||
@torch.inference_mode()
|
||||
self.step += 1
|
||||
|
||||
@torch.no_grad()
|
||||
@cast_torch_tensor
|
||||
@prior_sample_in_chunks
|
||||
def p_sample_loop(self, *args, **kwargs):
|
||||
return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs)
|
||||
|
||||
@torch.inference_mode()
|
||||
@torch.no_grad()
|
||||
@cast_torch_tensor
|
||||
@prior_sample_in_chunks
|
||||
def sample(self, *args, **kwargs):
|
||||
return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs)
|
||||
|
||||
@torch.inference_mode()
|
||||
@torch.no_grad()
|
||||
def sample_batch_size(self, *args, **kwargs):
|
||||
return self.ema_diffusion_prior.ema_model.sample_batch_size(*args, **kwargs)
|
||||
|
||||
@cast_torch_tensor
|
||||
def forward(
|
||||
self,
|
||||
*args,
|
||||
divisor = 1,
|
||||
max_batch_size = None,
|
||||
**kwargs
|
||||
):
|
||||
with autocast(enabled = self.amp):
|
||||
loss = self.diffusion_prior(*args, **kwargs)
|
||||
return self.scaler.scale(loss / divisor)
|
||||
total_loss = 0.
|
||||
|
||||
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
|
||||
with autocast(enabled = self.amp):
|
||||
loss = self.diffusion_prior(*chunked_args, **chunked_kwargs)
|
||||
loss = loss * chunk_size_frac
|
||||
|
||||
total_loss += loss.item()
|
||||
|
||||
if self.training:
|
||||
self.scaler.scale(loss).backward()
|
||||
|
||||
return total_loss
|
||||
|
||||
# decoder trainer
|
||||
|
||||
def decoder_sample_in_chunks(fn):
|
||||
@wraps(fn)
|
||||
def inner(self, *args, max_batch_size = None, **kwargs):
|
||||
if not exists(max_batch_size):
|
||||
return fn(self, *args, **kwargs)
|
||||
|
||||
if self.decoder.unconditional:
|
||||
batch_size = kwargs.get('batch_size')
|
||||
batch_sizes = num_to_groups(batch_size, max_batch_size)
|
||||
outputs = [fn(self, *args, **{**kwargs, 'batch_size': sub_batch_size}) for sub_batch_size in batch_sizes]
|
||||
else:
|
||||
outputs = [fn(self, *chunked_args, **chunked_kwargs) for _, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs)]
|
||||
|
||||
return torch.cat(outputs, dim = 0)
|
||||
return inner
|
||||
|
||||
class DecoderTrainer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
decoder,
|
||||
use_ema = True,
|
||||
lr = 3e-4,
|
||||
lr = 1e-4,
|
||||
wd = 1e-2,
|
||||
max_grad_norm = None,
|
||||
eps = 1e-8,
|
||||
max_grad_norm = 0.5,
|
||||
amp = False,
|
||||
**kwargs
|
||||
):
|
||||
@@ -231,11 +378,6 @@ class DecoderTrainer(nn.Module):
|
||||
self.num_unets = len(self.decoder.unets)
|
||||
|
||||
self.use_ema = use_ema
|
||||
|
||||
if use_ema:
|
||||
has_lazy_linear = any([type(module) == nn.LazyLinear for module in decoder.modules()])
|
||||
assert not has_lazy_linear, 'you must set the text_embed_dim on your u-nets if you plan on doing automatic exponential moving average'
|
||||
|
||||
self.ema_unets = nn.ModuleList([])
|
||||
|
||||
self.amp = amp
|
||||
@@ -243,13 +385,14 @@ class DecoderTrainer(nn.Module):
|
||||
# be able to finely customize learning rate, weight decay
|
||||
# per unet
|
||||
|
||||
lr, wd = map(partial(cast_tuple, length = self.num_unets), (lr, wd))
|
||||
lr, wd, eps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps))
|
||||
|
||||
for ind, (unet, unet_lr, unet_wd) in enumerate(zip(self.decoder.unets, lr, wd)):
|
||||
for ind, (unet, unet_lr, unet_wd, unet_eps) in enumerate(zip(self.decoder.unets, lr, wd, eps)):
|
||||
optimizer = get_optimizer(
|
||||
unet.parameters(),
|
||||
lr = unet_lr,
|
||||
wd = unet_wd,
|
||||
eps = unet_eps,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@@ -265,6 +408,8 @@ class DecoderTrainer(nn.Module):
|
||||
|
||||
self.max_grad_norm = max_grad_norm
|
||||
|
||||
self.register_buffer('step', torch.tensor([0.]))
|
||||
|
||||
@property
|
||||
def unets(self):
|
||||
return nn.ModuleList([ema.ema_model for ema in self.ema_unets])
|
||||
@@ -275,8 +420,11 @@ class DecoderTrainer(nn.Module):
|
||||
scaler = getattr(self, f'scaler{index}')
|
||||
return scaler.scale(loss)
|
||||
|
||||
def update(self, unet_number):
|
||||
assert 1 <= unet_number <= self.num_unets
|
||||
def update(self, unet_number = None):
|
||||
if self.num_unets == 1:
|
||||
unet_number = default(unet_number, 1)
|
||||
|
||||
assert exists(unet_number) and 1 <= unet_number <= self.num_unets
|
||||
index = unet_number - 1
|
||||
unet = self.decoder.unets[index]
|
||||
|
||||
@@ -295,26 +443,49 @@ class DecoderTrainer(nn.Module):
|
||||
ema_unet = self.ema_unets[index]
|
||||
ema_unet.update()
|
||||
|
||||
self.step += 1
|
||||
|
||||
@torch.no_grad()
|
||||
@cast_torch_tensor
|
||||
@decoder_sample_in_chunks
|
||||
def sample(self, *args, **kwargs):
|
||||
if self.use_ema:
|
||||
trainable_unets = self.decoder.unets
|
||||
self.decoder.unets = self.unets # swap in exponential moving averaged unets for sampling
|
||||
if kwargs.pop('use_non_ema', False) or not self.use_ema:
|
||||
return self.decoder.sample(*args, **kwargs)
|
||||
|
||||
trainable_unets = self.decoder.unets
|
||||
self.decoder.unets = self.unets # swap in exponential moving averaged unets for sampling
|
||||
|
||||
output = self.decoder.sample(*args, **kwargs)
|
||||
|
||||
if self.use_ema:
|
||||
self.decoder.unets = trainable_unets # restore original training unets
|
||||
self.decoder.unets = trainable_unets # restore original training unets
|
||||
|
||||
# cast the ema_model unets back to original device
|
||||
for ema in self.ema_unets:
|
||||
ema.restore_ema_model_device()
|
||||
|
||||
return output
|
||||
|
||||
@cast_torch_tensor
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
*,
|
||||
unet_number,
|
||||
divisor = 1,
|
||||
*args,
|
||||
unet_number = None,
|
||||
max_batch_size = None,
|
||||
**kwargs
|
||||
):
|
||||
with autocast(enabled = self.amp):
|
||||
loss = self.decoder(x, unet_number = unet_number, **kwargs)
|
||||
return self.scale(loss / divisor, unet_number = unet_number)
|
||||
if self.num_unets == 1:
|
||||
unet_number = default(unet_number, 1)
|
||||
|
||||
total_loss = 0.
|
||||
|
||||
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
|
||||
with autocast(enabled = self.amp):
|
||||
loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs)
|
||||
loss = loss * chunk_size_frac
|
||||
|
||||
total_loss += loss.item()
|
||||
|
||||
if self.training:
|
||||
self.scale(loss, unet_number = unet_number).backward()
|
||||
|
||||
return total_loss
|
||||
BIN
samples/oxford.png
Normal file
BIN
samples/oxford.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 985 KiB |
3
setup.py
3
setup.py
@@ -10,7 +10,7 @@ setup(
|
||||
'dream = dalle2_pytorch.cli:dream'
|
||||
],
|
||||
},
|
||||
version = '0.2.8',
|
||||
version = '0.3.5',
|
||||
license='MIT',
|
||||
description = 'DALL-E 2',
|
||||
author = 'Phil Wang',
|
||||
@@ -30,6 +30,7 @@ setup(
|
||||
'einops-exts>=0.0.3',
|
||||
'embedding-reader',
|
||||
'kornia>=0.5.4',
|
||||
'numpy',
|
||||
'pillow',
|
||||
'resize-right>=0.0.2',
|
||||
'rotary-embedding-torch',
|
||||
|
||||
@@ -1,354 +1,375 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
import click
|
||||
import math
|
||||
import argparse
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import clip
|
||||
from torch import nn
|
||||
from embedding_reader import EmbeddingReader
|
||||
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork
|
||||
from dalle2_pytorch.train import load_diffusion_model, save_diffusion_model, print_ribbon
|
||||
from dalle2_pytorch.optimizer import get_optimizer
|
||||
from torch.cuda.amp import autocast,GradScaler
|
||||
|
||||
import time
|
||||
from dalle2_pytorch.dataloaders import make_splits
|
||||
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter
|
||||
from dalle2_pytorch.trainer import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model, print_ribbon
|
||||
|
||||
from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker
|
||||
|
||||
from embedding_reader import EmbeddingReader
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
import wandb
|
||||
os.environ["WANDB_SILENT"] = "true"
|
||||
NUM_TEST_EMBEDDINGS = 100 # for cosine similarity reporting during training
|
||||
REPORT_METRICS_EVERY = 100 # for cosine similarity and other metric reporting during training
|
||||
# constants
|
||||
|
||||
REPORT_METRICS_EVERY = 250 # for cosine similarity and other metric reporting during training
|
||||
|
||||
def eval_model(model,device,image_reader,text_reader,start,end,batch_size,loss_type,phase="Validation"):
|
||||
tracker = WandbTracker()
|
||||
|
||||
# helpers functions
|
||||
|
||||
def exists(val):
|
||||
val is not None
|
||||
|
||||
class Timer:
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.last_time = time.time()
|
||||
|
||||
def elapsed(self):
|
||||
return time.time() - self.last_time
|
||||
|
||||
# functions
|
||||
|
||||
def eval_model(model, dataloader, text_conditioned, loss_type, phase="Validation"):
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
total_loss = 0.
|
||||
total_samples = 0.
|
||||
|
||||
for emb_images, emb_text in zip(image_reader(batch_size=batch_size, start=start, end=end),
|
||||
text_reader(batch_size=batch_size, start=start, end=end)):
|
||||
for image_embeddings, text_data in tqdm(dataloader):
|
||||
|
||||
emb_images_tensor = torch.tensor(emb_images[0]).to(device)
|
||||
emb_text_tensor = torch.tensor(emb_text[0]).to(device)
|
||||
batches = image_embeddings.shape[0]
|
||||
|
||||
batches = emb_images_tensor.shape[0]
|
||||
input_args = dict(image_embed=image_embeddings)
|
||||
if text_conditioned:
|
||||
input_args = dict(**input_args, text = text_data)
|
||||
else:
|
||||
input_args = dict(**input_args, text_embed=text_data)
|
||||
|
||||
loss = model(text_embed = emb_text_tensor, image_embed = emb_images_tensor)
|
||||
loss = model(**input_args)
|
||||
|
||||
total_loss += loss.item() * batches
|
||||
total_loss += loss * batches
|
||||
total_samples += batches
|
||||
|
||||
avg_loss = (total_loss / total_samples)
|
||||
wandb.log({f'{phase} {loss_type}': avg_loss})
|
||||
|
||||
def report_cosine_sims(diffusion_prior,image_reader,text_reader,train_set_size,NUM_TEST_EMBEDDINGS,device):
|
||||
tracker.log({f'{phase} {loss_type}': avg_loss})
|
||||
|
||||
def report_cosine_sims(diffusion_prior, dataloader, text_conditioned):
|
||||
diffusion_prior.eval()
|
||||
|
||||
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
|
||||
|
||||
tstart = train_set_size
|
||||
tend = train_set_size+NUM_TEST_EMBEDDINGS
|
||||
for test_image_embeddings, text_data in tqdm(dataloader):
|
||||
|
||||
# we are text conditioned, we produce an embedding from the tokenized text
|
||||
if text_conditioned:
|
||||
text_embedding, text_encodings, text_mask = diffusion_prior.clip.embed_text(
|
||||
text_data)
|
||||
text_cond = dict(text_embed=text_embedding,
|
||||
text_encodings=text_encodings, mask=text_mask)
|
||||
else:
|
||||
text_embedding = text_data
|
||||
text_cond = dict(text_embed=text_embedding)
|
||||
|
||||
# make a copy of the text embeddings for shuffling
|
||||
text_embed_shuffled = text_embedding.clone()
|
||||
|
||||
# roll the text to simulate "unrelated" captions
|
||||
rolled_idx = torch.roll(torch.arange(text_embedding.shape[0]), 1)
|
||||
text_embed_shuffled = text_embed_shuffled[rolled_idx]
|
||||
text_embed_shuffled = text_embed_shuffled / \
|
||||
text_embed_shuffled.norm(dim=1, keepdim=True)
|
||||
|
||||
if text_conditioned:
|
||||
text_encodings_shuffled = text_encodings[rolled_idx]
|
||||
text_mask_shuffled = text_mask[rolled_idx]
|
||||
else:
|
||||
text_encodings_shuffled = None
|
||||
text_mask_shuffled = None
|
||||
|
||||
text_cond_shuffled = dict(text_embed=text_embed_shuffled,
|
||||
text_encodings=text_encodings_shuffled, mask=text_mask_shuffled)
|
||||
|
||||
for embt, embi in zip(text_reader(batch_size=NUM_TEST_EMBEDDINGS, start=tstart, end=tend),
|
||||
image_reader(batch_size=NUM_TEST_EMBEDDINGS, start=tstart, end=tend)):
|
||||
# make a copy of the text embeddings for shuffling
|
||||
text_embed = torch.tensor(embt[0]).to(device)
|
||||
text_embed_shuffled = text_embed.clone()
|
||||
# roll the text embeddings to simulate "unrelated" captions
|
||||
rolled_idx = torch.roll(torch.arange(NUM_TEST_EMBEDDINGS), 1)
|
||||
text_embed_shuffled = text_embed_shuffled[rolled_idx]
|
||||
text_embed_shuffled = text_embed_shuffled / \
|
||||
text_embed_shuffled.norm(dim=1, keepdim=True)
|
||||
test_text_shuffled_cond = dict(text_embed=text_embed_shuffled)
|
||||
# prepare the text embedding
|
||||
text_embed = text_embed / text_embed.norm(dim=1, keepdim=True)
|
||||
test_text_cond = dict(text_embed=text_embed)
|
||||
text_embed = text_embedding / text_embedding.norm(dim=1, keepdim=True)
|
||||
|
||||
# prepare image embeddings
|
||||
test_image_embeddings = torch.tensor(embi[0]).to(device)
|
||||
test_image_embeddings = test_image_embeddings / \
|
||||
test_image_embeddings.norm(dim=1, keepdim=True)
|
||||
test_image_embeddings = test_image_embeddings / \
|
||||
test_image_embeddings.norm(dim=1, keepdim=True)
|
||||
|
||||
# predict on the unshuffled text embeddings
|
||||
predicted_image_embeddings = diffusion_prior.p_sample_loop(
|
||||
(NUM_TEST_EMBEDDINGS, 768), text_cond=test_text_cond)
|
||||
predicted_image_embeddings = predicted_image_embeddings / \
|
||||
predicted_image_embeddings.norm(dim=1, keepdim=True)
|
||||
predicted_image_embeddings = diffusion_prior.p_sample_loop(
|
||||
test_image_embeddings.shape, text_cond)
|
||||
predicted_image_embeddings = predicted_image_embeddings / \
|
||||
predicted_image_embeddings.norm(dim=1, keepdim=True)
|
||||
|
||||
# predict on the shuffled embeddings
|
||||
predicted_unrelated_embeddings = diffusion_prior.p_sample_loop(
|
||||
(NUM_TEST_EMBEDDINGS, 768), text_cond=test_text_shuffled_cond)
|
||||
predicted_unrelated_embeddings = predicted_unrelated_embeddings / \
|
||||
predicted_unrelated_embeddings.norm(dim=1, keepdim=True)
|
||||
predicted_unrelated_embeddings = diffusion_prior.p_sample_loop(
|
||||
test_image_embeddings.shape, text_cond_shuffled)
|
||||
predicted_unrelated_embeddings = predicted_unrelated_embeddings / \
|
||||
predicted_unrelated_embeddings.norm(dim=1, keepdim=True)
|
||||
|
||||
# calculate similarities
|
||||
original_similarity = cos(
|
||||
original_similarity = cos(
|
||||
text_embed, test_image_embeddings).cpu().numpy()
|
||||
predicted_similarity = cos(
|
||||
predicted_similarity = cos(
|
||||
text_embed, predicted_image_embeddings).cpu().numpy()
|
||||
unrelated_similarity = cos(
|
||||
unrelated_similarity = cos(
|
||||
text_embed, predicted_unrelated_embeddings).cpu().numpy()
|
||||
predicted_img_similarity = cos(
|
||||
predicted_img_similarity = cos(
|
||||
test_image_embeddings, predicted_image_embeddings).cpu().numpy()
|
||||
wandb.log({"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity),
|
||||
tracker.log({"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity),
|
||||
"CosineSimilarity(text_embed,predicted_image_embed)":np.mean(predicted_similarity),
|
||||
"CosineSimilarity(orig_image_embed,predicted_image_embed)":np.mean(predicted_img_similarity),
|
||||
"CosineSimilarity(text_embed,predicted_unrelated_embed)": np.mean(unrelated_similarity),
|
||||
"Cosine similarity difference":np.mean(predicted_similarity - original_similarity)})
|
||||
|
||||
def train(image_embed_dim,
|
||||
image_embed_url,
|
||||
text_embed_url,
|
||||
batch_size,
|
||||
train_percent,
|
||||
val_percent,
|
||||
test_percent,
|
||||
num_epochs,
|
||||
dp_loss_type,
|
||||
clip,
|
||||
dp_condition_on_text_encodings,
|
||||
dp_timesteps,
|
||||
dp_normformer,
|
||||
dp_cond_drop_prob,
|
||||
dpn_depth,
|
||||
dpn_dim_head,
|
||||
dpn_heads,
|
||||
save_interval,
|
||||
save_path,
|
||||
device,
|
||||
RESUME,
|
||||
DPRIOR_PATH,
|
||||
config,
|
||||
wandb_entity,
|
||||
wandb_project,
|
||||
learning_rate=0.001,
|
||||
max_grad_norm=0.5,
|
||||
weight_decay=0.01,
|
||||
dropout=0.05,
|
||||
amp=False):
|
||||
|
||||
# DiffusionPriorNetwork
|
||||
@click.command()
|
||||
@click.option("--wandb-entity", default="laion")
|
||||
@click.option("--wandb-project", default="diffusion-prior")
|
||||
@click.option("--wandb-dataset", default="LAION-5B")
|
||||
@click.option("--wandb-arch", default="DiffusionPrior")
|
||||
@click.option("--image-embed-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/")
|
||||
@click.option("--text-embed-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/")
|
||||
@click.option("--meta-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/laion2B-en-metadata/")
|
||||
@click.option("--learning-rate", default=1.1e-4)
|
||||
@click.option("--weight-decay", default=6.02e-2)
|
||||
@click.option("--dropout", default=5e-2)
|
||||
@click.option("--max-grad-norm", default=0.5)
|
||||
@click.option("--num-data-points", default=250e6)
|
||||
@click.option("--batch-size", default=320)
|
||||
@click.option("--num-epochs", default=5)
|
||||
@click.option("--image-embed-dim", default=768)
|
||||
@click.option("--train-percent", default=0.9)
|
||||
@click.option("--val-percent", default=1e-7)
|
||||
@click.option("--test-percent", default=0.0999999)
|
||||
@click.option("--dpn-depth", default=12)
|
||||
@click.option("--dpn-dim-head", default=64)
|
||||
@click.option("--dpn-heads", default=12)
|
||||
@click.option("--dp-condition-on-text-encodings", default=True)
|
||||
@click.option("--dp-timesteps", default=1000)
|
||||
@click.option("--dp-normformer", default=True)
|
||||
@click.option("--dp-cond-drop-prob", default=0.1)
|
||||
@click.option("--dp-loss-type", default="l2")
|
||||
@click.option("--clip", default="ViT-L/14")
|
||||
@click.option("--amp", default=False)
|
||||
@click.option("--save-interval", default=120)
|
||||
@click.option("--save-path", default="./diffusion_prior_checkpoints")
|
||||
@click.option("--pretrained-model-path", default=None)
|
||||
@click.option("--gpu-device", default=0)
|
||||
def train(
|
||||
wandb_entity,
|
||||
wandb_project,
|
||||
wandb_dataset,
|
||||
wandb_arch,
|
||||
image_embed_url,
|
||||
text_embed_url,
|
||||
meta_url,
|
||||
learning_rate,
|
||||
weight_decay,
|
||||
dropout,
|
||||
max_grad_norm,
|
||||
num_data_points,
|
||||
batch_size,
|
||||
num_epochs,
|
||||
image_embed_dim,
|
||||
train_percent,
|
||||
val_percent,
|
||||
test_percent,
|
||||
dpn_depth,
|
||||
dpn_dim_head,
|
||||
dpn_heads,
|
||||
dp_condition_on_text_encodings,
|
||||
dp_timesteps,
|
||||
dp_normformer,
|
||||
dp_cond_drop_prob,
|
||||
dp_loss_type,
|
||||
clip,
|
||||
amp,
|
||||
save_interval,
|
||||
save_path,
|
||||
pretrained_model_path,
|
||||
gpu_device
|
||||
):
|
||||
config = {
|
||||
"learning_rate": learning_rate,
|
||||
"architecture": wandb_arch,
|
||||
"dataset": wandb_dataset,
|
||||
"weight_decay": weight_decay,
|
||||
"max_gradient_clipping_norm": max_grad_norm,
|
||||
"batch_size": batch_size,
|
||||
"epochs": num_epochs,
|
||||
"diffusion_prior_network": {
|
||||
"depth": dpn_depth,
|
||||
"dim_head": dpn_dim_head,
|
||||
"heads": dpn_heads,
|
||||
"normformer": dp_normformer
|
||||
},
|
||||
"diffusion_prior": {
|
||||
"condition_on_text_encodings": dp_condition_on_text_encodings,
|
||||
"timesteps": dp_timesteps,
|
||||
"cond_drop_prob": dp_cond_drop_prob,
|
||||
"loss_type": dp_loss_type,
|
||||
"clip": clip
|
||||
}
|
||||
}
|
||||
|
||||
# Check if DPRIOR_PATH exists(saved model path)
|
||||
|
||||
DPRIOR_PATH = pretrained_model_path
|
||||
RESUME = exists(DPRIOR_PATH)
|
||||
|
||||
if not RESUME:
|
||||
tracker.init(
|
||||
entity = wandb_entity,
|
||||
project = wandb_project,
|
||||
config = config
|
||||
)
|
||||
|
||||
# Obtain the utilized device.
|
||||
|
||||
has_cuda = torch.cuda.is_available()
|
||||
if has_cuda:
|
||||
device = torch.device(f"cuda:{gpu_device}")
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
# Training loop
|
||||
# diffusion prior network
|
||||
|
||||
prior_network = DiffusionPriorNetwork(
|
||||
dim = image_embed_dim,
|
||||
depth = dpn_depth,
|
||||
dim_head = dpn_dim_head,
|
||||
heads = dpn_heads,
|
||||
attn_dropout = dropout,
|
||||
ff_dropout = dropout,
|
||||
normformer = dp_normformer).to(device)
|
||||
dim = image_embed_dim,
|
||||
depth = dpn_depth,
|
||||
dim_head = dpn_dim_head,
|
||||
heads = dpn_heads,
|
||||
attn_dropout = dropout,
|
||||
ff_dropout = dropout,
|
||||
normformer = dp_normformer
|
||||
)
|
||||
|
||||
# DiffusionPrior with text embeddings and image embeddings pre-computed
|
||||
# Load clip model if text-conditioning
|
||||
if dp_condition_on_text_encodings:
|
||||
clip_adapter = OpenAIClipAdapter(clip)
|
||||
else:
|
||||
clip_adapter = None
|
||||
|
||||
# diffusion prior with text embeddings and image embeddings pre-computed
|
||||
|
||||
diffusion_prior = DiffusionPrior(
|
||||
net = prior_network,
|
||||
clip = clip,
|
||||
image_embed_dim = image_embed_dim,
|
||||
timesteps = dp_timesteps,
|
||||
cond_drop_prob = dp_cond_drop_prob,
|
||||
loss_type = dp_loss_type,
|
||||
condition_on_text_encodings = dp_condition_on_text_encodings).to(device)
|
||||
net = prior_network,
|
||||
clip = clip_adapter,
|
||||
image_embed_dim = image_embed_dim,
|
||||
timesteps = dp_timesteps,
|
||||
cond_drop_prob = dp_cond_drop_prob,
|
||||
loss_type = dp_loss_type,
|
||||
condition_on_text_encodings = dp_condition_on_text_encodings
|
||||
)
|
||||
|
||||
# Load pre-trained model from DPRIOR_PATH
|
||||
|
||||
if RESUME:
|
||||
diffusion_prior=load_diffusion_model(DPRIOR_PATH,device)
|
||||
wandb.init( entity=wandb_entity, project=wandb_project, config=config)
|
||||
diffusion_prior, loaded_obj = load_diffusion_model(DPRIOR_PATH, device)
|
||||
tracker.init(entity = wandb_entity, project = wandb_project, config = config)
|
||||
|
||||
# diffusion prior trainer
|
||||
|
||||
trainer = DiffusionPriorTrainer(
|
||||
diffusion_prior = diffusion_prior,
|
||||
lr = learning_rate,
|
||||
wd = weight_decay,
|
||||
max_grad_norm = max_grad_norm,
|
||||
amp = amp,
|
||||
).to(device)
|
||||
|
||||
# load optimizer and scaler
|
||||
|
||||
if RESUME:
|
||||
trainer.optimizer.load_state_dict(loaded_obj['optimizer'])
|
||||
trainer.scaler.load_state_dict(loaded_obj['scaler'])
|
||||
|
||||
# Create save_path if it doesn't exist
|
||||
if not os.path.exists(save_path):
|
||||
os.makedirs(save_path)
|
||||
|
||||
# Get image and text embeddings from the servers
|
||||
print_ribbon("Downloading embeddings - image and text")
|
||||
image_reader = EmbeddingReader(embeddings_folder=image_embed_url, file_format="npy")
|
||||
text_reader = EmbeddingReader(embeddings_folder=text_embed_url, file_format="npy")
|
||||
num_data_points = text_reader.count
|
||||
Path(save_path).mkdir(exist_ok = True, parents = True)
|
||||
|
||||
# Utilize wrapper to abstract away loader logic
|
||||
print_ribbon("Downloading Embeddings")
|
||||
loader_args = dict(text_conditioned=dp_condition_on_text_encodings, batch_size=batch_size, num_data_points=num_data_points,
|
||||
train_split=train_percent, eval_split=val_percent, device=device, img_url=image_embed_url)
|
||||
|
||||
if dp_condition_on_text_encodings:
|
||||
loader_args = dict(**loader_args, meta_url=meta_url)
|
||||
else:
|
||||
loader_args = dict(**loader_args, txt_url=text_embed_url)
|
||||
|
||||
train_loader, eval_loader, test_loader = make_splits(**loader_args)
|
||||
|
||||
### Training code ###
|
||||
scaler = GradScaler(enabled=amp)
|
||||
optimizer = get_optimizer(diffusion_prior.net.parameters(), wd=weight_decay, lr=learning_rate)
|
||||
|
||||
step = 1
|
||||
timer = Timer()
|
||||
epochs = num_epochs
|
||||
|
||||
step = 0
|
||||
t = time.time()
|
||||
|
||||
train_set_size = int(train_percent*num_data_points)
|
||||
val_set_size = int(val_percent*num_data_points)
|
||||
eval_start = train_set_size
|
||||
|
||||
for _ in range(epochs):
|
||||
|
||||
for emb_images,emb_text in zip(image_reader(batch_size=batch_size, start=0, end=train_set_size),
|
||||
text_reader(batch_size=batch_size, start=0, end=train_set_size)):
|
||||
|
||||
for image, text in tqdm(train_loader):
|
||||
|
||||
diffusion_prior.train()
|
||||
|
||||
emb_images_tensor = torch.tensor(emb_images[0]).to(device)
|
||||
emb_text_tensor = torch.tensor(emb_text[0]).to(device)
|
||||
input_args = dict(image_embed=image)
|
||||
if dp_condition_on_text_encodings:
|
||||
input_args = dict(**input_args, text = text)
|
||||
else:
|
||||
input_args = dict(**input_args, text_embed=text)
|
||||
|
||||
with autocast(enabled=amp):
|
||||
loss = diffusion_prior(text_embed = emb_text_tensor,image_embed = emb_images_tensor)
|
||||
scaler.scale(loss).backward()
|
||||
loss = trainer(**input_args)
|
||||
|
||||
# Samples per second
|
||||
step+=1
|
||||
samples_per_sec = batch_size*step/(time.time()-t)
|
||||
|
||||
samples_per_sec = batch_size * step / timer.elapsed()
|
||||
|
||||
# Save checkpoint every save_interval minutes
|
||||
if(int(time.time()-t) >= 60*save_interval):
|
||||
t = time.time()
|
||||
if(int(timer.elapsed()) >= 60 * save_interval):
|
||||
timer.reset()
|
||||
|
||||
save_diffusion_model(
|
||||
save_path,
|
||||
diffusion_prior,
|
||||
optimizer,
|
||||
scaler,
|
||||
trainer.optimizer,
|
||||
trainer.scaler,
|
||||
config,
|
||||
image_embed_dim)
|
||||
|
||||
# Log to wandb
|
||||
wandb.log({"Training loss": loss.item(),
|
||||
tracker.log({"Training loss": loss,
|
||||
"Steps": step,
|
||||
"Samples per second": samples_per_sec})
|
||||
# Log cosineSim(text_embed,predicted_image_embed) - cosineSim(text_embed,image_embed)
|
||||
# Use NUM_TEST_EMBEDDINGS samples from the test set each time
|
||||
# Get embeddings from the most recently saved model
|
||||
if(step % REPORT_METRICS_EVERY) == 0:
|
||||
report_cosine_sims(diffusion_prior,
|
||||
image_reader,
|
||||
text_reader,
|
||||
train_set_size,
|
||||
NUM_TEST_EMBEDDINGS,
|
||||
device)
|
||||
report_cosine_sims(diffusion_prior, eval_loader, dp_condition_on_text_encodings)
|
||||
### Evaluate model(validation run) ###
|
||||
eval_model(diffusion_prior,
|
||||
device,
|
||||
image_reader,
|
||||
text_reader,
|
||||
eval_start,
|
||||
eval_start+NUM_TEST_EMBEDDINGS,
|
||||
NUM_TEST_EMBEDDINGS,
|
||||
dp_loss_type,
|
||||
phase="Validation")
|
||||
eval_model(diffusion_prior, eval_loader, dp_condition_on_text_encodings, dp_loss_type, phase="Validation")
|
||||
|
||||
scaler.unscale_(optimizer)
|
||||
nn.utils.clip_grad_norm_(diffusion_prior.parameters(), max_grad_norm)
|
||||
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
optimizer.zero_grad()
|
||||
step += 1
|
||||
trainer.update()
|
||||
|
||||
### Test run ###
|
||||
test_set_size = int(test_percent*train_set_size)
|
||||
start=train_set_size+val_set_size
|
||||
end=num_data_points
|
||||
eval_model(diffusion_prior,device,image_reader,text_reader,start,end,batch_size,dp_loss_type,phase="Test")
|
||||
eval_model(diffusion_prior, test_loader, dp_condition_on_text_encodings, dp_loss_type, phase="Test")
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
# Logging
|
||||
parser.add_argument("--wandb-entity", type=str, default="laion")
|
||||
parser.add_argument("--wandb-project", type=str, default="diffusion-prior")
|
||||
parser.add_argument("--wandb-dataset", type=str, default="LAION-5B")
|
||||
parser.add_argument("--wandb-arch", type=str, default="DiffusionPrior")
|
||||
# URLs for embeddings
|
||||
parser.add_argument("--image-embed-url", type=str, default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/")
|
||||
parser.add_argument("--text-embed-url", type=str, default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/")
|
||||
# Hyperparameters
|
||||
parser.add_argument("--learning-rate", type=float, default=1.1e-4)
|
||||
parser.add_argument("--weight-decay", type=float, default=6.02e-2)
|
||||
parser.add_argument("--dropout", type=float, default=5e-2)
|
||||
parser.add_argument("--max-grad-norm", type=float, default=0.5)
|
||||
parser.add_argument("--batch-size", type=int, default=10**4)
|
||||
parser.add_argument("--num-epochs", type=int, default=5)
|
||||
# Image embed dimension
|
||||
parser.add_argument("--image-embed-dim", type=int, default=768)
|
||||
# Train-test split
|
||||
parser.add_argument("--train-percent", type=float, default=0.7)
|
||||
parser.add_argument("--val-percent", type=float, default=0.2)
|
||||
parser.add_argument("--test-percent", type=float, default=0.1)
|
||||
# LAION training(pre-computed embeddings)
|
||||
# DiffusionPriorNetwork(dpn) parameters
|
||||
parser.add_argument("--dpn-depth", type=int, default=6)
|
||||
parser.add_argument("--dpn-dim-head", type=int, default=64)
|
||||
parser.add_argument("--dpn-heads", type=int, default=8)
|
||||
# DiffusionPrior(dp) parameters
|
||||
parser.add_argument("--dp-condition-on-text-encodings", type=bool, default=False)
|
||||
parser.add_argument("--dp-timesteps", type=int, default=100)
|
||||
parser.add_argument("--dp-normformer", type=bool, default=False)
|
||||
parser.add_argument("--dp-cond-drop-prob", type=float, default=0.1)
|
||||
parser.add_argument("--dp-loss-type", type=str, default="l2")
|
||||
parser.add_argument("--clip", type=str, default=None)
|
||||
parser.add_argument("--amp", type=bool, default=False)
|
||||
# Model checkpointing interval(minutes)
|
||||
parser.add_argument("--save-interval", type=int, default=30)
|
||||
parser.add_argument("--save-path", type=str, default="./diffusion_prior_checkpoints")
|
||||
# Saved model path
|
||||
parser.add_argument("--pretrained-model-path", type=str, default=None)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
config = ({"learning_rate": args.learning_rate,
|
||||
"architecture": args.wandb_arch,
|
||||
"dataset": args.wandb_dataset,
|
||||
"weight_decay":args.weight_decay,
|
||||
"max_gradient_clipping_norm":args.max_grad_norm,
|
||||
"batch_size":args.batch_size,
|
||||
"epochs": args.num_epochs,
|
||||
"diffusion_prior_network":{"depth":args.dpn_depth,
|
||||
"dim_head":args.dpn_dim_head,
|
||||
"heads":args.dpn_heads,
|
||||
"normformer":args.dp_normformer},
|
||||
"diffusion_prior":{"condition_on_text_encodings": args.dp_condition_on_text_encodings,
|
||||
"timesteps": args.dp_timesteps,
|
||||
"cond_drop_prob":args.dp_cond_drop_prob,
|
||||
"loss_type":args.dp_loss_type,
|
||||
"clip":args.clip}
|
||||
})
|
||||
|
||||
RESUME = False
|
||||
# Check if DPRIOR_PATH exists(saved model path)
|
||||
DPRIOR_PATH = args.pretrained_model_path
|
||||
if(DPRIOR_PATH is not None):
|
||||
RESUME = True
|
||||
else:
|
||||
wandb.init(
|
||||
entity=args.wandb_entity,
|
||||
project=args.wandb_project,
|
||||
config=config)
|
||||
|
||||
# Obtain the utilized device.
|
||||
|
||||
has_cuda = torch.cuda.is_available()
|
||||
if has_cuda:
|
||||
device = torch.device("cuda:0")
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
# Training loop
|
||||
train(args.image_embed_dim,
|
||||
args.image_embed_url,
|
||||
args.text_embed_url,
|
||||
args.batch_size,
|
||||
args.train_percent,
|
||||
args.val_percent,
|
||||
args.test_percent,
|
||||
args.num_epochs,
|
||||
args.dp_loss_type,
|
||||
args.clip,
|
||||
args.dp_condition_on_text_encodings,
|
||||
args.dp_timesteps,
|
||||
args.dp_normformer,
|
||||
args.dp_cond_drop_prob,
|
||||
args.dpn_depth,
|
||||
args.dpn_dim_head,
|
||||
args.dpn_heads,
|
||||
args.save_interval,
|
||||
args.save_path,
|
||||
device,
|
||||
RESUME,
|
||||
DPRIOR_PATH,
|
||||
config,
|
||||
args.wandb_entity,
|
||||
args.wandb_project,
|
||||
args.learning_rate,
|
||||
args.max_grad_norm,
|
||||
args.weight_decay,
|
||||
args.dropout,
|
||||
args.amp)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
train()
|
||||
|
||||
Reference in New Issue
Block a user