mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-14 20:04:21 +01:00
Compare commits
35 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
11d4e11f10 | ||
|
|
99778e12de | ||
|
|
0f0011caf0 | ||
|
|
7b7a62044a | ||
|
|
156fe5ed9f | ||
|
|
5ec34bebe1 | ||
|
|
8eaacf1ac1 | ||
|
|
e66c7b0249 | ||
|
|
f7cd4a0992 | ||
|
|
68e7d2f241 | ||
|
|
74f222596a | ||
|
|
aa6772dcff | ||
|
|
71d0c4edae | ||
|
|
f7eee09d8b | ||
|
|
89de5af63e | ||
|
|
4ec6d0ba81 | ||
|
|
aee92dba4a | ||
|
|
b0cd5f24b6 | ||
|
|
b494ed81d4 | ||
|
|
ff3474f05c | ||
|
|
d5293f19f1 | ||
|
|
e697183849 | ||
|
|
591d37e266 | ||
|
|
d1f02e8f49 | ||
|
|
9faab59b23 | ||
|
|
5d27029e98 | ||
|
|
3115fa17b3 | ||
|
|
124d8577c8 | ||
|
|
2db0c9794c | ||
|
|
2277b47ffd | ||
|
|
28b58e568c | ||
|
|
924455d97d | ||
|
|
6021945fc8 | ||
|
|
6f76652d11 | ||
|
|
3dda2570ed |
128
README.md
128
README.md
@@ -508,7 +508,7 @@ To use a pretrained OpenAI CLIP, simply import `OpenAIClipAdapter` and pass it i
|
|||||||
import torch
|
import torch
|
||||||
from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder, OpenAIClipAdapter
|
from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder, OpenAIClipAdapter
|
||||||
|
|
||||||
# openai pretrained clip - defaults to ViT/B-32
|
# openai pretrained clip - defaults to ViT-B/32
|
||||||
|
|
||||||
clip = OpenAIClipAdapter()
|
clip = OpenAIClipAdapter()
|
||||||
|
|
||||||
@@ -706,7 +706,7 @@ mock_image_embed = torch.randn(1, 512).cuda()
|
|||||||
images = decoder.sample(mock_image_embed) # (1, 3, 1024, 1024)
|
images = decoder.sample(mock_image_embed) # (1, 3, 1024, 1024)
|
||||||
```
|
```
|
||||||
|
|
||||||
## Training wrapper (wip)
|
## Training wrapper
|
||||||
|
|
||||||
### Decoder Training
|
### Decoder Training
|
||||||
|
|
||||||
@@ -732,8 +732,8 @@ clip = CLIP(
|
|||||||
|
|
||||||
# mock data
|
# mock data
|
||||||
|
|
||||||
text = torch.randint(0, 49408, (4, 256)).cuda()
|
text = torch.randint(0, 49408, (32, 256)).cuda()
|
||||||
images = torch.randn(4, 3, 256, 256).cuda()
|
images = torch.randn(32, 3, 256, 256).cuda()
|
||||||
|
|
||||||
# decoder (with unet)
|
# decoder (with unet)
|
||||||
|
|
||||||
@@ -774,8 +774,12 @@ decoder_trainer = DecoderTrainer(
|
|||||||
)
|
)
|
||||||
|
|
||||||
for unet_number in (1, 2):
|
for unet_number in (1, 2):
|
||||||
loss = decoder_trainer(images, text = text, unet_number = unet_number) # use the decoder_trainer forward
|
loss = decoder_trainer(
|
||||||
loss.backward()
|
images,
|
||||||
|
text = text,
|
||||||
|
unet_number = unet_number, # which unet to train on
|
||||||
|
max_batch_size = 4 # gradient accumulation - this sets the maximum batch size in which to do forward and backwards pass - for this example 32 / 4 == 8 times
|
||||||
|
)
|
||||||
|
|
||||||
decoder_trainer.update(unet_number) # update the specific unet as well as its exponential moving average
|
decoder_trainer.update(unet_number) # update the specific unet as well as its exponential moving average
|
||||||
|
|
||||||
@@ -810,8 +814,8 @@ clip = CLIP(
|
|||||||
|
|
||||||
# mock data
|
# mock data
|
||||||
|
|
||||||
text = torch.randint(0, 49408, (4, 256)).cuda()
|
text = torch.randint(0, 49408, (32, 256)).cuda()
|
||||||
images = torch.randn(4, 3, 256, 256).cuda()
|
images = torch.randn(32, 3, 256, 256).cuda()
|
||||||
|
|
||||||
# prior networks (with transformer)
|
# prior networks (with transformer)
|
||||||
|
|
||||||
@@ -838,8 +842,7 @@ diffusion_prior_trainer = DiffusionPriorTrainer(
|
|||||||
ema_update_every = 10,
|
ema_update_every = 10,
|
||||||
)
|
)
|
||||||
|
|
||||||
loss = diffusion_prior_trainer(text, images)
|
loss = diffusion_prior_trainer(text, images, max_batch_size = 4)
|
||||||
loss.backward()
|
|
||||||
diffusion_prior_trainer.update() # this will update the optimizer as well as the exponential moving averaged diffusion prior
|
diffusion_prior_trainer.update() # this will update the optimizer as well as the exponential moving averaged diffusion prior
|
||||||
|
|
||||||
# after much of the above three lines in a loop
|
# after much of the above three lines in a loop
|
||||||
@@ -848,6 +851,57 @@ diffusion_prior_trainer.update() # this will update the optimizer as well as th
|
|||||||
image_embeds = diffusion_prior_trainer.sample(text) # (4, 512) - exponential moving averaged image embeddings
|
image_embeds = diffusion_prior_trainer.sample(text) # (4, 512) - exponential moving averaged image embeddings
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Bonus
|
||||||
|
|
||||||
|
### Unconditional Training
|
||||||
|
|
||||||
|
The repository also contains the means to train unconditional DDPM model, or even cascading DDPMs. You simply have to set `unconditional = True` in the `Decoder`
|
||||||
|
|
||||||
|
ex.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from dalle2_pytorch import Unet, Decoder
|
||||||
|
|
||||||
|
# unet for the cascading ddpm
|
||||||
|
|
||||||
|
unet1 = Unet(
|
||||||
|
dim = 128,
|
||||||
|
dim_mults=(1, 2, 4, 8)
|
||||||
|
).cuda()
|
||||||
|
|
||||||
|
unet2 = Unet(
|
||||||
|
dim = 32,
|
||||||
|
dim_mults = (1, 2, 4, 8, 16)
|
||||||
|
).cuda()
|
||||||
|
|
||||||
|
# decoder, which contains the unets
|
||||||
|
|
||||||
|
decoder = Decoder(
|
||||||
|
unet = (unet1, unet2),
|
||||||
|
image_sizes = (256, 512), # first unet up to 256px, then second to 512px
|
||||||
|
timesteps = 1000,
|
||||||
|
unconditional = True
|
||||||
|
).cuda()
|
||||||
|
|
||||||
|
# mock images (get a lot of this)
|
||||||
|
|
||||||
|
images = torch.randn(1, 3, 512, 512).cuda()
|
||||||
|
|
||||||
|
# feed images into decoder
|
||||||
|
|
||||||
|
for i in (1, 2):
|
||||||
|
loss = decoder(images, unet_number = i)
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
# do the above for many many many many steps
|
||||||
|
# then it will learn to generate images
|
||||||
|
|
||||||
|
images = decoder.sample(batch_size = 2) # (2, 3, 512, 512)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Dataloaders
|
||||||
|
|
||||||
### Decoder Dataloaders
|
### Decoder Dataloaders
|
||||||
|
|
||||||
In order to make loading data simple and efficient, we include some general dataloaders that can be used to train portions of the network.
|
In order to make loading data simple and efficient, we include some general dataloaders that can be used to train portions of the network.
|
||||||
@@ -892,14 +946,14 @@ dataset = ImageEmbeddingDataset(
|
|||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
## Scripts
|
### Scripts (wip)
|
||||||
|
|
||||||
### Using the `train_diffusion_prior.py` script
|
#### `train_diffusion_prior.py`
|
||||||
|
|
||||||
This script allows training the DiffusionPrior on pre-computed text and image embeddings. The working example below elucidates this process.
|
This script allows training the DiffusionPrior on pre-computed text and image embeddings. The working example below elucidates this process.
|
||||||
Please note that the script internally passes text_embed and image_embed to the DiffusionPrior, unlike the example below.
|
Please note that the script internally passes text_embed and image_embed to the DiffusionPrior, unlike the example below.
|
||||||
|
|
||||||
### Usage
|
#### Usage
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$ python train_diffusion_prior.py
|
$ python train_diffusion_prior.py
|
||||||
@@ -907,58 +961,49 @@ $ python train_diffusion_prior.py
|
|||||||
|
|
||||||
The most significant parameters for the script are as follows:
|
The most significant parameters for the script are as follows:
|
||||||
|
|
||||||
--image-embed-url, default = "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/")
|
- `image-embed-url`, default = `"https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/"`
|
||||||
|
|
||||||
--text-embed-url, default = "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/")
|
- `text-embed-url`, default = `"https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/"`
|
||||||
|
|
||||||
--image-embed-dim, default=768 - 768 corresponds to the ViT iL/14 embedding size,change it to what your chosen ViT generates
|
- `image-embed-dim`, default = `768` - 768 corresponds to the ViT iL/14 embedding size,change it to what your chosen ViT generates
|
||||||
|
|
||||||
--learning-rate, default=1.1e-4
|
- `learning-rate`, default = `1.1e-4`
|
||||||
|
|
||||||
--weight-decay, default=6.02e-2
|
- `weight-decay`, default = `6.02e-2`
|
||||||
|
|
||||||
--max-grad-norm, default=0.5
|
- `max-grad-norm`, default = `0.5`
|
||||||
|
|
||||||
--batch-size, default=10 ** 4
|
- `batch-size`, default = `10 ** 4`
|
||||||
|
|
||||||
--num-epochs, default=5
|
- `num-epochs`, default = `5`
|
||||||
|
|
||||||
--clip, default=None # Signals the prior to use pre-computed embeddings
|
- `clip`, default = `None` # Signals the prior to use pre-computed embeddings
|
||||||
|
|
||||||
### Sample wandb run log
|
#### Loading and Saving the DiffusionPrior model
|
||||||
|
|
||||||
Please find a sample wandb run log at : https://wandb.ai/laion/diffusion-prior/runs/1blxu24j
|
|
||||||
|
|
||||||
### Loading and saving the Diffusion Prior model
|
|
||||||
|
|
||||||
Two methods are provided, load_diffusion_model and save_diffusion_model, the names being self-explanatory.
|
Two methods are provided, load_diffusion_model and save_diffusion_model, the names being self-explanatory.
|
||||||
|
|
||||||
## from dalle2_pytorch.train import load_diffusion_model, save_diffusion_model
|
```python
|
||||||
|
from dalle2_pytorch.train import load_diffusion_model, save_diffusion_model
|
||||||
|
```
|
||||||
|
|
||||||
|
##### Loading
|
||||||
|
|
||||||
load_diffusion_model(dprior_path, device)
|
load_diffusion_model(dprior_path, device)
|
||||||
|
|
||||||
dprior_path : path to saved model(.pth)
|
dprior_path : path to saved model(.pth)
|
||||||
|
|
||||||
device : the cuda device you're running on
|
device : the cuda device you're running on
|
||||||
|
|
||||||
|
##### Saving
|
||||||
|
|
||||||
save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim)
|
save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim)
|
||||||
|
|
||||||
save_path : path to save at
|
save_path : path to save at
|
||||||
|
|
||||||
model : object of Diffusion_Prior
|
model : object of Diffusion_Prior
|
||||||
|
|
||||||
optimizer : optimizer object - see train_diffusion_prior.py for how to create one.
|
optimizer : optimizer object - see train_diffusion_prior.py for how to create one.
|
||||||
|
|
||||||
e.g: optimizer = get_optimizer(diffusion_prior.net.parameters(), wd=weight_decay, lr=learning_rate)
|
e.g: optimizer = get_optimizer(diffusion_prior.net.parameters(), wd=weight_decay, lr=learning_rate)
|
||||||
|
|
||||||
scaler : a GradScaler object.
|
scaler : a GradScaler object.
|
||||||
|
|
||||||
e.g: scaler = GradScaler(enabled=amp)
|
e.g: scaler = GradScaler(enabled=amp)
|
||||||
|
|
||||||
config : config object created in train_diffusion_prior.py - see file for example.
|
config : config object created in train_diffusion_prior.py - see file for example.
|
||||||
|
|
||||||
image_embed_dim - the dimension of the image_embedding
|
image_embed_dim - the dimension of the image_embedding
|
||||||
|
|
||||||
e.g: 768
|
e.g: 768
|
||||||
|
|
||||||
## CLI (wip)
|
## CLI (wip)
|
||||||
@@ -1004,6 +1049,7 @@ Once built, images will be saved to the same directory the command is invoked
|
|||||||
- [x] make sure the cascading ddpm in the repository can be trained unconditionally, offer a one-line CLI tool for training on a folder of images
|
- [x] make sure the cascading ddpm in the repository can be trained unconditionally, offer a one-line CLI tool for training on a folder of images
|
||||||
- [x] bring in cross-scale embedding from iclr paper https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/crossformer.py#L14
|
- [x] bring in cross-scale embedding from iclr paper https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/crossformer.py#L14
|
||||||
- [x] cross embed layers for downsampling, as an option
|
- [x] cross embed layers for downsampling, as an option
|
||||||
|
- [x] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a>
|
||||||
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
|
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
|
||||||
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
|
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
|
||||||
- [ ] train on a toy task, offer in colab
|
- [ ] train on a toy task, offer in colab
|
||||||
@@ -1011,12 +1057,14 @@ Once built, images will be saved to the same directory the command is invoked
|
|||||||
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
|
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
|
||||||
- [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824
|
- [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824
|
||||||
- [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove
|
- [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove
|
||||||
- [ ] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a>
|
|
||||||
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
|
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
|
||||||
- [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783
|
- [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783
|
||||||
- [ ] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
|
- [ ] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
|
||||||
- [ ] bring in skip-layer excitatons (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training
|
- [ ] bring in skip-layer excitatons (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training
|
||||||
- [ ] decoder needs one day worth of refactor for tech debt
|
- [ ] decoder needs one day worth of refactor for tech debt
|
||||||
|
- [ ] allow for unet to be able to condition non-cross attention style as well
|
||||||
|
- [ ] for all model classes with hyperparameters that changes the network architecture, make it requirement that they must expose a config property, and write a simple function that asserts that it restores the object correctly
|
||||||
|
- [ ] for both diffusion prior and decoder, all exponential moving averaged models needs to be saved and restored as well (as well as the step number)
|
||||||
|
|
||||||
## Citations
|
## Citations
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import math
|
import math
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from inspect import isfunction
|
from inspect import isfunction
|
||||||
from functools import partial
|
from functools import partial, wraps
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -33,6 +33,10 @@ from rotary_embedding_torch import RotaryEmbedding
|
|||||||
from x_clip import CLIP
|
from x_clip import CLIP
|
||||||
from coca_pytorch import CoCa
|
from coca_pytorch import CoCa
|
||||||
|
|
||||||
|
# constants
|
||||||
|
|
||||||
|
NAT = 1. / math.log(2.)
|
||||||
|
|
||||||
# helper functions
|
# helper functions
|
||||||
|
|
||||||
def exists(val):
|
def exists(val):
|
||||||
@@ -41,6 +45,14 @@ def exists(val):
|
|||||||
def identity(t, *args, **kwargs):
|
def identity(t, *args, **kwargs):
|
||||||
return t
|
return t
|
||||||
|
|
||||||
|
def maybe(fn):
|
||||||
|
@wraps(fn)
|
||||||
|
def inner(x):
|
||||||
|
if not exists(x):
|
||||||
|
return x
|
||||||
|
return fn(x)
|
||||||
|
return inner
|
||||||
|
|
||||||
def default(val, d):
|
def default(val, d):
|
||||||
if exists(val):
|
if exists(val):
|
||||||
return val
|
return val
|
||||||
@@ -91,6 +103,9 @@ def freeze_model_and_make_eval_(model):
|
|||||||
|
|
||||||
# tensor helpers
|
# tensor helpers
|
||||||
|
|
||||||
|
def log(t, eps = 1e-12):
|
||||||
|
return torch.log(t.clamp(min = eps))
|
||||||
|
|
||||||
def l2norm(t):
|
def l2norm(t):
|
||||||
return F.normalize(t, dim = -1)
|
return F.normalize(t, dim = -1)
|
||||||
|
|
||||||
@@ -107,10 +122,10 @@ def resize_image_to(image, target_image_size):
|
|||||||
# ddpms expect images to be in the range of -1 to 1
|
# ddpms expect images to be in the range of -1 to 1
|
||||||
# but CLIP may otherwise
|
# but CLIP may otherwise
|
||||||
|
|
||||||
def normalize_img(img):
|
def normalize_neg_one_to_one(img):
|
||||||
return img * 2 - 1
|
return img * 2 - 1
|
||||||
|
|
||||||
def unnormalize_img(normed_img):
|
def unnormalize_zero_to_one(normed_img):
|
||||||
return (normed_img + 1) * 0.5
|
return (normed_img + 1) * 0.5
|
||||||
|
|
||||||
# clip related adapters
|
# clip related adapters
|
||||||
@@ -271,7 +286,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
|
|||||||
def embed_image(self, image):
|
def embed_image(self, image):
|
||||||
assert not self.cleared
|
assert not self.cleared
|
||||||
image = resize_image_to(image, self.image_size)
|
image = resize_image_to(image, self.image_size)
|
||||||
image = self.clip_normalize(unnormalize_img(image))
|
image = self.clip_normalize(image)
|
||||||
image_embed = self.clip.encode_image(image)
|
image_embed = self.clip.encode_image(image)
|
||||||
return EmbeddedImage(l2norm(image_embed.float()), None)
|
return EmbeddedImage(l2norm(image_embed.float()), None)
|
||||||
|
|
||||||
@@ -297,6 +312,36 @@ def noise_like(shape, device, repeat=False):
|
|||||||
noise = lambda: torch.randn(shape, device=device)
|
noise = lambda: torch.randn(shape, device=device)
|
||||||
return repeat_noise() if repeat else noise()
|
return repeat_noise() if repeat else noise()
|
||||||
|
|
||||||
|
def meanflat(x):
|
||||||
|
return x.mean(dim = tuple(range(1, len(x.shape))))
|
||||||
|
|
||||||
|
def normal_kl(mean1, logvar1, mean2, logvar2):
|
||||||
|
return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2))
|
||||||
|
|
||||||
|
def approx_standard_normal_cdf(x):
|
||||||
|
return 0.5 * (1.0 + torch.tanh(((2.0 / math.pi) ** 0.5) * (x + 0.044715 * (x ** 3))))
|
||||||
|
|
||||||
|
def discretized_gaussian_log_likelihood(x, *, means, log_scales, thres = 0.999):
|
||||||
|
assert x.shape == means.shape == log_scales.shape
|
||||||
|
|
||||||
|
centered_x = x - means
|
||||||
|
inv_stdv = torch.exp(-log_scales)
|
||||||
|
plus_in = inv_stdv * (centered_x + 1. / 255.)
|
||||||
|
cdf_plus = approx_standard_normal_cdf(plus_in)
|
||||||
|
min_in = inv_stdv * (centered_x - 1. / 255.)
|
||||||
|
cdf_min = approx_standard_normal_cdf(min_in)
|
||||||
|
log_cdf_plus = log(cdf_plus)
|
||||||
|
log_one_minus_cdf_min = log(1. - cdf_min)
|
||||||
|
cdf_delta = cdf_plus - cdf_min
|
||||||
|
|
||||||
|
log_probs = torch.where(x < -thres,
|
||||||
|
log_cdf_plus,
|
||||||
|
torch.where(x > thres,
|
||||||
|
log_one_minus_cdf_min,
|
||||||
|
log(cdf_delta)))
|
||||||
|
|
||||||
|
return log_probs
|
||||||
|
|
||||||
def cosine_beta_schedule(timesteps, s = 0.008):
|
def cosine_beta_schedule(timesteps, s = 0.008):
|
||||||
"""
|
"""
|
||||||
cosine schedule
|
cosine schedule
|
||||||
@@ -398,12 +443,6 @@ class BaseGaussianDiffusion(nn.Module):
|
|||||||
register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
||||||
register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
|
register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
|
||||||
|
|
||||||
def q_mean_variance(self, x_start, t):
|
|
||||||
mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
|
||||||
variance = extract(1. - self.alphas_cumprod, t, x_start.shape)
|
|
||||||
log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
|
|
||||||
return mean, variance, log_variance
|
|
||||||
|
|
||||||
def q_posterior(self, x_start, x_t, t):
|
def q_posterior(self, x_start, x_t, t):
|
||||||
posterior_mean = (
|
posterior_mean = (
|
||||||
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
|
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
|
||||||
@@ -575,7 +614,6 @@ class Attention(nn.Module):
|
|||||||
heads = 8,
|
heads = 8,
|
||||||
dropout = 0.,
|
dropout = 0.,
|
||||||
causal = False,
|
causal = False,
|
||||||
post_norm = False,
|
|
||||||
rotary_emb = None
|
rotary_emb = None
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -585,7 +623,6 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
self.causal = causal
|
self.causal = causal
|
||||||
self.norm = LayerNorm(dim)
|
self.norm = LayerNorm(dim)
|
||||||
self.post_norm = LayerNorm(dim) # sandwich norm from Coqview paper + Normformer
|
|
||||||
self.dropout = nn.Dropout(dropout)
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
|
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
|
||||||
@@ -596,7 +633,7 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
self.to_out = nn.Sequential(
|
self.to_out = nn.Sequential(
|
||||||
nn.Linear(inner_dim, dim, bias = False),
|
nn.Linear(inner_dim, dim, bias = False),
|
||||||
LayerNorm(dim) if post_norm else nn.Identity()
|
LayerNorm(dim)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x, mask = None, attn_bias = None):
|
def forward(self, x, mask = None, attn_bias = None):
|
||||||
@@ -653,8 +690,7 @@ class Attention(nn.Module):
|
|||||||
out = einsum('b h i j, b j d -> b h i d', attn, v)
|
out = einsum('b h i j, b j d -> b h i d', attn, v)
|
||||||
|
|
||||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||||
out = self.to_out(out)
|
return self.to_out(out)
|
||||||
return self.post_norm(out)
|
|
||||||
|
|
||||||
class CausalTransformer(nn.Module):
|
class CausalTransformer(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -680,7 +716,7 @@ class CausalTransformer(nn.Module):
|
|||||||
self.layers = nn.ModuleList([])
|
self.layers = nn.ModuleList([])
|
||||||
for _ in range(depth):
|
for _ in range(depth):
|
||||||
self.layers.append(nn.ModuleList([
|
self.layers.append(nn.ModuleList([
|
||||||
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, post_norm = normformer, rotary_emb = rotary_emb),
|
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, rotary_emb = rotary_emb),
|
||||||
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
|
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
|
||||||
]))
|
]))
|
||||||
|
|
||||||
@@ -831,7 +867,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
|||||||
image_channels = 3,
|
image_channels = 3,
|
||||||
timesteps = 1000,
|
timesteps = 1000,
|
||||||
cond_drop_prob = 0.,
|
cond_drop_prob = 0.,
|
||||||
loss_type = "l1",
|
loss_type = "l2",
|
||||||
predict_x_start = True,
|
predict_x_start = True,
|
||||||
beta_schedule = "cosine",
|
beta_schedule = "cosine",
|
||||||
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
|
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
|
||||||
@@ -1127,6 +1163,7 @@ class CrossAttention(nn.Module):
|
|||||||
dim_head = 64,
|
dim_head = 64,
|
||||||
heads = 8,
|
heads = 8,
|
||||||
dropout = 0.,
|
dropout = 0.,
|
||||||
|
norm_context = False
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.scale = dim_head ** -0.5
|
self.scale = dim_head ** -0.5
|
||||||
@@ -1136,13 +1173,17 @@ class CrossAttention(nn.Module):
|
|||||||
context_dim = default(context_dim, dim)
|
context_dim = default(context_dim, dim)
|
||||||
|
|
||||||
self.norm = LayerNorm(dim)
|
self.norm = LayerNorm(dim)
|
||||||
self.norm_context = LayerNorm(context_dim)
|
self.norm_context = LayerNorm(context_dim) if norm_context else nn.Identity()
|
||||||
self.dropout = nn.Dropout(dropout)
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
|
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
|
||||||
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
||||||
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
|
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
|
||||||
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
|
||||||
|
self.to_out = nn.Sequential(
|
||||||
|
nn.Linear(inner_dim, dim, bias = False),
|
||||||
|
LayerNorm(dim)
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x, context, mask = None):
|
def forward(self, x, context, mask = None):
|
||||||
b, n, device = *x.shape[:2], x.device
|
b, n, device = *x.shape[:2], x.device
|
||||||
@@ -1264,7 +1305,7 @@ class Unet(nn.Module):
|
|||||||
self,
|
self,
|
||||||
dim,
|
dim,
|
||||||
*,
|
*,
|
||||||
image_embed_dim,
|
image_embed_dim = None,
|
||||||
text_embed_dim = None,
|
text_embed_dim = None,
|
||||||
cond_dim = None,
|
cond_dim = None,
|
||||||
num_image_tokens = 4,
|
num_image_tokens = 4,
|
||||||
@@ -1272,6 +1313,7 @@ class Unet(nn.Module):
|
|||||||
out_dim = None,
|
out_dim = None,
|
||||||
dim_mults=(1, 2, 4, 8),
|
dim_mults=(1, 2, 4, 8),
|
||||||
channels = 3,
|
channels = 3,
|
||||||
|
channels_out = None,
|
||||||
attn_dim_head = 32,
|
attn_dim_head = 32,
|
||||||
attn_heads = 16,
|
attn_heads = 16,
|
||||||
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
|
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
|
||||||
@@ -1302,6 +1344,7 @@ class Unet(nn.Module):
|
|||||||
# determine dimensions
|
# determine dimensions
|
||||||
|
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
|
self.channels_out = default(channels_out, channels)
|
||||||
|
|
||||||
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
|
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
|
||||||
init_dim = default(init_dim, dim // 3 * 2)
|
init_dim = default(init_dim, dim // 3 * 2)
|
||||||
@@ -1334,7 +1377,10 @@ class Unet(nn.Module):
|
|||||||
self.image_to_cond = nn.Sequential(
|
self.image_to_cond = nn.Sequential(
|
||||||
nn.Linear(image_embed_dim, cond_dim * num_image_tokens),
|
nn.Linear(image_embed_dim, cond_dim * num_image_tokens),
|
||||||
Rearrange('b (n d) -> b n d', n = num_image_tokens)
|
Rearrange('b (n d) -> b n d', n = num_image_tokens)
|
||||||
) if image_embed_dim != cond_dim else nn.Identity()
|
) if cond_on_image_embeds and image_embed_dim != cond_dim else nn.Identity()
|
||||||
|
|
||||||
|
self.norm_cond = nn.LayerNorm(cond_dim)
|
||||||
|
self.norm_mid_cond = nn.LayerNorm(cond_dim)
|
||||||
|
|
||||||
# text encoding conditioning (optional)
|
# text encoding conditioning (optional)
|
||||||
|
|
||||||
@@ -1407,11 +1453,9 @@ class Unet(nn.Module):
|
|||||||
Upsample(dim_in)
|
Upsample(dim_in)
|
||||||
]))
|
]))
|
||||||
|
|
||||||
out_dim = default(out_dim, channels)
|
|
||||||
|
|
||||||
self.final_conv = nn.Sequential(
|
self.final_conv = nn.Sequential(
|
||||||
ResnetBlock(dim, dim, groups = resnet_groups[0]),
|
ResnetBlock(dim, dim, groups = resnet_groups[0]),
|
||||||
nn.Conv2d(dim, out_dim, 1)
|
nn.Conv2d(dim, self.channels_out, 1)
|
||||||
)
|
)
|
||||||
|
|
||||||
# if the current settings for the unet are not correct
|
# if the current settings for the unet are not correct
|
||||||
@@ -1421,13 +1465,25 @@ class Unet(nn.Module):
|
|||||||
*,
|
*,
|
||||||
lowres_cond,
|
lowres_cond,
|
||||||
channels,
|
channels,
|
||||||
|
channels_out,
|
||||||
cond_on_image_embeds,
|
cond_on_image_embeds,
|
||||||
cond_on_text_encodings
|
cond_on_text_encodings
|
||||||
):
|
):
|
||||||
if lowres_cond == self.lowres_cond and channels == self.channels and cond_on_image_embeds == self.cond_on_image_embeds and cond_on_text_encodings == self.cond_on_text_encodings:
|
if lowres_cond == self.lowres_cond and \
|
||||||
|
channels == self.channels and \
|
||||||
|
cond_on_image_embeds == self.cond_on_image_embeds and \
|
||||||
|
cond_on_text_encodings == self.cond_on_text_encodings and \
|
||||||
|
channels_out == self.channels_out:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
updated_kwargs = {'lowres_cond': lowres_cond, 'channels': channels, 'cond_on_image_embeds': cond_on_image_embeds, 'cond_on_text_encodings': cond_on_text_encodings}
|
updated_kwargs = dict(
|
||||||
|
lowres_cond = lowres_cond,
|
||||||
|
channels = channels,
|
||||||
|
channels_out = channels_out,
|
||||||
|
cond_on_image_embeds = cond_on_image_embeds,
|
||||||
|
cond_on_text_encodings = cond_on_text_encodings
|
||||||
|
)
|
||||||
|
|
||||||
return self.__class__(**{**self._locals, **updated_kwargs})
|
return self.__class__(**{**self._locals, **updated_kwargs})
|
||||||
|
|
||||||
def forward_with_cond_scale(
|
def forward_with_cond_scale(
|
||||||
@@ -1492,11 +1548,12 @@ class Unet(nn.Module):
|
|||||||
|
|
||||||
if self.cond_on_image_embeds:
|
if self.cond_on_image_embeds:
|
||||||
image_tokens = self.image_to_cond(image_embed)
|
image_tokens = self.image_to_cond(image_embed)
|
||||||
|
null_image_embed = self.null_image_embed.to(image_tokens.dtype) # for some reason pytorch AMP not working
|
||||||
|
|
||||||
image_tokens = torch.where(
|
image_tokens = torch.where(
|
||||||
image_keep_mask,
|
image_keep_mask,
|
||||||
image_tokens,
|
image_tokens,
|
||||||
self.null_image_embed
|
null_image_embed
|
||||||
)
|
)
|
||||||
|
|
||||||
# take care of text encodings (optional)
|
# take care of text encodings (optional)
|
||||||
@@ -1520,10 +1577,12 @@ class Unet(nn.Module):
|
|||||||
text_mask = rearrange(text_mask, 'b n -> b n 1')
|
text_mask = rearrange(text_mask, 'b n -> b n 1')
|
||||||
text_keep_mask = text_mask & text_keep_mask
|
text_keep_mask = text_mask & text_keep_mask
|
||||||
|
|
||||||
|
null_text_embed = self.null_text_embed.to(text_tokens.dtype) # for some reason pytorch AMP not working
|
||||||
|
|
||||||
text_tokens = torch.where(
|
text_tokens = torch.where(
|
||||||
text_keep_mask,
|
text_keep_mask,
|
||||||
text_tokens,
|
text_tokens,
|
||||||
self.null_text_embed
|
null_text_embed
|
||||||
)
|
)
|
||||||
|
|
||||||
# main conditioning tokens (c)
|
# main conditioning tokens (c)
|
||||||
@@ -1538,6 +1597,11 @@ class Unet(nn.Module):
|
|||||||
|
|
||||||
mid_c = c if not exists(text_tokens) else torch.cat((c, text_tokens), dim = -2)
|
mid_c = c if not exists(text_tokens) else torch.cat((c, text_tokens), dim = -2)
|
||||||
|
|
||||||
|
# normalize conditioning tokens
|
||||||
|
|
||||||
|
c = self.norm_cond(c)
|
||||||
|
mid_c = self.norm_mid_cond(mid_c)
|
||||||
|
|
||||||
# go through the layers of the unet, down and up
|
# go through the layers of the unet, down and up
|
||||||
|
|
||||||
hiddens = []
|
hiddens = []
|
||||||
@@ -1611,7 +1675,7 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
timesteps = 1000,
|
timesteps = 1000,
|
||||||
image_cond_drop_prob = 0.1,
|
image_cond_drop_prob = 0.1,
|
||||||
text_cond_drop_prob = 0.5,
|
text_cond_drop_prob = 0.5,
|
||||||
loss_type = 'l1',
|
loss_type = 'l2',
|
||||||
beta_schedule = 'cosine',
|
beta_schedule = 'cosine',
|
||||||
predict_x_start = False,
|
predict_x_start = False,
|
||||||
predict_x_start_for_latent_diffusion = False,
|
predict_x_start_for_latent_diffusion = False,
|
||||||
@@ -1624,6 +1688,8 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
clip_denoised = True,
|
clip_denoised = True,
|
||||||
clip_x_start = True,
|
clip_x_start = True,
|
||||||
clip_adapter_overrides = dict(),
|
clip_adapter_overrides = dict(),
|
||||||
|
learned_variance = True,
|
||||||
|
vb_loss_weight = 0.001,
|
||||||
unconditional = False
|
unconditional = False
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@@ -1635,7 +1701,7 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
self.unconditional = unconditional
|
self.unconditional = unconditional
|
||||||
assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present'
|
assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present'
|
||||||
|
|
||||||
assert exists(clip) ^ exists(image_size), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)'
|
assert self.unconditional or (exists(clip) ^ exists(image_size)), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)'
|
||||||
|
|
||||||
self.clip = None
|
self.clip = None
|
||||||
if exists(clip):
|
if exists(clip):
|
||||||
@@ -1662,10 +1728,18 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
unets = cast_tuple(unet)
|
unets = cast_tuple(unet)
|
||||||
vaes = pad_tuple_to_length(cast_tuple(vae), len(unets), fillvalue = NullVQGanVAE(channels = self.channels))
|
vaes = pad_tuple_to_length(cast_tuple(vae), len(unets), fillvalue = NullVQGanVAE(channels = self.channels))
|
||||||
|
|
||||||
|
# whether to use learned variance, defaults to True for the first unet in the cascade, as in paper
|
||||||
|
|
||||||
|
learned_variance = pad_tuple_to_length(cast_tuple(learned_variance), len(unets), fillvalue = False)
|
||||||
|
self.learned_variance = learned_variance
|
||||||
|
self.vb_loss_weight = vb_loss_weight
|
||||||
|
|
||||||
|
# construct unets and vaes
|
||||||
|
|
||||||
self.unets = nn.ModuleList([])
|
self.unets = nn.ModuleList([])
|
||||||
self.vaes = nn.ModuleList([])
|
self.vaes = nn.ModuleList([])
|
||||||
|
|
||||||
for ind, (one_unet, one_vae) in enumerate(zip(unets, vaes)):
|
for ind, (one_unet, one_vae, one_unet_learned_var) in enumerate(zip(unets, vaes, learned_variance)):
|
||||||
assert isinstance(one_unet, Unet)
|
assert isinstance(one_unet, Unet)
|
||||||
assert isinstance(one_vae, (VQGanVAE, NullVQGanVAE))
|
assert isinstance(one_vae, (VQGanVAE, NullVQGanVAE))
|
||||||
|
|
||||||
@@ -1673,12 +1747,14 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
latent_dim = one_vae.encoded_dim if exists(one_vae) else None
|
latent_dim = one_vae.encoded_dim if exists(one_vae) else None
|
||||||
|
|
||||||
unet_channels = default(latent_dim, self.channels)
|
unet_channels = default(latent_dim, self.channels)
|
||||||
|
unet_channels_out = unet_channels * (1 if not one_unet_learned_var else 2)
|
||||||
|
|
||||||
one_unet = one_unet.cast_model_parameters(
|
one_unet = one_unet.cast_model_parameters(
|
||||||
lowres_cond = not is_first,
|
lowres_cond = not is_first,
|
||||||
cond_on_image_embeds = is_first and not unconditional,
|
cond_on_image_embeds = is_first and not unconditional,
|
||||||
cond_on_text_encodings = one_unet.cond_on_text_encodings and not unconditional,
|
cond_on_text_encodings = one_unet.cond_on_text_encodings and not unconditional,
|
||||||
channels = unet_channels
|
channels = unet_channels,
|
||||||
|
channels_out = unet_channels_out
|
||||||
)
|
)
|
||||||
|
|
||||||
self.unets.append(one_unet)
|
self.unets.append(one_unet)
|
||||||
@@ -1741,8 +1817,11 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
yield
|
yield
|
||||||
unet.cpu()
|
unet.cpu()
|
||||||
|
|
||||||
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, cond_scale = 1.):
|
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None):
|
||||||
pred = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
|
pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img))
|
||||||
|
|
||||||
|
if learned_variance:
|
||||||
|
pred, var_interp_frac_unnormalized = pred.chunk(2, dim = 1)
|
||||||
|
|
||||||
if predict_x_start:
|
if predict_x_start:
|
||||||
x_recon = pred
|
x_recon = pred
|
||||||
@@ -1753,24 +1832,38 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
x_recon.clamp_(-1., 1.)
|
x_recon.clamp_(-1., 1.)
|
||||||
|
|
||||||
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
||||||
|
|
||||||
|
if learned_variance:
|
||||||
|
# if learned variance, posterio variance and posterior log variance are predicted by the network
|
||||||
|
# by an interpolation of the max and min log beta values
|
||||||
|
# eq 15 - https://arxiv.org/abs/2102.09672
|
||||||
|
min_log = extract(self.posterior_log_variance_clipped, t, x.shape)
|
||||||
|
max_log = extract(torch.log(self.betas), t, x.shape)
|
||||||
|
var_interp_frac = unnormalize_zero_to_one(var_interp_frac_unnormalized)
|
||||||
|
|
||||||
|
posterior_log_variance = var_interp_frac * max_log + (1 - var_interp_frac) * min_log
|
||||||
|
posterior_variance = posterior_log_variance.exp()
|
||||||
|
|
||||||
return model_mean, posterior_variance, posterior_log_variance
|
return model_mean, posterior_variance, posterior_log_variance
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def p_sample(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, clip_denoised = True, repeat_noise = False):
|
def p_sample(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True, repeat_noise = False):
|
||||||
b, *_, device = *x.shape, x.device
|
b, *_, device = *x.shape, x.device
|
||||||
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start)
|
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, learned_variance = learned_variance)
|
||||||
noise = noise_like(x.shape, device, repeat_noise)
|
noise = noise_like(x.shape, device, repeat_noise)
|
||||||
# no noise when t == 0
|
# no noise when t == 0
|
||||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1):
|
def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1):
|
||||||
device = self.betas.device
|
device = self.betas.device
|
||||||
|
|
||||||
b = shape[0]
|
b = shape[0]
|
||||||
img = torch.randn(shape, device = device)
|
img = torch.randn(shape, device = device)
|
||||||
|
|
||||||
|
lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img)
|
||||||
|
|
||||||
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
|
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
|
||||||
img = self.p_sample(
|
img = self.p_sample(
|
||||||
unet,
|
unet,
|
||||||
@@ -1782,17 +1875,26 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
cond_scale = cond_scale,
|
cond_scale = cond_scale,
|
||||||
lowres_cond_img = lowres_cond_img,
|
lowres_cond_img = lowres_cond_img,
|
||||||
predict_x_start = predict_x_start,
|
predict_x_start = predict_x_start,
|
||||||
|
learned_variance = learned_variance,
|
||||||
clip_denoised = clip_denoised
|
clip_denoised = clip_denoised
|
||||||
)
|
)
|
||||||
|
|
||||||
return img
|
unnormalize_img = unnormalize_zero_to_one(img)
|
||||||
|
return unnormalize_img
|
||||||
|
|
||||||
def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None):
|
def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False):
|
||||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||||
|
|
||||||
|
# normalize to [-1, 1]
|
||||||
|
|
||||||
|
x_start = normalize_neg_one_to_one(x_start)
|
||||||
|
lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img)
|
||||||
|
|
||||||
|
# get x_t
|
||||||
|
|
||||||
x_noisy = self.q_sample(x_start = x_start, t = times, noise = noise)
|
x_noisy = self.q_sample(x_start = x_start, t = times, noise = noise)
|
||||||
|
|
||||||
pred = unet(
|
model_output = unet(
|
||||||
x_noisy,
|
x_noisy,
|
||||||
times,
|
times,
|
||||||
image_embed = image_embed,
|
image_embed = image_embed,
|
||||||
@@ -1803,10 +1905,48 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
text_cond_drop_prob = self.text_cond_drop_prob,
|
text_cond_drop_prob = self.text_cond_drop_prob,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if learned_variance:
|
||||||
|
pred, _ = model_output.chunk(2, dim = 1)
|
||||||
|
else:
|
||||||
|
pred = model_output
|
||||||
|
|
||||||
target = noise if not predict_x_start else x_start
|
target = noise if not predict_x_start else x_start
|
||||||
|
|
||||||
loss = self.loss_fn(pred, target)
|
loss = self.loss_fn(pred, target)
|
||||||
return loss
|
|
||||||
|
if not learned_variance:
|
||||||
|
# return simple loss if not using learned variance
|
||||||
|
return loss
|
||||||
|
|
||||||
|
# most of the code below is transcribed from
|
||||||
|
# https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/diffusion_utils_2.py
|
||||||
|
# the Improved DDPM paper then further modified it so that the mean is detached (shown a couple lines before), and weighted to be smaller than the l1 or l2 "simple" loss
|
||||||
|
# it is questionable whether this is really needed, looking at some of the figures in the paper, but may as well stay faithful to their implementation
|
||||||
|
|
||||||
|
# if learning the variance, also include the extra weight kl loss
|
||||||
|
|
||||||
|
true_mean, _, true_log_variance_clipped = self.q_posterior(x_start = x_start, x_t = x_noisy, t = times)
|
||||||
|
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x_noisy, t = times, image_embed = image_embed, clip_denoised = clip_denoised, learned_variance = True, model_output = model_output)
|
||||||
|
|
||||||
|
# kl loss with detached model predicted mean, for stability reasons as in paper
|
||||||
|
|
||||||
|
detached_model_mean = model_mean.detach()
|
||||||
|
|
||||||
|
kl = normal_kl(true_mean, true_log_variance_clipped, detached_model_mean, model_log_variance)
|
||||||
|
kl = meanflat(kl) * NAT
|
||||||
|
|
||||||
|
decoder_nll = -discretized_gaussian_log_likelihood(x_start, means = detached_model_mean, log_scales = 0.5 * model_log_variance)
|
||||||
|
decoder_nll = meanflat(decoder_nll) * NAT
|
||||||
|
|
||||||
|
# at the first timestep return the decoder NLL, otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
|
||||||
|
|
||||||
|
vb_losses = torch.where(times == 0, decoder_nll, kl)
|
||||||
|
|
||||||
|
# weight the vb loss smaller, for stability, as in the paper (recommended 0.001)
|
||||||
|
|
||||||
|
vb_loss = vb_losses.mean() * self.vb_loss_weight
|
||||||
|
|
||||||
|
return loss + vb_loss
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
@eval_decorator
|
@eval_decorator
|
||||||
@@ -1833,7 +1973,7 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
img = None
|
img = None
|
||||||
is_cuda = next(self.parameters()).is_cuda
|
is_cuda = next(self.parameters()).is_cuda
|
||||||
|
|
||||||
for unet_number, unet, vae, channel, image_size, predict_x_start in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)):
|
for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance)):
|
||||||
|
|
||||||
context = self.one_unet_in_gpu(unet = unet) if is_cuda else null_context()
|
context = self.one_unet_in_gpu(unet = unet) if is_cuda else null_context()
|
||||||
|
|
||||||
@@ -1848,8 +1988,7 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
image_size = vae.get_encoded_fmap_size(image_size)
|
image_size = vae.get_encoded_fmap_size(image_size)
|
||||||
shape = (batch_size, vae.encoded_dim, image_size, image_size)
|
shape = (batch_size, vae.encoded_dim, image_size, image_size)
|
||||||
|
|
||||||
if exists(lowres_cond_img):
|
lowres_cond_img = maybe(vae.encode)(lowres_cond_img)
|
||||||
lowres_cond_img = vae.encode(lowres_cond_img)
|
|
||||||
|
|
||||||
img = self.p_sample_loop(
|
img = self.p_sample_loop(
|
||||||
unet,
|
unet,
|
||||||
@@ -1859,6 +1998,7 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
text_mask = text_mask,
|
text_mask = text_mask,
|
||||||
cond_scale = cond_scale,
|
cond_scale = cond_scale,
|
||||||
predict_x_start = predict_x_start,
|
predict_x_start = predict_x_start,
|
||||||
|
learned_variance = learned_variance,
|
||||||
clip_denoised = not is_latent_diffusion,
|
clip_denoised = not is_latent_diffusion,
|
||||||
lowres_cond_img = lowres_cond_img
|
lowres_cond_img = lowres_cond_img
|
||||||
)
|
)
|
||||||
@@ -1888,6 +2028,7 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
target_image_size = self.image_sizes[unet_index]
|
target_image_size = self.image_sizes[unet_index]
|
||||||
predict_x_start = self.predict_x_start[unet_index]
|
predict_x_start = self.predict_x_start[unet_index]
|
||||||
random_crop_size = self.random_crop_sizes[unet_index]
|
random_crop_size = self.random_crop_sizes[unet_index]
|
||||||
|
learned_variance = self.learned_variance[unet_index]
|
||||||
b, c, h, w, device, = *image.shape, image.device
|
b, c, h, w, device, = *image.shape, image.device
|
||||||
|
|
||||||
check_shape(image, 'b c h w', c = self.channels)
|
check_shape(image, 'b c h w', c = self.channels)
|
||||||
@@ -1895,12 +2036,12 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
|
|
||||||
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
|
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
|
||||||
|
|
||||||
if not exists(image_embed):
|
if not exists(image_embed) and not self.unconditional:
|
||||||
assert exists(self.clip), 'if you want to derive CLIP image embeddings automatically, you must supply `clip` to the decoder on init'
|
assert exists(self.clip), 'if you want to derive CLIP image embeddings automatically, you must supply `clip` to the decoder on init'
|
||||||
image_embed, _ = self.clip.embed_image(image)
|
image_embed, _ = self.clip.embed_image(image)
|
||||||
|
|
||||||
text_encodings = text_mask = None
|
text_encodings = text_mask = None
|
||||||
if exists(text) and not exists(text_encodings):
|
if exists(text) and not exists(text_encodings) and not self.unconditional:
|
||||||
assert exists(self.clip), 'if you are passing in raw text, you need to supply `clip` to the decoder'
|
assert exists(self.clip), 'if you are passing in raw text, you need to supply `clip` to the decoder'
|
||||||
_, text_encodings, text_mask = self.clip.embed_text(text)
|
_, text_encodings, text_mask = self.clip.embed_text(text)
|
||||||
|
|
||||||
@@ -1921,11 +2062,9 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
vae.eval()
|
vae.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
image = vae.encode(image)
|
image = vae.encode(image)
|
||||||
|
lowres_cond_img = maybe(vae.encode)(lowres_cond_img)
|
||||||
|
|
||||||
if exists(lowres_cond_img):
|
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance)
|
||||||
lowres_cond_img = vae.encode(lowres_cond_img)
|
|
||||||
|
|
||||||
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start)
|
|
||||||
|
|
||||||
# main class
|
# main class
|
||||||
|
|
||||||
@@ -1975,4 +2114,3 @@ class DALLE2(nn.Module):
|
|||||||
return images[0]
|
return images[0]
|
||||||
|
|
||||||
return images
|
return images
|
||||||
|
|
||||||
|
|||||||
@@ -7,16 +7,17 @@ def separate_weight_decayable_params(params):
|
|||||||
|
|
||||||
def get_optimizer(
|
def get_optimizer(
|
||||||
params,
|
params,
|
||||||
lr = 3e-4,
|
lr = 2e-5,
|
||||||
wd = 1e-2,
|
wd = 1e-2,
|
||||||
betas = (0.9, 0.999),
|
betas = (0.9, 0.999),
|
||||||
|
eps = 1e-8,
|
||||||
filter_by_requires_grad = False
|
filter_by_requires_grad = False
|
||||||
):
|
):
|
||||||
if filter_by_requires_grad:
|
if filter_by_requires_grad:
|
||||||
params = list(filter(lambda t: t.requires_grad, params))
|
params = list(filter(lambda t: t.requires_grad, params))
|
||||||
|
|
||||||
if wd == 0:
|
if wd == 0:
|
||||||
return Adam(params, lr = lr, betas = betas)
|
return Adam(params, lr = lr, betas = betas, eps = eps)
|
||||||
|
|
||||||
params = set(params)
|
params = set(params)
|
||||||
wd_params, no_wd_params = separate_weight_decayable_params(params)
|
wd_params, no_wd_params = separate_weight_decayable_params(params)
|
||||||
@@ -26,4 +27,4 @@ def get_optimizer(
|
|||||||
{'params': list(no_wd_params), 'weight_decay': 0},
|
{'params': list(no_wd_params), 'weight_decay': 0},
|
||||||
]
|
]
|
||||||
|
|
||||||
return AdamW(param_groups, lr = lr, weight_decay = wd, betas = betas)
|
return AdamW(param_groups, lr = lr, weight_decay = wd, betas = betas, eps = eps)
|
||||||
|
|||||||
49
dalle2_pytorch/trackers.py
Normal file
49
dalle2_pytorch/trackers.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
import os
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
# helper functions
|
||||||
|
|
||||||
|
def exists(val):
|
||||||
|
return val is not None
|
||||||
|
|
||||||
|
# base class
|
||||||
|
|
||||||
|
class BaseTracker(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def init(self, config, **kwargs):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def log(self, log, **kwargs):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
# basic stdout class
|
||||||
|
|
||||||
|
class ConsoleTracker(BaseTracker):
|
||||||
|
def init(self, **config):
|
||||||
|
print(config)
|
||||||
|
|
||||||
|
def log(self, log, **kwargs):
|
||||||
|
print(log)
|
||||||
|
|
||||||
|
# basic wandb class
|
||||||
|
|
||||||
|
class WandbTracker(BaseTracker):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
try:
|
||||||
|
import wandb
|
||||||
|
except ImportError as e:
|
||||||
|
print('`pip install wandb` to use the wandb experiment tracker')
|
||||||
|
raise e
|
||||||
|
|
||||||
|
os.environ["WANDB_SILENT"] = "true"
|
||||||
|
self.wandb = wandb
|
||||||
|
|
||||||
|
def init(self, **config):
|
||||||
|
self.wandb.init(**config)
|
||||||
|
|
||||||
|
def log(self, log, **kwargs):
|
||||||
|
self.wandb.log(log, **kwargs)
|
||||||
@@ -1,6 +1,8 @@
|
|||||||
import time
|
import time
|
||||||
import copy
|
import copy
|
||||||
from functools import partial
|
from math import ceil
|
||||||
|
from functools import partial, wraps
|
||||||
|
from collections.abc import Iterable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@@ -9,11 +11,16 @@ from torch.cuda.amp import autocast, GradScaler
|
|||||||
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
|
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
|
||||||
from dalle2_pytorch.optimizer import get_optimizer
|
from dalle2_pytorch.optimizer import get_optimizer
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
# helper functions
|
# helper functions
|
||||||
|
|
||||||
def exists(val):
|
def exists(val):
|
||||||
return val is not None
|
return val is not None
|
||||||
|
|
||||||
|
def default(val, d):
|
||||||
|
return val if exists(val) else d
|
||||||
|
|
||||||
def cast_tuple(val, length = 1):
|
def cast_tuple(val, length = 1):
|
||||||
return val if isinstance(val, tuple) else ((val,) * length)
|
return val if isinstance(val, tuple) else ((val,) * length)
|
||||||
|
|
||||||
@@ -40,6 +47,79 @@ def groupby_prefix_and_trim(prefix, d):
|
|||||||
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
|
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
|
||||||
return kwargs_without_prefix, kwargs
|
return kwargs_without_prefix, kwargs
|
||||||
|
|
||||||
|
# decorators
|
||||||
|
|
||||||
|
def cast_torch_tensor(fn):
|
||||||
|
@wraps(fn)
|
||||||
|
def inner(model, *args, **kwargs):
|
||||||
|
device = kwargs.pop('_device', next(model.parameters()).device)
|
||||||
|
cast_device = kwargs.pop('_cast_device', True)
|
||||||
|
|
||||||
|
kwargs_keys = kwargs.keys()
|
||||||
|
all_args = (*args, *kwargs.values())
|
||||||
|
split_kwargs_index = len(all_args) - len(kwargs_keys)
|
||||||
|
all_args = tuple(map(lambda t: torch.from_numpy(t) if exists(t) and isinstance(t, np.ndarray) else t, all_args))
|
||||||
|
|
||||||
|
if cast_device:
|
||||||
|
all_args = tuple(map(lambda t: t.to(device) if exists(t) and isinstance(t, torch.Tensor) else t, all_args))
|
||||||
|
|
||||||
|
args, kwargs_values = all_args[:split_kwargs_index], all_args[split_kwargs_index:]
|
||||||
|
kwargs = dict(tuple(zip(kwargs_keys, kwargs_values)))
|
||||||
|
|
||||||
|
out = fn(model, *args, **kwargs)
|
||||||
|
return out
|
||||||
|
return inner
|
||||||
|
|
||||||
|
# gradient accumulation functions
|
||||||
|
|
||||||
|
def split_iterable(it, split_size):
|
||||||
|
accum = []
|
||||||
|
for ind in range(ceil(len(it) / split_size)):
|
||||||
|
start_index = ind * split_size
|
||||||
|
accum.append(it[start_index: (start_index + split_size)])
|
||||||
|
return accum
|
||||||
|
|
||||||
|
def split(t, split_size = None):
|
||||||
|
if not exists(split_size):
|
||||||
|
return t
|
||||||
|
|
||||||
|
if isinstance(t, torch.Tensor):
|
||||||
|
return t.split(split_size, dim = 0)
|
||||||
|
|
||||||
|
if isinstance(t, Iterable):
|
||||||
|
return split_iterable(t, split_size)
|
||||||
|
|
||||||
|
return TypeError
|
||||||
|
|
||||||
|
def find_first(cond, arr):
|
||||||
|
for el in arr:
|
||||||
|
if cond(el):
|
||||||
|
return el
|
||||||
|
return None
|
||||||
|
|
||||||
|
def split_args_and_kwargs(*args, split_size = None, **kwargs):
|
||||||
|
all_args = (*args, *kwargs.values())
|
||||||
|
len_all_args = len(all_args)
|
||||||
|
first_tensor = find_first(lambda t: isinstance(t, torch.Tensor), all_args)
|
||||||
|
assert exists(first_tensor)
|
||||||
|
|
||||||
|
batch_size = len(first_tensor)
|
||||||
|
split_size = default(split_size, batch_size)
|
||||||
|
num_chunks = ceil(batch_size / split_size)
|
||||||
|
|
||||||
|
dict_len = len(kwargs)
|
||||||
|
dict_keys = kwargs.keys()
|
||||||
|
split_kwargs_index = len_all_args - dict_len
|
||||||
|
|
||||||
|
split_all_args = [split(arg, split_size = split_size) if exists(arg) and isinstance(arg, (torch.Tensor, Iterable)) else ((arg,) * num_chunks) for arg in all_args]
|
||||||
|
chunk_sizes = tuple(map(len, split_all_args[0]))
|
||||||
|
|
||||||
|
for (chunk_size, *chunked_all_args) in tuple(zip(chunk_sizes, *split_all_args)):
|
||||||
|
chunked_args, chunked_kwargs_values = chunked_all_args[:split_kwargs_index], chunked_all_args[split_kwargs_index:]
|
||||||
|
chunked_kwargs = dict(tuple(zip(dict_keys, chunked_kwargs_values)))
|
||||||
|
chunk_size_frac = chunk_size / batch_size
|
||||||
|
yield chunk_size_frac, (chunked_args, chunked_kwargs)
|
||||||
|
|
||||||
# print helpers
|
# print helpers
|
||||||
|
|
||||||
def print_ribbon(s, symbol = '=', repeat = 40):
|
def print_ribbon(s, symbol = '=', repeat = 40):
|
||||||
@@ -71,7 +151,7 @@ def load_diffusion_model(dprior_path, device):
|
|||||||
# Load state dict from saved model
|
# Load state dict from saved model
|
||||||
diffusion_prior.load_state_dict(loaded_obj['model'])
|
diffusion_prior.load_state_dict(loaded_obj['model'])
|
||||||
|
|
||||||
return diffusion_prior
|
return diffusion_prior, loaded_obj
|
||||||
|
|
||||||
def save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim):
|
def save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim):
|
||||||
# Saving State Dict
|
# Saving State Dict
|
||||||
@@ -90,7 +170,7 @@ class EMA(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model,
|
model,
|
||||||
beta = 0.99,
|
beta = 0.9999,
|
||||||
update_after_step = 1000,
|
update_after_step = 1000,
|
||||||
update_every = 10,
|
update_every = 10,
|
||||||
):
|
):
|
||||||
@@ -105,6 +185,10 @@ class EMA(nn.Module):
|
|||||||
self.register_buffer('initted', torch.Tensor([False]))
|
self.register_buffer('initted', torch.Tensor([False]))
|
||||||
self.register_buffer('step', torch.tensor([0.]))
|
self.register_buffer('step', torch.tensor([0.]))
|
||||||
|
|
||||||
|
def restore_ema_model_device(self):
|
||||||
|
device = self.initted.device
|
||||||
|
self.ema_model.to(device)
|
||||||
|
|
||||||
def update(self):
|
def update(self):
|
||||||
self.step += 1
|
self.step += 1
|
||||||
|
|
||||||
@@ -143,6 +227,7 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
use_ema = True,
|
use_ema = True,
|
||||||
lr = 3e-4,
|
lr = 3e-4,
|
||||||
wd = 1e-2,
|
wd = 1e-2,
|
||||||
|
eps = 1e-6,
|
||||||
max_grad_norm = None,
|
max_grad_norm = None,
|
||||||
amp = False,
|
amp = False,
|
||||||
**kwargs
|
**kwargs
|
||||||
@@ -169,6 +254,7 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
diffusion_prior.parameters(),
|
diffusion_prior.parameters(),
|
||||||
lr = lr,
|
lr = lr,
|
||||||
wd = wd,
|
wd = wd,
|
||||||
|
eps = eps,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -176,6 +262,8 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
|
|
||||||
self.max_grad_norm = max_grad_norm
|
self.max_grad_norm = max_grad_norm
|
||||||
|
|
||||||
|
self.register_buffer('step', torch.tensor([0.]))
|
||||||
|
|
||||||
def update(self):
|
def update(self):
|
||||||
if exists(self.max_grad_norm):
|
if exists(self.max_grad_norm):
|
||||||
self.scaler.unscale_(self.optimizer)
|
self.scaler.unscale_(self.optimizer)
|
||||||
@@ -188,11 +276,15 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
if self.use_ema:
|
if self.use_ema:
|
||||||
self.ema_diffusion_prior.update()
|
self.ema_diffusion_prior.update()
|
||||||
|
|
||||||
|
self.step += 1
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
|
@cast_torch_tensor
|
||||||
def p_sample_loop(self, *args, **kwargs):
|
def p_sample_loop(self, *args, **kwargs):
|
||||||
return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs)
|
return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
|
@cast_torch_tensor
|
||||||
def sample(self, *args, **kwargs):
|
def sample(self, *args, **kwargs):
|
||||||
return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs)
|
return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs)
|
||||||
|
|
||||||
@@ -200,15 +292,26 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
def sample_batch_size(self, *args, **kwargs):
|
def sample_batch_size(self, *args, **kwargs):
|
||||||
return self.ema_diffusion_prior.ema_model.sample_batch_size(*args, **kwargs)
|
return self.ema_diffusion_prior.ema_model.sample_batch_size(*args, **kwargs)
|
||||||
|
|
||||||
|
@cast_torch_tensor
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
*args,
|
*args,
|
||||||
divisor = 1,
|
max_batch_size = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
with autocast(enabled = self.amp):
|
total_loss = 0.
|
||||||
loss = self.diffusion_prior(*args, **kwargs)
|
|
||||||
return self.scaler.scale(loss / divisor)
|
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
|
||||||
|
with autocast(enabled = self.amp):
|
||||||
|
loss = self.diffusion_prior(*chunked_args, **chunked_kwargs)
|
||||||
|
loss = loss * chunk_size_frac
|
||||||
|
|
||||||
|
total_loss += loss.item()
|
||||||
|
|
||||||
|
if self.training:
|
||||||
|
self.scaler.scale(loss).backward()
|
||||||
|
|
||||||
|
return total_loss
|
||||||
|
|
||||||
# decoder trainer
|
# decoder trainer
|
||||||
|
|
||||||
@@ -217,8 +320,9 @@ class DecoderTrainer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
decoder,
|
decoder,
|
||||||
use_ema = True,
|
use_ema = True,
|
||||||
lr = 3e-4,
|
lr = 2e-5,
|
||||||
wd = 1e-2,
|
wd = 1e-2,
|
||||||
|
eps = 1e-8,
|
||||||
max_grad_norm = None,
|
max_grad_norm = None,
|
||||||
amp = False,
|
amp = False,
|
||||||
**kwargs
|
**kwargs
|
||||||
@@ -243,13 +347,14 @@ class DecoderTrainer(nn.Module):
|
|||||||
# be able to finely customize learning rate, weight decay
|
# be able to finely customize learning rate, weight decay
|
||||||
# per unet
|
# per unet
|
||||||
|
|
||||||
lr, wd = map(partial(cast_tuple, length = self.num_unets), (lr, wd))
|
lr, wd, eps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps))
|
||||||
|
|
||||||
for ind, (unet, unet_lr, unet_wd) in enumerate(zip(self.decoder.unets, lr, wd)):
|
for ind, (unet, unet_lr, unet_wd, unet_eps) in enumerate(zip(self.decoder.unets, lr, wd, eps)):
|
||||||
optimizer = get_optimizer(
|
optimizer = get_optimizer(
|
||||||
unet.parameters(),
|
unet.parameters(),
|
||||||
lr = unet_lr,
|
lr = unet_lr,
|
||||||
wd = unet_wd,
|
wd = unet_wd,
|
||||||
|
eps = unet_eps,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -265,6 +370,8 @@ class DecoderTrainer(nn.Module):
|
|||||||
|
|
||||||
self.max_grad_norm = max_grad_norm
|
self.max_grad_norm = max_grad_norm
|
||||||
|
|
||||||
|
self.register_buffer('step', torch.tensor([0.]))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def unets(self):
|
def unets(self):
|
||||||
return nn.ModuleList([ema.ema_model for ema in self.ema_unets])
|
return nn.ModuleList([ema.ema_model for ema in self.ema_unets])
|
||||||
@@ -295,7 +402,10 @@ class DecoderTrainer(nn.Module):
|
|||||||
ema_unet = self.ema_unets[index]
|
ema_unet = self.ema_unets[index]
|
||||||
ema_unet.update()
|
ema_unet.update()
|
||||||
|
|
||||||
|
self.step += 1
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
@cast_torch_tensor
|
||||||
def sample(self, *args, **kwargs):
|
def sample(self, *args, **kwargs):
|
||||||
if self.use_ema:
|
if self.use_ema:
|
||||||
trainable_unets = self.decoder.unets
|
trainable_unets = self.decoder.unets
|
||||||
@@ -305,16 +415,31 @@ class DecoderTrainer(nn.Module):
|
|||||||
|
|
||||||
if self.use_ema:
|
if self.use_ema:
|
||||||
self.decoder.unets = trainable_unets # restore original training unets
|
self.decoder.unets = trainable_unets # restore original training unets
|
||||||
|
|
||||||
|
# cast the ema_model unets back to original device
|
||||||
|
for ema in self.ema_unets:
|
||||||
|
ema.restore_ema_model_device()
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
@cast_torch_tensor
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x,
|
*args,
|
||||||
*,
|
|
||||||
unet_number,
|
unet_number,
|
||||||
divisor = 1,
|
max_batch_size = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
with autocast(enabled = self.amp):
|
total_loss = 0.
|
||||||
loss = self.decoder(x, unet_number = unet_number, **kwargs)
|
|
||||||
return self.scale(loss / divisor, unet_number = unet_number)
|
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
|
||||||
|
with autocast(enabled = self.amp):
|
||||||
|
loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs)
|
||||||
|
loss = loss * chunk_size_frac
|
||||||
|
|
||||||
|
total_loss += loss.item()
|
||||||
|
|
||||||
|
if self.training:
|
||||||
|
self.scale(loss, unet_number = unet_number).backward()
|
||||||
|
|
||||||
|
return total_loss
|
||||||
|
|||||||
3
setup.py
3
setup.py
@@ -10,7 +10,7 @@ setup(
|
|||||||
'dream = dalle2_pytorch.cli:dream'
|
'dream = dalle2_pytorch.cli:dream'
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
version = '0.2.9',
|
version = '0.2.35',
|
||||||
license='MIT',
|
license='MIT',
|
||||||
description = 'DALL-E 2',
|
description = 'DALL-E 2',
|
||||||
author = 'Phil Wang',
|
author = 'Phil Wang',
|
||||||
@@ -30,6 +30,7 @@ setup(
|
|||||||
'einops-exts>=0.0.3',
|
'einops-exts>=0.0.3',
|
||||||
'embedding-reader',
|
'embedding-reader',
|
||||||
'kornia>=0.5.4',
|
'kornia>=0.5.4',
|
||||||
|
'numpy',
|
||||||
'pillow',
|
'pillow',
|
||||||
'resize-right>=0.0.2',
|
'resize-right>=0.0.2',
|
||||||
'rotary-embedding-torch',
|
'rotary-embedding-torch',
|
||||||
|
|||||||
@@ -1,24 +1,42 @@
|
|||||||
import os
|
from pathlib import Path
|
||||||
|
import click
|
||||||
import math
|
import math
|
||||||
import argparse
|
import time
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from embedding_reader import EmbeddingReader
|
|
||||||
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork
|
|
||||||
from dalle2_pytorch.train import load_diffusion_model, save_diffusion_model, print_ribbon
|
|
||||||
from dalle2_pytorch.optimizer import get_optimizer
|
|
||||||
from torch.cuda.amp import autocast,GradScaler
|
|
||||||
|
|
||||||
import time
|
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork
|
||||||
|
from dalle2_pytorch.train import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model, print_ribbon
|
||||||
|
from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker
|
||||||
|
|
||||||
|
from embedding_reader import EmbeddingReader
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import wandb
|
# constants
|
||||||
os.environ["WANDB_SILENT"] = "true"
|
|
||||||
NUM_TEST_EMBEDDINGS = 100 # for cosine similarity reporting during training
|
NUM_TEST_EMBEDDINGS = 100 # for cosine similarity reporting during training
|
||||||
REPORT_METRICS_EVERY = 100 # for cosine similarity and other metric reporting during training
|
REPORT_METRICS_EVERY = 100 # for cosine similarity and other metric reporting during training
|
||||||
|
|
||||||
|
tracker = WandbTracker()
|
||||||
|
|
||||||
|
# helpers functions
|
||||||
|
|
||||||
|
def exists(val):
|
||||||
|
val is not None
|
||||||
|
|
||||||
|
class Timer:
|
||||||
|
def __init__(self):
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.last_time = time.time()
|
||||||
|
|
||||||
|
def elapsed(self):
|
||||||
|
return time.time() - self.last_time
|
||||||
|
# functions
|
||||||
|
|
||||||
def eval_model(model,device,image_reader,text_reader,start,end,batch_size,loss_type,phase="Validation"):
|
def eval_model(model,device,image_reader,text_reader,start,end,batch_size,loss_type,phase="Validation"):
|
||||||
model.eval()
|
model.eval()
|
||||||
@@ -40,7 +58,7 @@ def eval_model(model,device,image_reader,text_reader,start,end,batch_size,loss_t
|
|||||||
total_samples += batches
|
total_samples += batches
|
||||||
|
|
||||||
avg_loss = (total_loss / total_samples)
|
avg_loss = (total_loss / total_samples)
|
||||||
wandb.log({f'{phase} {loss_type}': avg_loss})
|
tracker.log({f'{phase} {loss_type}': avg_loss})
|
||||||
|
|
||||||
def report_cosine_sims(diffusion_prior,image_reader,text_reader,train_set_size,NUM_TEST_EMBEDDINGS,device):
|
def report_cosine_sims(diffusion_prior,image_reader,text_reader,train_set_size,NUM_TEST_EMBEDDINGS,device):
|
||||||
diffusion_prior.eval()
|
diffusion_prior.eval()
|
||||||
@@ -87,85 +105,177 @@ def report_cosine_sims(diffusion_prior,image_reader,text_reader,train_set_size,N
|
|||||||
text_embed, predicted_unrelated_embeddings).cpu().numpy()
|
text_embed, predicted_unrelated_embeddings).cpu().numpy()
|
||||||
predicted_img_similarity = cos(
|
predicted_img_similarity = cos(
|
||||||
test_image_embeddings, predicted_image_embeddings).cpu().numpy()
|
test_image_embeddings, predicted_image_embeddings).cpu().numpy()
|
||||||
wandb.log({"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity),
|
tracker.log({"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity),
|
||||||
"CosineSimilarity(text_embed,predicted_image_embed)":np.mean(predicted_similarity),
|
"CosineSimilarity(text_embed,predicted_image_embed)":np.mean(predicted_similarity),
|
||||||
"CosineSimilarity(orig_image_embed,predicted_image_embed)":np.mean(predicted_img_similarity),
|
"CosineSimilarity(orig_image_embed,predicted_image_embed)":np.mean(predicted_img_similarity),
|
||||||
"CosineSimilarity(text_embed,predicted_unrelated_embed)": np.mean(unrelated_similarity),
|
"CosineSimilarity(text_embed,predicted_unrelated_embed)": np.mean(unrelated_similarity),
|
||||||
"Cosine similarity difference":np.mean(predicted_similarity - original_similarity)})
|
"Cosine similarity difference":np.mean(predicted_similarity - original_similarity)})
|
||||||
|
|
||||||
def train(image_embed_dim,
|
@click.command()
|
||||||
image_embed_url,
|
@click.option("--wandb-entity", default="laion")
|
||||||
text_embed_url,
|
@click.option("--wandb-project", default="diffusion-prior")
|
||||||
batch_size,
|
@click.option("--wandb-dataset", default="LAION-5B")
|
||||||
train_percent,
|
@click.option("--wandb-arch", default="DiffusionPrior")
|
||||||
val_percent,
|
@click.option("--image-embed-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/")
|
||||||
test_percent,
|
@click.option("--text-embed-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/")
|
||||||
num_epochs,
|
@click.option("--learning-rate", default=1.1e-4)
|
||||||
dp_loss_type,
|
@click.option("--weight-decay", default=6.02e-2)
|
||||||
clip,
|
@click.option("--dropout", default=5e-2)
|
||||||
dp_condition_on_text_encodings,
|
@click.option("--max-grad-norm", default=0.5)
|
||||||
dp_timesteps,
|
@click.option("--batch-size", default=10**4)
|
||||||
dp_normformer,
|
@click.option("--num-epochs", default=5)
|
||||||
dp_cond_drop_prob,
|
@click.option("--image-embed-dim", default=768)
|
||||||
dpn_depth,
|
@click.option("--train-percent", default=0.7)
|
||||||
dpn_dim_head,
|
@click.option("--val-percent", default=0.2)
|
||||||
dpn_heads,
|
@click.option("--test-percent", default=0.1)
|
||||||
save_interval,
|
@click.option("--dpn-depth", default=6)
|
||||||
save_path,
|
@click.option("--dpn-dim-head", default=64)
|
||||||
device,
|
@click.option("--dpn-heads", default=8)
|
||||||
RESUME,
|
@click.option("--dp-condition-on-text-encodings", default=False)
|
||||||
DPRIOR_PATH,
|
@click.option("--dp-timesteps", default=100)
|
||||||
config,
|
@click.option("--dp-normformer", default=False)
|
||||||
wandb_entity,
|
@click.option("--dp-cond-drop-prob", default=0.1)
|
||||||
wandb_project,
|
@click.option("--dp-loss-type", default="l2")
|
||||||
learning_rate=0.001,
|
@click.option("--clip", default=None)
|
||||||
max_grad_norm=0.5,
|
@click.option("--amp", default=False)
|
||||||
weight_decay=0.01,
|
@click.option("--save-interval", default=30)
|
||||||
dropout=0.05,
|
@click.option("--save-path", default="./diffusion_prior_checkpoints")
|
||||||
amp=False):
|
@click.option("--pretrained-model-path", default=None)
|
||||||
|
def train(
|
||||||
|
wandb_entity,
|
||||||
|
wandb_project,
|
||||||
|
wandb_dataset,
|
||||||
|
wandb_arch,
|
||||||
|
image_embed_url,
|
||||||
|
text_embed_url,
|
||||||
|
learning_rate,
|
||||||
|
weight_decay,
|
||||||
|
dropout,
|
||||||
|
max_grad_norm,
|
||||||
|
batch_size,
|
||||||
|
num_epochs,
|
||||||
|
image_embed_dim,
|
||||||
|
train_percent,
|
||||||
|
val_percent,
|
||||||
|
test_percent,
|
||||||
|
dpn_depth,
|
||||||
|
dpn_dim_head,
|
||||||
|
dpn_heads,
|
||||||
|
dp_condition_on_text_encodings,
|
||||||
|
dp_timesteps,
|
||||||
|
dp_normformer,
|
||||||
|
dp_cond_drop_prob,
|
||||||
|
dp_loss_type,
|
||||||
|
clip,
|
||||||
|
amp,
|
||||||
|
save_interval,
|
||||||
|
save_path,
|
||||||
|
pretrained_model_path
|
||||||
|
):
|
||||||
|
config = {
|
||||||
|
"learning_rate": learning_rate,
|
||||||
|
"architecture": wandb_arch,
|
||||||
|
"dataset": wandb_dataset,
|
||||||
|
"weight_decay": weight_decay,
|
||||||
|
"max_gradient_clipping_norm": max_grad_norm,
|
||||||
|
"batch_size": batch_size,
|
||||||
|
"epochs": num_epochs,
|
||||||
|
"diffusion_prior_network": {
|
||||||
|
"depth": dpn_depth,
|
||||||
|
"dim_head": dpn_dim_head,
|
||||||
|
"heads": dpn_heads,
|
||||||
|
"normformer": dp_normformer
|
||||||
|
},
|
||||||
|
"diffusion_prior": {
|
||||||
|
"condition_on_text_encodings": dp_condition_on_text_encodings,
|
||||||
|
"timesteps": dp_timesteps,
|
||||||
|
"cond_drop_prob": dp_cond_drop_prob,
|
||||||
|
"loss_type": dp_loss_type,
|
||||||
|
"clip": clip
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Check if DPRIOR_PATH exists(saved model path)
|
||||||
|
|
||||||
|
DPRIOR_PATH = args.pretrained_model_path
|
||||||
|
RESUME = exists(DPRIOR_PATH)
|
||||||
|
|
||||||
|
if not RESUME:
|
||||||
|
tracker.init(
|
||||||
|
entity = wandb_entity,
|
||||||
|
project = wandb_project,
|
||||||
|
config = config
|
||||||
|
)
|
||||||
|
|
||||||
|
# Obtain the utilized device.
|
||||||
|
|
||||||
|
has_cuda = torch.cuda.is_available()
|
||||||
|
if has_cuda:
|
||||||
|
device = torch.device("cuda:0")
|
||||||
|
torch.cuda.set_device(device)
|
||||||
|
|
||||||
|
# Training loop
|
||||||
|
# diffusion prior network
|
||||||
|
|
||||||
# DiffusionPriorNetwork
|
|
||||||
prior_network = DiffusionPriorNetwork(
|
prior_network = DiffusionPriorNetwork(
|
||||||
dim = image_embed_dim,
|
dim = image_embed_dim,
|
||||||
depth = dpn_depth,
|
depth = dpn_depth,
|
||||||
dim_head = dpn_dim_head,
|
dim_head = dpn_dim_head,
|
||||||
heads = dpn_heads,
|
heads = dpn_heads,
|
||||||
attn_dropout = dropout,
|
attn_dropout = dropout,
|
||||||
ff_dropout = dropout,
|
ff_dropout = dropout,
|
||||||
normformer = dp_normformer).to(device)
|
normformer = dp_normformer
|
||||||
|
)
|
||||||
|
|
||||||
# DiffusionPrior with text embeddings and image embeddings pre-computed
|
# diffusion prior with text embeddings and image embeddings pre-computed
|
||||||
|
|
||||||
diffusion_prior = DiffusionPrior(
|
diffusion_prior = DiffusionPrior(
|
||||||
net = prior_network,
|
net = prior_network,
|
||||||
clip = clip,
|
clip = clip,
|
||||||
image_embed_dim = image_embed_dim,
|
image_embed_dim = image_embed_dim,
|
||||||
timesteps = dp_timesteps,
|
timesteps = dp_timesteps,
|
||||||
cond_drop_prob = dp_cond_drop_prob,
|
cond_drop_prob = dp_cond_drop_prob,
|
||||||
loss_type = dp_loss_type,
|
loss_type = dp_loss_type,
|
||||||
condition_on_text_encodings = dp_condition_on_text_encodings).to(device)
|
condition_on_text_encodings = dp_condition_on_text_encodings
|
||||||
|
)
|
||||||
|
|
||||||
# Load pre-trained model from DPRIOR_PATH
|
# Load pre-trained model from DPRIOR_PATH
|
||||||
|
|
||||||
if RESUME:
|
if RESUME:
|
||||||
diffusion_prior=load_diffusion_model(DPRIOR_PATH,device)
|
diffusion_prior, loaded_obj = load_diffusion_model(DPRIOR_PATH, device)
|
||||||
wandb.init( entity=wandb_entity, project=wandb_project, config=config)
|
tracker.init(entity = wandb_entity, project = wandb_project, config = config)
|
||||||
|
|
||||||
|
# diffusion prior trainer
|
||||||
|
|
||||||
|
trainer = DiffusionPriorTrainer(
|
||||||
|
diffusion_prior = diffusion_prior,
|
||||||
|
lr = learning_rate,
|
||||||
|
wd = weight_decay,
|
||||||
|
max_grad_norm = max_grad_norm,
|
||||||
|
amp = amp,
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
# load optimizer and scaler
|
||||||
|
|
||||||
|
if RESUME:
|
||||||
|
trainer.optimizer.load_state_dict(loaded_obj['optimizer'])
|
||||||
|
trainer.scaler.load_state_dict(loaded_obj['scaler'])
|
||||||
|
|
||||||
# Create save_path if it doesn't exist
|
# Create save_path if it doesn't exist
|
||||||
if not os.path.exists(save_path):
|
|
||||||
os.makedirs(save_path)
|
Path(save_path).mkdir(exist_ok = True, parents = True)
|
||||||
|
|
||||||
# Get image and text embeddings from the servers
|
# Get image and text embeddings from the servers
|
||||||
|
|
||||||
print_ribbon("Downloading embeddings - image and text")
|
print_ribbon("Downloading embeddings - image and text")
|
||||||
image_reader = EmbeddingReader(embeddings_folder=image_embed_url, file_format="npy")
|
image_reader = EmbeddingReader(embeddings_folder=image_embed_url, file_format="npy")
|
||||||
text_reader = EmbeddingReader(embeddings_folder=text_embed_url, file_format="npy")
|
text_reader = EmbeddingReader(embeddings_folder=text_embed_url, file_format="npy")
|
||||||
num_data_points = text_reader.count
|
num_data_points = text_reader.count
|
||||||
|
|
||||||
### Training code ###
|
### Training code ###
|
||||||
scaler = GradScaler(enabled=amp)
|
|
||||||
optimizer = get_optimizer(diffusion_prior.net.parameters(), wd=weight_decay, lr=learning_rate)
|
|
||||||
epochs = num_epochs
|
|
||||||
|
|
||||||
step = 0
|
timer = Timer()
|
||||||
t = time.time()
|
epochs = num_epochs
|
||||||
|
|
||||||
train_set_size = int(train_percent*num_data_points)
|
train_set_size = int(train_percent*num_data_points)
|
||||||
val_set_size = int(val_percent*num_data_points)
|
val_set_size = int(val_percent*num_data_points)
|
||||||
@@ -176,32 +286,31 @@ def train(image_embed_dim,
|
|||||||
for emb_images,emb_text in zip(image_reader(batch_size=batch_size, start=0, end=train_set_size),
|
for emb_images,emb_text in zip(image_reader(batch_size=batch_size, start=0, end=train_set_size),
|
||||||
text_reader(batch_size=batch_size, start=0, end=train_set_size)):
|
text_reader(batch_size=batch_size, start=0, end=train_set_size)):
|
||||||
|
|
||||||
diffusion_prior.train()
|
trainer.train()
|
||||||
|
|
||||||
emb_images_tensor = torch.tensor(emb_images[0]).to(device)
|
emb_images_tensor = torch.tensor(emb_images[0]).to(device)
|
||||||
emb_text_tensor = torch.tensor(emb_text[0]).to(device)
|
emb_text_tensor = torch.tensor(emb_text[0]).to(device)
|
||||||
|
|
||||||
with autocast(enabled=amp):
|
loss = trainer(text_embed = emb_text_tensor, image_embed = emb_images_tensor)
|
||||||
loss = diffusion_prior(text_embed = emb_text_tensor,image_embed = emb_images_tensor)
|
|
||||||
scaler.scale(loss).backward()
|
|
||||||
|
|
||||||
# Samples per second
|
# Samples per second
|
||||||
step+=1
|
|
||||||
samples_per_sec = batch_size*step/(time.time()-t)
|
samples_per_sec = batch_size * step / timer.elapsed()
|
||||||
|
|
||||||
# Save checkpoint every save_interval minutes
|
# Save checkpoint every save_interval minutes
|
||||||
if(int(time.time()-t) >= 60*save_interval):
|
if(int(timer.elapsed()) >= 60 * save_interval):
|
||||||
t = time.time()
|
timer.reset()
|
||||||
|
|
||||||
save_diffusion_model(
|
save_diffusion_model(
|
||||||
save_path,
|
save_path,
|
||||||
diffusion_prior,
|
diffusion_prior,
|
||||||
optimizer,
|
trainer.optimizer,
|
||||||
scaler,
|
trainer.scaler,
|
||||||
config,
|
config,
|
||||||
image_embed_dim)
|
image_embed_dim)
|
||||||
|
|
||||||
# Log to wandb
|
# Log to wandb
|
||||||
wandb.log({"Training loss": loss.item(),
|
tracker.log({"Training loss": loss.item(),
|
||||||
"Steps": step,
|
"Steps": step,
|
||||||
"Samples per second": samples_per_sec})
|
"Samples per second": samples_per_sec})
|
||||||
# Log cosineSim(text_embed,predicted_image_embed) - cosineSim(text_embed,image_embed)
|
# Log cosineSim(text_embed,predicted_image_embed) - cosineSim(text_embed,image_embed)
|
||||||
@@ -225,130 +334,13 @@ def train(image_embed_dim,
|
|||||||
dp_loss_type,
|
dp_loss_type,
|
||||||
phase="Validation")
|
phase="Validation")
|
||||||
|
|
||||||
scaler.unscale_(optimizer)
|
trainer.update()
|
||||||
nn.utils.clip_grad_norm_(diffusion_prior.parameters(), max_grad_norm)
|
|
||||||
|
|
||||||
scaler.step(optimizer)
|
|
||||||
scaler.update()
|
|
||||||
optimizer.zero_grad()
|
|
||||||
|
|
||||||
### Test run ###
|
### Test run ###
|
||||||
test_set_size = int(test_percent*train_set_size)
|
test_set_size = int(test_percent*train_set_size)
|
||||||
start=train_set_size+val_set_size
|
start = train_set_size+val_set_size
|
||||||
end=num_data_points
|
end = num_data_points
|
||||||
eval_model(diffusion_prior,device,image_reader,text_reader,start,end,batch_size,dp_loss_type,phase="Test")
|
eval_model(diffusion_prior,device,image_reader,text_reader,start,end,batch_size,dp_loss_type,phase="Test")
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
# Logging
|
|
||||||
parser.add_argument("--wandb-entity", type=str, default="laion")
|
|
||||||
parser.add_argument("--wandb-project", type=str, default="diffusion-prior")
|
|
||||||
parser.add_argument("--wandb-dataset", type=str, default="LAION-5B")
|
|
||||||
parser.add_argument("--wandb-arch", type=str, default="DiffusionPrior")
|
|
||||||
# URLs for embeddings
|
|
||||||
parser.add_argument("--image-embed-url", type=str, default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/")
|
|
||||||
parser.add_argument("--text-embed-url", type=str, default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/")
|
|
||||||
# Hyperparameters
|
|
||||||
parser.add_argument("--learning-rate", type=float, default=1.1e-4)
|
|
||||||
parser.add_argument("--weight-decay", type=float, default=6.02e-2)
|
|
||||||
parser.add_argument("--dropout", type=float, default=5e-2)
|
|
||||||
parser.add_argument("--max-grad-norm", type=float, default=0.5)
|
|
||||||
parser.add_argument("--batch-size", type=int, default=10**4)
|
|
||||||
parser.add_argument("--num-epochs", type=int, default=5)
|
|
||||||
# Image embed dimension
|
|
||||||
parser.add_argument("--image-embed-dim", type=int, default=768)
|
|
||||||
# Train-test split
|
|
||||||
parser.add_argument("--train-percent", type=float, default=0.7)
|
|
||||||
parser.add_argument("--val-percent", type=float, default=0.2)
|
|
||||||
parser.add_argument("--test-percent", type=float, default=0.1)
|
|
||||||
# LAION training(pre-computed embeddings)
|
|
||||||
# DiffusionPriorNetwork(dpn) parameters
|
|
||||||
parser.add_argument("--dpn-depth", type=int, default=6)
|
|
||||||
parser.add_argument("--dpn-dim-head", type=int, default=64)
|
|
||||||
parser.add_argument("--dpn-heads", type=int, default=8)
|
|
||||||
# DiffusionPrior(dp) parameters
|
|
||||||
parser.add_argument("--dp-condition-on-text-encodings", type=bool, default=False)
|
|
||||||
parser.add_argument("--dp-timesteps", type=int, default=100)
|
|
||||||
parser.add_argument("--dp-normformer", type=bool, default=False)
|
|
||||||
parser.add_argument("--dp-cond-drop-prob", type=float, default=0.1)
|
|
||||||
parser.add_argument("--dp-loss-type", type=str, default="l2")
|
|
||||||
parser.add_argument("--clip", type=str, default=None)
|
|
||||||
parser.add_argument("--amp", type=bool, default=False)
|
|
||||||
# Model checkpointing interval(minutes)
|
|
||||||
parser.add_argument("--save-interval", type=int, default=30)
|
|
||||||
parser.add_argument("--save-path", type=str, default="./diffusion_prior_checkpoints")
|
|
||||||
# Saved model path
|
|
||||||
parser.add_argument("--pretrained-model-path", type=str, default=None)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
config = ({"learning_rate": args.learning_rate,
|
|
||||||
"architecture": args.wandb_arch,
|
|
||||||
"dataset": args.wandb_dataset,
|
|
||||||
"weight_decay":args.weight_decay,
|
|
||||||
"max_gradient_clipping_norm":args.max_grad_norm,
|
|
||||||
"batch_size":args.batch_size,
|
|
||||||
"epochs": args.num_epochs,
|
|
||||||
"diffusion_prior_network":{"depth":args.dpn_depth,
|
|
||||||
"dim_head":args.dpn_dim_head,
|
|
||||||
"heads":args.dpn_heads,
|
|
||||||
"normformer":args.dp_normformer},
|
|
||||||
"diffusion_prior":{"condition_on_text_encodings": args.dp_condition_on_text_encodings,
|
|
||||||
"timesteps": args.dp_timesteps,
|
|
||||||
"cond_drop_prob":args.dp_cond_drop_prob,
|
|
||||||
"loss_type":args.dp_loss_type,
|
|
||||||
"clip":args.clip}
|
|
||||||
})
|
|
||||||
|
|
||||||
RESUME = False
|
|
||||||
# Check if DPRIOR_PATH exists(saved model path)
|
|
||||||
DPRIOR_PATH = args.pretrained_model_path
|
|
||||||
if(DPRIOR_PATH is not None):
|
|
||||||
RESUME = True
|
|
||||||
else:
|
|
||||||
wandb.init(
|
|
||||||
entity=args.wandb_entity,
|
|
||||||
project=args.wandb_project,
|
|
||||||
config=config)
|
|
||||||
|
|
||||||
# Obtain the utilized device.
|
|
||||||
|
|
||||||
has_cuda = torch.cuda.is_available()
|
|
||||||
if has_cuda:
|
|
||||||
device = torch.device("cuda:0")
|
|
||||||
torch.cuda.set_device(device)
|
|
||||||
|
|
||||||
# Training loop
|
|
||||||
train(args.image_embed_dim,
|
|
||||||
args.image_embed_url,
|
|
||||||
args.text_embed_url,
|
|
||||||
args.batch_size,
|
|
||||||
args.train_percent,
|
|
||||||
args.val_percent,
|
|
||||||
args.test_percent,
|
|
||||||
args.num_epochs,
|
|
||||||
args.dp_loss_type,
|
|
||||||
args.clip,
|
|
||||||
args.dp_condition_on_text_encodings,
|
|
||||||
args.dp_timesteps,
|
|
||||||
args.dp_normformer,
|
|
||||||
args.dp_cond_drop_prob,
|
|
||||||
args.dpn_depth,
|
|
||||||
args.dpn_dim_head,
|
|
||||||
args.dpn_heads,
|
|
||||||
args.save_interval,
|
|
||||||
args.save_path,
|
|
||||||
device,
|
|
||||||
RESUME,
|
|
||||||
DPRIOR_PATH,
|
|
||||||
config,
|
|
||||||
args.wandb_entity,
|
|
||||||
args.wandb_project,
|
|
||||||
args.learning_rate,
|
|
||||||
args.max_grad_norm,
|
|
||||||
args.weight_decay,
|
|
||||||
args.dropout,
|
|
||||||
args.amp)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
train()
|
||||||
|
|||||||
Reference in New Issue
Block a user